Skip to content

Commit

Permalink
Implement substitute for AbstractSystems
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Sep 12, 2023
1 parent 80e1b95 commit bf63c1a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
18 changes: 18 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1831,3 +1831,21 @@ function missing_variable_defaults(sys::AbstractSystem, default = 0.0)

return ds
end

keytype(::Type{<:Pair{T, V}}) where {T, V} = T
function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, Dict})
if keytype(eltype(rules)) <: Symbol
dict = todict(rules)
systems = get_systems(sys)
# post-walk to avoid infinite recursion
@set! sys.systems = map(Base.Fix2(substitute, dict), systems)
something(get(rules, nameof(sys), nothing), sys)
elseif sys isa ODESystem
rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]),
collect(rules)))
eqs = fast_substitute(equations(sys), rules)
ODESystem(eqs, get_iv(sys); name = nameof(sys))
else
error("substituting symbols is not supported for $(typeof(sys))")

Check warning on line 1849 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1849

Added line #L1849 was not covered by tests
end
end
7 changes: 0 additions & 7 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,6 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys))
checks = false)
end

function Symbolics.substitute(sys::ODESystem, rules::Union{Vector{<:Pair}, Dict})
rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]),
collect(rules)))
eqs = fast_substitute(equations(sys), rules)
ODESystem(eqs, get_iv(sys); name = nameof(sys))
end

"""
$(SIGNATURES)
Expand Down
19 changes: 19 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1012,3 +1012,22 @@ let
prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0))
@test !isnothing(prob.f.sys)
end

@parameters t
# SYS 1:
vars_sub1 = @variables s1(t)
@named sub = ODESystem(Equation[], t, vars_sub1, [])

vars1 = @variables x1(t)
@named sys1 = ODESystem(Equation[], t, vars1, [], systems = [sub])
@named sys2 = ODESystem(Equation[], t, vars1, [], systems = [sys1, sub])

# SYS 2: Extension to SYS 1
vars_sub2 = @variables s2(t)
@named partial_sub = ODESystem(Equation[], t, vars_sub2, [])
@named sub = extend(partial_sub, sub)

new_sys2 = complete(substitute(sys2, Dict(:sub => sub)))
Set(states(new_sys2)) == Set([new_sys2.x1, new_sys2.sys1.x1,
new_sys2.sys1.sub.s1, new_sys2.sys1.sub.s2,
new_sys2.sub.s1, new_sys2.sub.s2])

0 comments on commit bf63c1a

Please sign in to comment.