Skip to content

Commit

Permalink
Merge pull request #3216 from AayushSabharwal/as/fix-daeprob
Browse files Browse the repository at this point in the history
fix: fix DAEProblem with array parameters
  • Loading branch information
ChrisRackauckas authored Nov 19, 2024
2 parents 37254ac + 8dccf9d commit 57e1a43
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
10 changes: 3 additions & 7 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,11 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
pre, sol_states = get_substitutions_and_solved_unknowns(sys)

if implicit_dae
# inputs = [] makes `wrap_array_vars` offset by 1 since there is an extra
# argument
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
states = sol_states,
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps) .∘
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps, inputs = []) .∘
wrap_parameter_dependencies(sys, false),
kwargs...)
else
Expand Down Expand Up @@ -790,12 +792,6 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

struct DiscreteSaveAffect{F, S} <: Function
f::F
s::S
end
(d::DiscreteSaveAffect)(args...) = d.f(args..., d.s)

function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
Expand Down
3 changes: 1 addition & 2 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,7 @@ function process_SciMLProblem(
ddvs = map(Differential(iv), dvs)
du0map = to_varmap(du0map, ddvs)
merge!(op, du0map)

du0 = varmap_to_vars(du0map, ddvs; toterm = identity,
du0 = varmap_to_vars(op, ddvs; toterm = identity,
tofloat = true)
kwargs = merge(kwargs, (; ddvs))
else
Expand Down
9 changes: 9 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1515,3 +1515,12 @@ end
sol = solve(prob, Tsit5())
@test sol[obs] 1:7
end

@testset "DAEProblem with array parameters" begin
@variables x(t)=1.0 y(t) [guess = 1.0]
@parameters p[1:2] = [1.0, 2.0]
@mtkbuild sys = ODESystem([D(x) ~ x, y^2 ~ x + sum(p)], t)
prob = DAEProblem(sys, [D(x) => x, D(y) => D(x) / 2y], [], (0.0, 1.0))
sol = solve(prob, DFBDF(), abstol=1e-8, reltol=1e-8)
@test sol[x]sol[y^2 - sum(p)] atol=1e-5
end

0 comments on commit 57e1a43

Please sign in to comment.