Skip to content

Commit

Permalink
Merge pull request #3253 from AayushSabharwal/as/init-everywhere
Browse files Browse the repository at this point in the history
feat: create initialization systems for all problem types
  • Loading branch information
ChrisRackauckas authored Dec 25, 2024
2 parents dad05e5 + 2e07200 commit d5a48a4
Show file tree
Hide file tree
Showing 17 changed files with 750 additions and 393 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ ConstructionBase = "1"
DataInterpolations = "6.4"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
DelayDiffEq = "5.50"
DiffEqBase = "6.157"
DiffEqCallbacks = "2.16, 3, 4"
DiffEqNoiseProcess = "5"
Expand Down Expand Up @@ -117,7 +118,7 @@ Libdl = "1"
LinearAlgebra = "1"
MLStyle = "0.4.17"
NaNMath = "0.3, 1"
NonlinearSolve = "3.14, 4"
NonlinearSolve = "4.3"
OffsetArrays = "1"
OrderedCollections = "1"
OrdinaryDiffEq = "6.82.0"
Expand All @@ -129,15 +130,17 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.66"
SciMLBase = "2.68.1"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.35"
StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.7"
Symbolics = "6.19"
URIs = "1"
Expand Down
61 changes: 28 additions & 33 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
sparsity = false,
analytic = nothing,
split_idxs = nothing,
initializeprob = nothing,
update_initializeprob! = nothing,
initializeprobmap = nothing,
initializeprobpmap = nothing,
initialization_data = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
Expand Down Expand Up @@ -463,10 +460,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
observed = observedfun,
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
analytic = analytic,
initializeprob = initializeprob,
update_initializeprob! = update_initializeprob!,
initializeprobmap = initializeprobmap,
initializeprobpmap = initializeprobpmap)
initialization_data)
end

"""
Expand Down Expand Up @@ -496,10 +490,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
sparse = false, simplify = false,
eval_module = @__MODULE__,
checkbounds = false,
initializeprob = nothing,
initializeprobmap = nothing,
initializeprobpmap = nothing,
update_initializeprob! = nothing,
initialization_data = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
Expand Down Expand Up @@ -547,15 +538,12 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
nothing
end

DAEFunction{iip}(f,
DAEFunction{iip}(f;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
jac_prototype = jac_prototype,
observed = observedfun,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap,
initializeprobpmap = initializeprobpmap,
update_initializeprob! = update_initializeprob!)
initialization_data)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand All @@ -567,6 +555,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
eval_expression = false,
eval_module = @__MODULE__,
checkbounds = false,
initialization_data = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`")
Expand All @@ -579,7 +568,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
f(u, h, p, t) = f_oop(u, h, p, t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)

DDEFunction{iip}(f, sys = sys)
DDEFunction{iip}(f; sys = sys, initialization_data)
end

function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand All @@ -591,6 +580,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
eval_expression = false,
eval_module = @__MODULE__,
checkbounds = false,
initialization_data = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`")
Expand All @@ -609,7 +599,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
g(u, h, p, t) = g_oop(u, h, p, t)
g(du, u, h, p, t) = g_iip(du, u, h, p, t)

SDDEFunction{iip}(f, g, sys = sys)
SDDEFunction{iip}(f, g; sys = sys, initialization_data)
end

"""
Expand Down Expand Up @@ -933,7 +923,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
h(p, t) = h_oop(p, t)
h(p::MTKParameters, t) = h_oop(p..., t)
u0 = h(p, tspan[1])
u0 = float.(h(p, tspan[1]))
if u0 !== nothing
u0 = u0_constructor(u0)
end
Expand Down Expand Up @@ -1257,23 +1247,23 @@ Generates a NonlinearProblem or NonlinearLeastSquaresProblem from an ODESystem
which represents the initialization, i.e. the calculation of the consistent
initial conditions for the given DAE.
"""
function InitializationProblem(sys::AbstractODESystem, args...; kwargs...)
function InitializationProblem(sys::AbstractSystem, args...; kwargs...)
InitializationProblem{true}(sys, args...; kwargs...)
end

function InitializationProblem(sys::AbstractODESystem, t,
function InitializationProblem(sys::AbstractSystem, t,
u0map::StaticArray,
args...;
kwargs...)
InitializationProblem{false, SciMLBase.FullSpecialize}(
sys, t, u0map, args...; kwargs...)
end

function InitializationProblem{true}(sys::AbstractODESystem, args...; kwargs...)
function InitializationProblem{true}(sys::AbstractSystem, args...; kwargs...)
InitializationProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...)
function InitializationProblem{false}(sys::AbstractSystem, args...; kwargs...)
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

Expand All @@ -1292,8 +1282,8 @@ function Base.showerror(io::IO, e::IncompleteInitializationError)
println(io, e.uninit)
end

function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
t::Number, u0map = [],
function InitializationProblem{iip, specialize}(sys::AbstractSystem,
t, u0map = [],
parammap = DiffEqBase.NullParameters();
guesses = [],
check_length = true,
Expand All @@ -1320,6 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined)
end

meta = get_metadata(isys)
if meta isa InitializationSystemMetadata
@set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys)
end

ts = get_tearing_state(isys)
unassigned_vars = StructuralTransformations.singular_check(ts)
if warn_initialize_determined && !isempty(unassigned_vars)
Expand Down Expand Up @@ -1357,13 +1352,13 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. $(scc_message)To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true"
end

parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
[get_iv(sys) => t] :
merge(todict(parammap), Dict(get_iv(sys) => t))
parammap = Dict(k => v for (k, v) in parammap if v !== missing)
if isempty(u0map)
u0map = Dict()
parammap = recursive_unwrap(anydict(parammap))
if t !== nothing
parammap[get_iv(sys)] = t
end
filter!(kvp -> kvp[2] !== missing, parammap)

u0map = to_varmap(u0map, unknowns(sys))
if isempty(guesses)
guesses = Dict()
end
Expand Down Expand Up @@ -1405,5 +1400,5 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
else
NonlinearLeastSquaresProblem
end
TProb(isys, u0map, parammap; kwargs...)
TProb(isys, u0map, parammap; kwargs..., build_initializeprob = false)
end
27 changes: 7 additions & 20 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,29 +256,16 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
:ODESystem, force = true)
end
defaults = Dict{Any, Any}(todict(defaults))
guesses = Dict{Any, Any}(todict(guesses))
var_to_name = Dict()
process_variables!(var_to_name, defaults, dvs′)
process_variables!(var_to_name, defaults, ps′)
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
process_variables!(var_to_name, defaults, guesses, dvs′)
process_variables!(var_to_name, defaults, guesses, ps′)
process_variables!(
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
process_variables!(
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
defaults = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(defaults) if v !== nothing)

sysdvsguesses = [ModelingToolkit.getguess(st) for st in dvs′]
hasaguess = findall(!isnothing, sysdvsguesses)
var_guesses = dvs′[hasaguess] .=> sysdvsguesses[hasaguess]
sysdvsguesses = isempty(var_guesses) ? Dict() : todict(var_guesses)
syspsguesses = [ModelingToolkit.getguess(st) for st in ps′]
hasaguess = findall(!isnothing, syspsguesses)
ps_guesses = ps′[hasaguess] .=> syspsguesses[hasaguess]
syspsguesses = isempty(ps_guesses) ? Dict() : todict(ps_guesses)
syspdepguesses = [ModelingToolkit.getguess(eq.lhs) for eq in parameter_dependencies]
hasaguess = findall(!isnothing, syspdepguesses)
pdep_guesses = [eq.lhs for eq in parameter_dependencies][hasaguess] .=>
syspdepguesses[hasaguess]
syspdepguesses = isempty(pdep_guesses) ? Dict() : todict(pdep_guesses)

guesses = merge(sysdvsguesses, syspsguesses, syspdepguesses, todict(guesses))
guesses = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(guesses) if v !== nothing)

Expand Down
62 changes: 44 additions & 18 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ struct SDESystem <: AbstractODESystem
"""
defaults::Dict
"""
The guesses to use as the initial conditions for the
initialization system.
"""
guesses::Dict
"""
The system for performing the initialization.
"""
initializesystem::Union{Nothing, NonlinearSystem}
"""
Extra equations to be enforced during the initialization sequence.
"""
initialization_eqs::Vector{Equation}
"""
Type of the system.
"""
connector_type::Any
Expand Down Expand Up @@ -144,9 +157,8 @@ struct SDESystem <: AbstractODESystem
isscheduled::Bool

function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
tgrad,
jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
guesses, initializesystem, initialization_eqs, connector_type,
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
is_dde = false,
Expand All @@ -171,9 +183,9 @@ struct SDESystem <: AbstractODESystem
check_units(u, deqs, neqs)
end
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac,
Wfact, Wfact_t, name, description, systems,
defaults, connector_type, cevents, devents,
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
devents,
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
is_dde, isscheduled)
end
Expand All @@ -187,6 +199,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
default_u0 = Dict(),
default_p = Dict(),
defaults = _merge(Dict(default_u0), Dict(default_p)),
guesses = Dict(),
initializesystem = nothing,
initialization_eqs = Equation[],
name = nothing,
description = "",
connector_type = nothing,
Expand All @@ -207,6 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
dvs′ = value.(dvs)
ps′ = value.(ps)
ctrl′ = value.(controls)
parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)

sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
Expand All @@ -217,13 +234,21 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:SDESystem, force = true)
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

defaults = Dict{Any, Any}(todict(defaults))
guesses = Dict{Any, Any}(todict(guesses))
var_to_name = Dict()
process_variables!(var_to_name, defaults, dvs′)
process_variables!(var_to_name, defaults, ps′)
process_variables!(var_to_name, defaults, guesses, dvs′)
process_variables!(var_to_name, defaults, guesses, ps′)
process_variables!(
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
process_variables!(
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
defaults = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(defaults) if v !== nothing)
guesses = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(guesses) if v !== nothing)

isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))

tgrad = RefValue(EMPTY_TGRAD)
Expand All @@ -233,14 +258,13 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
Wfact_t = RefValue(EMPTY_JAC)
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)
if is_dde === nothing
is_dde = _check_if_dde(deqs, iv′, systems)
end
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
initializesystem, initialization_eqs, connector_type,
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
end
Expand Down Expand Up @@ -520,7 +544,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
version = nothing, tgrad = false, sparse = false,
jac = false, Wfact = false, eval_expression = false,
eval_module = @__MODULE__,
checkbounds = false,
checkbounds = false, initialization_data = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
Expand Down Expand Up @@ -591,13 +615,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

SDEFunction{iip, specialize}(f, g,
SDEFunction{iip, specialize}(f, g;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
tgrad = _tgrad === nothing ? nothing : _tgrad,
Wfact = _Wfact === nothing ? nothing : _Wfact,
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
mass_matrix = _M,
mass_matrix = _M, initialization_data,
observed = observedfun)
end

Expand Down Expand Up @@ -714,7 +738,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
end
f, u0, p = process_SciMLProblem(
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
kwargs...)
t = tspan === nothing ? nothing : tspan[1], kwargs...)
cbs = process_events(sys; callback, kwargs...)
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))

Expand All @@ -736,6 +760,8 @@ function DiffEqBase.SDEProblem{iip, specialize}(
noise = nothing
end

kwargs = filter_kwargs(kwargs)

SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
noise_rate_prototype = noise_rate_prototype, kwargs...)
end
Expand Down
Loading

0 comments on commit d5a48a4

Please sign in to comment.