Skip to content

Commit

Permalink
Merge pull request #239 from lxvm/forward
Browse files Browse the repository at this point in the history
ForwardDiff directly on all non-C solvers
  • Loading branch information
ChrisRackauckas authored Feb 22, 2024
2 parents 12c635f + e7192fb commit 4593763
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
16 changes: 12 additions & 4 deletions docs/src/basics/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The in-place interface allows evaluating vector-valued integrands without
allocating an output array. This can be beneficial for reducing allocations when
integrating many functions simultaneously or to make use of existing in-place
code. However, note that not all algorithms use in-place operations under the
hood, i.e. `HCubatureJL()`, and may still allocate.
hood, i.e. [`HCubatureJL`](@ref), and may still allocate.

You can construct an `IntegralFunction(f, prototype)`, where `f` is of the form
`f(y, u, p)` where `prototype` is of the desired type and shape of `y`.
Expand All @@ -22,16 +22,17 @@ different points, which maximizes the parallelism for a given algorithm.
You can construct an out-of-place `BatchIntegralFunction(bf)` where `bf` is of
the form `bf(u, p) = stack(x -> f(x, p), eachslice(u; dims=ndims(u)))`, where
`f` is the (unbatched) integrand.
For interoperability with as many algorithms as possible, it is important that your out-of-place batch integrand accept an **empty** array of quadrature points and still return an output with a size and type consistent with the non-empty case.

You can construct an in-place `BatchIntegralFunction(bf, prototype)`, where `bf`
is of the form `bf(y, u, p) = foreach((y,x) -> f(y,x,p), eachslice(y, dims=ndims(y)), eachslice(x, dims=ndims(x)))`.

Note that not all algorithms use in-place batched operations under the hood,
i.e. `QuadGKJL()`.
i.e. [`QuadGKJL`](@ref).

## What should I do if my solution is not converged?

Certain algorithms, such as `QuadratureRule` used a fixed number of points to
Certain algorithms, such as [`QuadratureRule`](@ref) used a fixed number of points to
calculate an integral and cannot provide an error estimate. In this case, you
have to increase the number of points and check the convergence yourself, which
will depend on the accuracy of the rule you choose.
Expand All @@ -47,7 +48,7 @@ precision arithmetic may help.

## How can I integrate arbitrarily-spaced data?

See `SampledIntegralProblem`.
See [`SampledIntegralProblem`](@ref).

## How can I integrate on arbitrary geometries?

Expand All @@ -59,6 +60,13 @@ because that is what lower-level packages implement.
Fixed quadrature rules from other packages can be used with `QuadratureRule`.
Otherwise, feel free to open an issue or pull request.

## My integrand works with algorithm X but fails on algorithm Y

While bugs are not out of the question, certain algorithms, especially those implemented in C, are not compatible with arbitrary Julia types and have to return specific numeric types or arrays thereof.
In some cases, such as [`ArblibJL`](@ref), it is also expected that the integrand work with a custom quadrature point type.
Moreover, some algorithms, such as [`VEGAS`](@ref), only support scalar integrands.
For more details see the [solver page](@ref solvers).

## Can I take derivatives with respect to the limits of integration?

Currently this is not implemented.
41 changes: 26 additions & 15 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,37 @@ using Integrals
isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
### Forward-Mode AD Intercepts

#= Direct AD on solvers with QuadGK and HCubature
# incompatible with iip since types must change
function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
end
# Default to direct AD on solvers
function Integrals.__solvebp(cache, alg, sensealg, domain,
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}

function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, domain,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
if isinplace(cache.f)
prototype = cache.f.integrand_prototype
elt = eltype(prototype)
ForwardDiff.can_dual(elt) || throw(ArgumentError("ForwardDiff of in-place integrands only supports prototypes with real elements"))
dprototype = similar(prototype, replace_dualvaltype(D, elt))
df = if cache.f isa BatchIntegralFunction
BatchIntegralFunction{true}(cache.f.f, dprototype)
else
IntegralFunction{true}(cache.f.f, dprototype)
end
prob = Integrals.build_problem(cache)
dprob = remake(prob, f = df)
dcache = init(dprob, alg; sensealg = sensealg, do_inf_transformation=Val(false), kwargs...)
Integrals.__solvebp_call(dcache, alg, sensealg, domain, p; kwargs...)
else
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
end
end
=#


# TODO: add the pushforward for derivative w.r.t lb, and ub (and then combinations?)

# Manually split for the pushforward
function Integrals.__solvebp(cache, alg, sensealg, domain,
p::Union{D, AbstractArray{<:D}};
kwargs...) where {T, V, P, D <: ForwardDiff.Dual{T, V, P}}
function Integrals.__solvebp(cache, alg::Integrals.AbstractIntegralCExtensionAlgorithm, sensealg, domain,
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}

# we need the output type to avoid perturbation confusion while unwrapping nested duals
# We compute a vector-valued integral of the primal and dual simultaneously
Expand Down Expand Up @@ -73,6 +83,7 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
end
end

DT <: Real || throw(ArgumentError("differentiating algorithms in C"))
ForwardDiff.can_dual(elt) || ForwardDiff.throw_cannot_dual(elt)
rawp = p isa D ? reinterpret(V, [p]) : copy(reinterpret(V, vec(p)))

Expand Down
7 changes: 4 additions & 3 deletions src/algorithms_extension.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
## Extension Algorithms

abstract type AbstractIntegralExtensionAlgorithm <: SciMLBase.AbstractIntegralAlgorithm end
abstract type AbstractIntegralCExtensionAlgorithm <: AbstractIntegralExtensionAlgorithm end

abstract type AbstractCubaAlgorithm <: AbstractIntegralExtensionAlgorithm end
abstract type AbstractCubaAlgorithm <: AbstractIntegralCExtensionAlgorithm end

"""
CubaVegas()
Expand Down Expand Up @@ -152,7 +153,7 @@ function CubaCuhre(; flags = 0, minevals = 0, key = 0)
return CubaCuhre(flags, minevals, key)
end

abstract type AbstractCubatureJLAlgorithm <: AbstractIntegralExtensionAlgorithm end
abstract type AbstractCubatureJLAlgorithm <: AbstractIntegralCExtensionAlgorithm end

"""
CubatureJLh(; error_norm=Cubature.INDIVIDUAL)
Expand Down Expand Up @@ -219,7 +220,7 @@ documentation for additional details the algorithm arguments and on implementing
high-precision integrands. Additionally, the error estimate is included in the return value
of the integral, representing a ball.
"""
struct ArblibJL{O} <: AbstractIntegralExtensionAlgorithm
struct ArblibJL{O} <: AbstractIntegralCExtensionAlgorithm
check_analytic::Bool
take_prec::Bool
warn_on_no_convergence::Bool
Expand Down

0 comments on commit 4593763

Please sign in to comment.