Skip to content

Commit

Permalink
Merge pull request #2861 from hersle/init_defaults_flipping
Browse files Browse the repository at this point in the history
Left-to-right expand observed equations into defaults during initialization
  • Loading branch information
ChrisRackauckas authored Jul 17, 2024
2 parents 6d8e86e + 8862b97 commit 2f5c718
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
14 changes: 10 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -725,10 +725,16 @@ function get_u0(
defs = mergedefaults(defs, parammap, ps)
end

obs = filter!(x -> !(x[1] isa Number),
map(x -> isparameter(x.rhs) ? x.lhs => x.rhs : x.rhs => x.lhs, observed(sys)))
observedmap = isempty(obs) ? Dict() : todict(obs)
defs = mergedefaults(defs, observedmap, u0map, dvs)
# Convert observed equations "lhs ~ rhs" into defaults.
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
# if "lhs" is known by other means (parameter, another default, ...)
# TODO: Is there a better way to determine which equations to flip?
obs = map(x -> x.lhs => x.rhs, observed(sys))
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
obsmap = isempty(obs) ? Dict() : todict(obs)

defs = mergedefaults(defs, obsmap, u0map, dvs)
if symbolic_u0
u0 = varmap_to_vars(
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
Expand Down
12 changes: 6 additions & 6 deletions test/guess_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,21 @@ prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
@variables y(t) = x0
@mtkbuild sys = ODESystem([x ~ x0, D(y) ~ x], t)
prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
prob[x] == 1.0
prob[y] == 1.0
@test prob[x] == 1.0
@test prob[y] == 1.0

@parameters x0
@variables x(t)
@variables y(t) = x0
@mtkbuild sys = ODESystem([x ~ y, D(y) ~ x], t)
prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
prob[x] == 1.0
prob[y] == 1.0
@test prob[x] == 1.0
@test prob[y] == 1.0

@parameters x0
@variables x(t) = x0
@variables y(t) = x
@mtkbuild sys = ODESystem([x ~ y, D(y) ~ x], t)
prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
prob[x] == 1.0
prob[y] == 1.0
@test prob[x] == 1.0
@test prob[y] == 1.0
20 changes: 20 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1194,3 +1194,23 @@ end
@test_nowarn obsfn(buffer, [1.0], ps..., 3.0)
@test buffer [2.0, 3.0, 4.0]
end

# https://github.com/SciML/ModelingToolkit.jl/issues/2859
@testset "Initialization with defaults from observed equations (edge case)" begin
@variables x(t) y(t) z(t)
eqs = [D(x) ~ 0, y ~ x, D(z) ~ 0]
defaults = [x => 1, z => y]
@named sys = ODESystem(eqs, t; defaults)
ssys = structural_simplify(sys)
prob = ODEProblem(ssys, [], (0.0, 1.0), [])
@test prob[x] == prob[y] == prob[z] == 1.0

@parameters y0
@variables x(t) y(t) z(t)
eqs = [D(x) ~ 0, y ~ y0 / x, D(z) ~ y]
defaults = [y0 => 1, x => 1, z => y]
@named sys = ODESystem(eqs, t; defaults)
ssys = structural_simplify(sys)
prob = ODEProblem(ssys, [], (0.0, 1.0), [])
@test prob[x] == prob[y] == prob[z] == 1.0
end

0 comments on commit 2f5c718

Please sign in to comment.