Skip to content

Commit

Permalink
Merge pull request #2992 from AayushSabharwal/as/bugs
Browse files Browse the repository at this point in the history
fix: fix substitute duplicating equations
  • Loading branch information
ChrisRackauckas authored Aug 28, 2024
2 parents 677cf30 + d7c673a commit 2813251
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
36 changes: 28 additions & 8 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2934,12 +2934,12 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
elseif sys isa ODESystem
rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]),
collect(rules)))
eqs = fast_substitute(equations(sys), rules)
pdeps = fast_substitute(parameter_dependencies(sys), rules)
eqs = fast_substitute(get_eqs(sys), rules)
pdeps = fast_substitute(get_parameter_dependencies(sys), rules)
defs = Dict(fast_substitute(k, rules) => fast_substitute(v, rules)
for (k, v) in defaults(sys))
for (k, v) in get_defaults(sys))
guess = Dict(fast_substitute(k, rules) => fast_substitute(v, rules)
for (k, v) in guesses(sys))
for (k, v) in get_guesses(sys))
subsys = map(s -> substitute(s, rules), get_systems(sys))
ODESystem(eqs, get_iv(sys); name = nameof(sys), defaults = defs,
guesses = guess, parameter_dependencies = pdeps, systems = subsys)
Expand All @@ -2948,14 +2948,34 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
end
end

struct InvalidParameterDependenciesType
got::Any
end

function Base.showerror(io::IO, err::InvalidParameterDependenciesType)
print(
io, "Parameter dependencies must be a `Dict`, or an array of `Pair` or `Equation`.")
if err.got !== nothing
print(io, " Got ", err.got)
end
end

function process_parameter_dependencies(pdeps, ps)
if pdeps === nothing || isempty(pdeps)
return Equation[], ps
elseif eltype(pdeps) <: Pair
pdeps = [lhs ~ rhs for (lhs, rhs) in pdeps]
end
if !(eltype(pdeps) <: Equation)
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
if pdeps isa Dict
pdeps = [k ~ v for (k, v) in pdeps]
else
pdeps isa AbstractArray || throw(InvalidParameterDependenciesType(pdeps))
pdeps = [if p isa Pair
p[1] ~ p[2]
elseif p isa Equation
p
else
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
end
for p in pdeps]
end
lhss = BasicSymbolic[]
for p in pdeps
Expand Down
2 changes: 1 addition & 1 deletion test/dq_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,4 @@ end
@variables X(tt) [unit = u"L"]
DD = Differential(tt)
eqs = [DD(X) ~ p - d * X + d * X]
@test ModelingToolkit.validate(eqs)
@test ModelingToolkit.validate(eqs)
24 changes: 24 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1331,3 +1331,27 @@ end
@test length(ModelingToolkit.guesses(sys2)) == 2
@test ModelingToolkit.guesses(sys2)[p3] == 2.0
end

@testset "Substituting with nested systems" begin
@parameters p1 p2
@variables x(t) y(t)
@named innersys = ODESystem([D(x) ~ y + p2], t; parameter_dependencies = [p2 ~ 2p1],
defaults = [p1 => 1.0, p2 => 2.0], guesses = [p1 => 2.0, p2 => 3.0])
@parameters p3 p4
@named outersys = ODESystem(
[D(innersys.y) ~ innersys.y + p4], t; parameter_dependencies = [p4 ~ 3p3],
defaults = [p3 => 3.0, p4 => 9.0], guesses = [p4 => 10.0], systems = [innersys])
@test_nowarn structural_simplify(outersys)
@parameters p5
sys2 = substitute(outersys, [p4 => p5])
@test_nowarn structural_simplify(sys2)
@test length(equations(sys2)) == 2
@test length(parameters(sys2)) == 2
@test length(full_parameters(sys2)) == 4
@test all(!isequal(p4), full_parameters(sys2))
@test any(isequal(p5), full_parameters(sys2))
@test length(ModelingToolkit.defaults(sys2)) == 4
@test ModelingToolkit.defaults(sys2)[p5] == 9.0
@test length(ModelingToolkit.guesses(sys2)) == 3
@test ModelingToolkit.guesses(sys2)[p5] == 10.0
end

0 comments on commit 2813251

Please sign in to comment.