Skip to content

Commit

Permalink
Merge pull request #3252 from AayushSabharwal/as/initprob-resid-proto…
Browse files Browse the repository at this point in the history
…type

fix: recalculate resid_prototype in remake_initialization_data
  • Loading branch information
ChrisRackauckas authored Dec 3, 2024
2 parents 4626fe7 + c578da1 commit 1f669d9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
11 changes: 10 additions & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,16 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p,
newp = remake_buffer(
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
end
initprob = remake(oldinitprob; u0 = newu0, p = newp)
if oldinitprob.f.resid_prototype === nothing
newf = oldinitprob.f
else
newf = NonlinearFunction{
SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}(
oldinitprob.f;
resid_prototype = calculate_resid_prototype(
length(oldinitprob.f.resid_prototype), newu0, newp))
end
initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp)
return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!,
odefn.initializeprobmap, odefn.initializeprobpmap)
end
Expand Down
18 changes: 11 additions & 7 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ function hessian_sparsity(sys::NonlinearSystem)
unknowns(sys)) for eq in equations(sys)]
end

function calculate_resid_prototype(N, u0, p)
u0ElType = u0 === nothing ? Float64 : eltype(u0)
if SciMLStructures.isscimlstructure(p)
u0ElType = promote_type(
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
u0ElType)
end
return zeros(u0ElType, N)
end

"""
```julia
SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
Expand Down Expand Up @@ -337,13 +347,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
if length(dvs) == length(equations(sys))
resid_prototype = nothing
else
u0ElType = u0 === nothing ? Float64 : eltype(u0)
if SciMLStructures.isscimlstructure(p)
u0ElType = promote_type(
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
u0ElType)
end
resid_prototype = zeros(u0ElType, length(equations(sys)))
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
end

NonlinearFunction{iip}(f,
Expand Down
17 changes: 17 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1032,3 +1032,20 @@ end
@test prob3.f.initialization_data !== nothing
@test init(prob3)[x] 0.5
end

@testset "Issue#3246: type promotion with parameter dependent initialization_eqs" begin
@variables x(t)=1 y(t)=1
@parameters a = 1
@named sys = ODESystem([D(x) ~ 0, D(y) ~ x + a], t; initialization_eqs = [y ~ a])

ssys = structural_simplify(sys)
prob = ODEProblem(ssys, [], (0, 1), [])

@test SciMLBase.successful_retcode(solve(prob))

seta = setsym_oop(prob, [a])
(newu0, newp) = seta(prob, ForwardDiff.Dual{ForwardDiff.Tag{:tag, Float64}}.([1.0], 1))
newprob = remake(prob, u0 = newu0, p = newp)

@test SciMLBase.successful_retcode(solve(newprob))
end

0 comments on commit 1f669d9

Please sign in to comment.