Skip to content

Commit

Permalink
another calc_tderivative
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Nov 27, 2024
1 parent f10b234 commit 80f9a2d
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step)
alg = unwrap_alg(integrator, true)
#derivative!(dT, tf, t, du2, integrator, cache.grad_config)
autodiff_alg = alg_autodiff(alg)

autodiff_alg = if autodiff_alg isa AutoSparse
ADTypes.dense_ad(autodiff_alg)
else
Expand All @@ -68,7 +68,15 @@ function calc_tderivative(integrator, cache)
tf = cache.tf
tf.u = uprev
tf.p = p
dT = DI.derivative(tf, alg_autodiff(alg), t)

autodiff_alg = alg_autodiff(alg)
autodiff_alg = if autodiff_alg isa AutoSparse
autodiff_alg = ADTypes.dense_ad(autodiff_alg)
else
autodiff_alg
end

dT = DI.derivative(tf, autodiff_alg, t)
end
dT
end
Expand Down

0 comments on commit 80f9a2d

Please sign in to comment.