diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d99f5f69fa..48d647f621 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1831,3 +1831,21 @@ function missing_variable_defaults(sys::AbstractSystem, default = 0.0) return ds end + +keytype(::Type{<:Pair{T, V}}) where {T, V} = T +function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, Dict}) + if keytype(eltype(rules)) <: Symbol + dict = todict(rules) + systems = get_systems(sys) + # post-walk to avoid infinite recursion + @set! sys.systems = map(Base.Fix2(substitute, dict), systems) + something(get(rules, nameof(sys), nothing), sys) + elseif sys isa ODESystem + rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), + collect(rules))) + eqs = fast_substitute(equations(sys), rules) + ODESystem(eqs, get_iv(sys); name = nameof(sys)) + else + error("substituting symbols is not supported for $(typeof(sys))") + end +end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e368df4f1b..c123f7b15c 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -465,13 +465,6 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys)) checks = false) end -function Symbolics.substitute(sys::ODESystem, rules::Union{Vector{<:Pair}, Dict}) - rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), - collect(rules))) - eqs = fast_substitute(equations(sys), rules) - ODESystem(eqs, get_iv(sys); name = nameof(sys)) -end - """ $(SIGNATURES) diff --git a/test/odesystem.jl b/test/odesystem.jl index a45a45b1b0..e8b08fe4a2 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1012,3 +1012,22 @@ let prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0)) @test !isnothing(prob.f.sys) end + +@parameters t +# SYS 1: +vars_sub1 = @variables s1(t) +@named sub = ODESystem(Equation[], t, vars_sub1, []) + +vars1 = @variables x1(t) +@named sys1 = ODESystem(Equation[], t, vars1, [], systems = [sub]) +@named sys2 = ODESystem(Equation[], t, vars1, [], systems = [sys1, sub]) + +# SYS 2: Extension to SYS 1 +vars_sub2 = @variables s2(t) +@named partial_sub = ODESystem(Equation[], t, vars_sub2, []) +@named sub = extend(partial_sub, sub) + +new_sys2 = complete(substitute(sys2, Dict(:sub => sub))) +Set(states(new_sys2)) == Set([new_sys2.x1, new_sys2.sys1.x1, + new_sys2.sys1.sub.s1, new_sys2.sys1.sub.s2, + new_sys2.sub.s1, new_sys2.sub.s2])