diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 0adc8fb5ce..dc6ca70d6d 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -42,7 +42,7 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step) alg = unwrap_alg(integrator, true) #derivative!(dT, tf, t, du2, integrator, cache.grad_config) autodiff_alg = alg_autodiff(alg) - + autodiff_alg = if autodiff_alg isa AutoSparse ADTypes.dense_ad(autodiff_alg) else @@ -68,7 +68,15 @@ function calc_tderivative(integrator, cache) tf = cache.tf tf.u = uprev tf.p = p - dT = DI.derivative(tf, alg_autodiff(alg), t) + + autodiff_alg = alg_autodiff(alg) + autodiff_alg = if autodiff_alg isa AutoSparse + autodiff_alg = ADTypes.dense_ad(autodiff_alg) + else + autodiff_alg + end + + dT = DI.derivative(tf, autodiff_alg, t) end dT end