From 80f9a2d1bfbb7d50f8629ccb7df679c95dfaa41b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 27 Nov 2024 10:59:51 -0500 Subject: [PATCH] another calc_tderivative --- .../src/derivative_utils.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 0adc8fb5ce..dc6ca70d6d 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -42,7 +42,7 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step) alg = unwrap_alg(integrator, true) #derivative!(dT, tf, t, du2, integrator, cache.grad_config) autodiff_alg = alg_autodiff(alg) - + autodiff_alg = if autodiff_alg isa AutoSparse ADTypes.dense_ad(autodiff_alg) else @@ -68,7 +68,15 @@ function calc_tderivative(integrator, cache) tf = cache.tf tf.u = uprev tf.p = p - dT = DI.derivative(tf, alg_autodiff(alg), t) + + autodiff_alg = alg_autodiff(alg) + autodiff_alg = if autodiff_alg isa AutoSparse + autodiff_alg = ADTypes.dense_ad(autodiff_alg) + else + autodiff_alg + end + + dT = DI.derivative(tf, autodiff_alg, t) end dT end