From bf63c1a8093ed124d395a66ccc22ef629295a0f0 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 12 Sep 2023 11:55:28 -0400 Subject: [PATCH 1/2] Implement `substitute` for `AbstractSystem`s --- src/systems/abstractsystem.jl | 18 ++++++++++++++++++ src/systems/diffeqs/odesystem.jl | 7 ------- test/odesystem.jl | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d99f5f69fa..48d647f621 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1831,3 +1831,21 @@ function missing_variable_defaults(sys::AbstractSystem, default = 0.0) return ds end + +keytype(::Type{<:Pair{T, V}}) where {T, V} = T +function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, Dict}) + if keytype(eltype(rules)) <: Symbol + dict = todict(rules) + systems = get_systems(sys) + # post-walk to avoid infinite recursion + @set! sys.systems = map(Base.Fix2(substitute, dict), systems) + something(get(rules, nameof(sys), nothing), sys) + elseif sys isa ODESystem + rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), + collect(rules))) + eqs = fast_substitute(equations(sys), rules) + ODESystem(eqs, get_iv(sys); name = nameof(sys)) + else + error("substituting symbols is not supported for $(typeof(sys))") + end +end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e368df4f1b..c123f7b15c 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -465,13 +465,6 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys)) checks = false) end -function Symbolics.substitute(sys::ODESystem, rules::Union{Vector{<:Pair}, Dict}) - rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), - collect(rules))) - eqs = fast_substitute(equations(sys), rules) - ODESystem(eqs, get_iv(sys); name = nameof(sys)) -end - """ $(SIGNATURES) diff --git a/test/odesystem.jl b/test/odesystem.jl index a45a45b1b0..e8b08fe4a2 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1012,3 +1012,22 @@ let prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0)) @test !isnothing(prob.f.sys) end + +@parameters t +# SYS 1: +vars_sub1 = @variables s1(t) +@named sub = ODESystem(Equation[], t, vars_sub1, []) + +vars1 = @variables x1(t) +@named sys1 = ODESystem(Equation[], t, vars1, [], systems = [sub]) +@named sys2 = ODESystem(Equation[], t, vars1, [], systems = [sys1, sub]) + +# SYS 2: Extension to SYS 1 +vars_sub2 = @variables s2(t) +@named partial_sub = ODESystem(Equation[], t, vars_sub2, []) +@named sub = extend(partial_sub, sub) + +new_sys2 = complete(substitute(sys2, Dict(:sub => sub))) +Set(states(new_sys2)) == Set([new_sys2.x1, new_sys2.sys1.x1, + new_sys2.sys1.sub.s1, new_sys2.sys1.sub.s2, + new_sys2.sub.s1, new_sys2.sub.s2]) From 5981d758deb69805f302cd29a50840aabe133777 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 12 Sep 2023 12:10:10 -0400 Subject: [PATCH 2/2] Format --- src/ModelingToolkit.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 89dee632f4..16e9ab6548 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -3,7 +3,7 @@ $(DocStringExtensions.README) """ module ModelingToolkit using PrecompileTools, Reexport -@recompile_invalidations begin +@recompile_invalidations begin using DocStringExtensions using Compat using AbstractTrees