Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Feb 22, 2024
1 parent 4593763 commit 8313934
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)

# Default to direct AD on solvers
function Integrals.__solvebp(cache, alg, sensealg, domain,
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}

p::Union{D, AbstractArray{<:D}};
kwargs...) where {T, V, P, D <: ForwardDiff.Dual{T, V, P}}
if isinplace(cache.f)
prototype = cache.f.integrand_prototype
elt = eltype(prototype)
ForwardDiff.can_dual(elt) || throw(ArgumentError("ForwardDiff of in-place integrands only supports prototypes with real elements"))
ForwardDiff.can_dual(elt) ||
throw(ArgumentError("ForwardDiff of in-place integrands only supports prototypes with real elements"))
dprototype = similar(prototype, replace_dualvaltype(D, elt))
df = if cache.f isa BatchIntegralFunction
BatchIntegralFunction{true}(cache.f.f, dprototype)
Expand All @@ -20,20 +20,21 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
end
prob = Integrals.build_problem(cache)
dprob = remake(prob, f = df)
dcache = init(dprob, alg; sensealg = sensealg, do_inf_transformation=Val(false), kwargs...)
dcache = init(
dprob, alg; sensealg = sensealg, do_inf_transformation = Val(false), kwargs...)
Integrals.__solvebp_call(dcache, alg, sensealg, domain, p; kwargs...)
else
Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...)
end
end


# TODO: add the pushforward for derivative w.r.t lb, and ub (and then combinations?)

# Manually split for the pushforward
function Integrals.__solvebp(cache, alg::Integrals.AbstractIntegralCExtensionAlgorithm, sensealg, domain,
p::Union{D,AbstractArray{<:D}};
kwargs...) where {T, V, P, D<:ForwardDiff.Dual{T, V, P}}
function Integrals.__solvebp(
cache, alg::Integrals.AbstractIntegralCExtensionAlgorithm, sensealg, domain,
p::Union{D, AbstractArray{<:D}};
kwargs...) where {T, V, P, D <: ForwardDiff.Dual{T, V, P}}

# we need the output type to avoid perturbation confusion while unwrapping nested duals
# We compute a vector-valued integral of the primal and dual simultaneously
Expand Down

0 comments on commit 8313934

Please sign in to comment.