Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: propagate ODEProblem guesses to remake #3226

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ REPL = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.57.1"
SciMLBase = "2.64"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1310,11 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
elseif isempty(u0map) && get_initializesystem(sys) === nothing
isys = structural_simplify(
generate_initializesystem(
sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
else
isys = structural_simplify(
generate_initializesystem(
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
end

ts = get_tearing_state(isys)
Expand Down
136 changes: 64 additions & 72 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ function generate_initializesystem(sys::ODESystem;
# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
guesses = merge(get_guesses(sys), todict(guesses))
additional_guesses = anydict(guesses)
guesses = merge(get_guesses(sys), additional_guesses)
schedule = getfield(sys, :schedule)
if !isnothing(schedule)
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
Expand Down Expand Up @@ -178,7 +179,7 @@ function generate_initializesystem(sys::ODESystem;
for k in keys(defs)
defs[k] = substitute(defs[k], paramsubs)
end
meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap))
meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses)
return NonlinearSystem(eqs_ics,
vars,
pars;
Expand All @@ -193,6 +194,7 @@ end
struct InitializationSystemMetadata
u0map::Dict{Any, Any}
pmap::Dict{Any, Any}
additional_guesses::Dict{Any, Any}
end

function is_parameter_solvable(p, pmap, defs, guesses)
Expand All @@ -208,17 +210,16 @@ function is_parameter_solvable(p, pmap, defs, guesses)
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
end

function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp)
if u0 === missing && p === missing
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
odefn.initializeprobpmap
return odefn.initialization_data
end
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
oldinitprob = odefn.initializeprob
if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) ||
!(oldinitprob.f.sys isa NonlinearSystem)
return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap,
odefn.initializeprobpmap
oldinitprob === nothing && return nothing
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!,
odefn.initializeprobmap, odefn.initializeprobpmap)
end
pidxs = ParameterIndex[]
pvals = []
Expand Down Expand Up @@ -260,78 +261,69 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
end
initprob = remake(oldinitprob; u0 = newu0, p = newp)
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
odefn.initializeprobpmap
return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!,
odefn.initializeprobmap, odefn.initializeprobpmap)
end
if u0 === missing || isempty(u0)
u0 = Dict()
elseif !(eltype(u0) <: Pair)
u0 = Dict(unknowns(sys) .=> u0)
end
if p === missing
p = Dict()
end
if t0 === nothing
t0 = 0.0
end
u0 = todict(u0)
dvs = unknowns(sys)
ps = parameters(sys)
u0map = to_varmap(u0, dvs)
symbols_to_symbolics!(sys, u0map)
pmap = to_varmap(p, ps)
symbols_to_symbolics!(sys, pmap)
guesses = Dict()
defs = defaults(sys)
varmap = merge(defs, u0)
for k in collect(keys(varmap))
if varmap[k] === nothing
delete!(varmap, k)
if SciMLBase.has_initializeprob(odefn)
oldsys = odefn.initializeprob.f.sys
meta = get_metadata(oldsys)
if meta isa InitializationSystemMetadata
u0map = merge(meta.u0map, u0map)
pmap = merge(meta.pmap, pmap)
merge!(guesses, meta.additional_guesses)
end
end
varmap = canonicalize_varmap(varmap)
missingvars = setdiff(unknowns(sys), collect(keys(varmap)))
setobserved = filter(keys(varmap)) do var
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
end
p = todict(p)
guesses = ModelingToolkit.guesses(sys)
solvablepars = [par
for par in parameters(sys)
if is_parameter_solvable(par, p, defs, guesses)]
pvarmap = merge(defs, p)
setparobserved = filter(keys(pvarmap)) do var
has_parameter_dependency_with_lhs(sys, var)
end
if (((!isempty(missingvars) || !isempty(solvablepars) ||
!isempty(setobserved) || !isempty(setparobserved)) &&
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys)))
if SciMLBase.has_initializeprob(odefn)
oldsys = odefn.initializeprob.f.sys
meta = get_metadata(oldsys)
if meta isa InitializationSystemMetadata
u0 = merge(meta.u0map, u0)
p = merge(meta.pmap, p)
else
# there is no initializeprob, so the original problem construction
# had no solvable parameters and had the differential variables
# specified in `u0map`.
if u0 === missing
# the user didn't pass `u0` to `remake`, so they want to retain
# existing values. Fill the differential variables in `u0map`,
# initialization will either be elided or solve for the algebraic
# variables
diff_idxs = isdiffeq.(equations(sys))
for i in eachindex(dvs)
diff_idxs[i] || continue
u0map[dvs[i]] = newu0[i]
end
end
for k in collect(keys(u0))
if u0[k] === nothing
delete!(u0, k)
if p === missing
# the user didn't pass `p` to `remake`, so they want to retain
# existing values. Fill all parameters in `pmap` so that none of
# them are solvable.
for p in ps
pmap[p] = getp(sys, p)(newp)
end
end
for k in collect(keys(p))
if p[k] === nothing
delete!(p, k)
end
# all non-solvable parameters need values regardless
for p in ps
haskey(pmap, p) && continue
is_parameter_solvable(p, pmap, defs, guesses) && continue
pmap[p] = getp(sys, p)(newp)
end

initprob = InitializationProblem(sys, t0, u0, p)
initprobmap = getu(initprob, unknowns(sys))
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
getpunknowns = getu(initprob, punknowns)
setpunknowns = setp(sys, punknowns)
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
reqd_syms = parameter_symbols(initprob)
update_initializeprob! = UpdateInitializeprob(
getu(sys, reqd_syms), setu(initprob, reqd_syms))
return initprob, update_initializeprob!, initprobmap, initprobpmap
else
return nothing, nothing, nothing, nothing
end
if t0 === nothing
t0 = 0.0
end
filter_missing_values!(u0map)
filter_missing_values!(pmap)
f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0)
kws = f.kwargs
initprob = get(kws, :initializeprob, nothing)
if initprob === nothing
return nothing
end
return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing),
get(kws, :initializeprobmap, nothing),
get(kws, :initializeprobpmap, nothing))
end

"""
Expand Down
3 changes: 2 additions & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ function OptimizationSystem(objective; constraints = [], kwargs...)
push!(new_ps, p)
end
end
return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
return OptimizationSystem(
objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
end

function flatten(sys::OptimizationSystem)
Expand Down
62 changes: 57 additions & 5 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ const AnyDict = Dict{Any, Any}
$(TYPEDSIGNATURES)

If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`
and `nothing`.
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`,
`missing` and `nothing`.
"""
anydict() = AnyDict()
anydict(::SciMLBase.NullParameters) = AnyDict()
anydict(::Nothing) = AnyDict()
anydict(::Missing) = AnyDict()
anydict(x::AnyDict) = x
anydict(x) = AnyDict(x)

Expand Down Expand Up @@ -51,6 +52,42 @@ function add_toterms(varmap::AbstractDict; toterm = default_toterm)
return cp
end

"""
$(TYPEDSIGNATURES)

Turn any `Symbol` keys in `varmap` to the appropriate symbolic variables in `sys`. Any
symbols that cannot be converted are ignored.
"""
function symbols_to_symbolics!(sys::AbstractSystem, varmap::AbstractDict)
if is_split(sys)
ic = get_index_cache(sys)
for k in collect(keys(varmap))
k isa Symbol || continue
newk = get(ic.symbol_to_variable, k, nothing)
newk === nothing && continue
varmap[newk] = varmap[k]
delete!(varmap, k)
end
else
syms = all_symbols(sys)
for k in collect(keys(varmap))
k isa Symbol || continue
idx = findfirst(syms) do sym
hasname(sym) || return false
name = getname(sym)
return name == k
end
idx === nothing && continue
newk = syms[idx]
if iscall(newk) && operation(newk) === getindex
newk = arguments(newk)[1]
end
varmap[newk] = varmap[k]
delete!(varmap, k)
end
end
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -388,6 +425,15 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
end
end

"""
$(TYPEDSIGNATURES)

Remove keys in `varmap` whose values are `nothing`.
"""
function filter_missing_values!(varmap::AbstractDict)
filter!(kvp -> kvp[2] !== nothing, varmap)
end

struct GetUpdatedMTKParameters{G, S}
# `getu` functor which gets parameters that are unknowns during initialization
getpunknowns::G
Expand Down Expand Up @@ -431,12 +477,16 @@ end
$(TYPEDEF)

A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in
case constructing a SciMLFunction is not required.
case constructing a SciMLFunction is not required. The arguments passed to it are available
in the `args` field, and the keyword arguments in the `kwargs` field.
"""
struct EmptySciMLFunction end
struct EmptySciMLFunction{A, K}
args::A
kwargs::K
end

function EmptySciMLFunction(args...; kwargs...)
return nothing
return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs)
end

"""
Expand Down Expand Up @@ -516,8 +566,10 @@ function process_SciMLProblem(
pType = typeof(pmap)
_u0map = u0map
u0map = to_varmap(u0map, dvs)
symbols_to_symbolics!(sys, u0map)
_pmap = pmap
pmap = to_varmap(pmap, ps)
symbols_to_symbolics!(sys, pmap)
defs = add_toterms(recursive_unwrap(defaults(sys)))
cmap, cs = get_cmap(sys)
kwargs = NamedTuple(kwargs)
Expand Down
1 change: 0 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ
return nothing
end


function collect_var!(unknowns, parameters, var, iv; depth = 0)
isequal(var, iv) && return nothing
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
Expand Down
57 changes: 57 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -975,3 +975,60 @@ end
@test integ.ps[p] ≈ 1.0
@test integ.ps[q]≈cbrt(2) rtol=1e-6
end

@testset "Guesses provided to `ODEProblem` are used in `remake`" begin
@variables x(t) y(t)=2x
@parameters p q=3x
@mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t)
prob = ODEProblem(
sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0])
@test prob[x] == 0.0
@test prob[y] == 0.0
@test prob.ps[p] == 1.0
@test prob.ps[q] == 0.0
integ = init(prob)
@test integ[x] ≈ 1 / cbrt(3)
@test integ[y] ≈ 2 / cbrt(3)
@test integ.ps[p] == 1.0
@test integ.ps[q] ≈ 3 / cbrt(3)
prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x])
integ2 = init(prob2)
@test integ2[x] ≈ cbrt(3 / 28)
@test integ2[y] ≈ 3cbrt(3 / 28)
@test integ2.ps[p] == 1.0
@test integ2.ps[q] ≈ 2cbrt(3 / 28)
end

@testset "Remake problem with no initializeprob" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p [guess = 1.0] q [guess = 1.0]
@mtkbuild sys = ODESystem(
[D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p])
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
@test prob.f.initialization_data === nothing
prob2 = remake(prob; u0 = [x => 2.0])
@test prob2[x] == 2.0
@test prob2.f.initialization_data === nothing
prob3 = remake(prob; u0 = [y => 2.0])
@test prob3.f.initialization_data !== nothing
@test init(prob3)[x] ≈ 1.0
prob4 = remake(prob; p = [p => 1.0])
@test prob4.f.initialization_data === nothing
prob5 = remake(prob; p = [p => missing, q => 2.0])
@test prob5.f.initialization_data !== nothing
@test init(prob5).ps[p] ≈ 1.0
end

@testset "Variables provided as symbols" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p [guess = 1.0] q [guess = 1.0]
@mtkbuild sys = ODESystem(
[D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p])
prob = ODEProblem(sys, [:x => 1.0], (0.0, 1.0), [p => 1.0])
@test prob.f.initialization_data === nothing
prob2 = remake(prob; u0 = [:x => 2.0])
@test prob2.f.initialization_data === nothing
prob3 = remake(prob; u0 = [:y => 1.0])
@test prob3.f.initialization_data !== nothing
@test init(prob3)[x] ≈ 0.5
end
Loading