diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 4abb345822..726d171bd0 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -260,7 +260,16 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newp = remake_buffer( oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals) end - initprob = remake(oldinitprob; u0 = newu0, p = newp) + if oldinitprob.f.resid_prototype === nothing + newf = oldinitprob.f + else + newf = NonlinearFunction{ + SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}( + oldinitprob.f; + resid_prototype = calculate_resid_prototype( + length(oldinitprob.f.resid_prototype), newu0, newp)) + end + initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp) return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap) end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 94b44e0a6f..1289388197 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -283,6 +283,16 @@ function hessian_sparsity(sys::NonlinearSystem) unknowns(sys)) for eq in equations(sys)] end +function calculate_resid_prototype(N, u0, p) + u0ElType = u0 === nothing ? Float64 : eltype(u0) + if SciMLStructures.isscimlstructure(p) + u0ElType = promote_type( + eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), + u0ElType) + end + return zeros(u0ElType, N) +end + """ ```julia SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys), @@ -337,13 +347,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s if length(dvs) == length(equations(sys)) resid_prototype = nothing else - u0ElType = u0 === nothing ? Float64 : eltype(u0) - if SciMLStructures.isscimlstructure(p) - u0ElType = promote_type( - eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), - u0ElType) - end - resid_prototype = zeros(u0ElType, length(equations(sys))) + resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p) end NonlinearFunction{iip}(f, diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 0b0dc42c1e..f3015f7db0 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1032,3 +1032,20 @@ end @test prob3.f.initialization_data !== nothing @test init(prob3)[x] ≈ 0.5 end + +@testset "Issue#3246: type promotion with parameter dependent initialization_eqs" begin + @variables x(t)=1 y(t)=1 + @parameters a = 1 + @named sys = ODESystem([D(x) ~ 0, D(y) ~ x + a], t; initialization_eqs = [y ~ a]) + + ssys = structural_simplify(sys) + prob = ODEProblem(ssys, [], (0, 1), []) + + @test SciMLBase.successful_retcode(solve(prob)) + + seta = setsym_oop(prob, [a]) + (newu0, newp) = seta(prob, ForwardDiff.Dual{ForwardDiff.Tag{:tag, Float64}}.([1.0], 1)) + newprob = remake(prob, u0 = newu0, p = newp) + + @test SciMLBase.successful_retcode(solve(newprob)) +end