Skip to content

Commit

Permalink
feat: extend CSE hack to non-observed equations
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 24, 2024
1 parent d2fe0eb commit 1025363
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 12 deletions.
45 changes: 39 additions & 6 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
end
@set! sys.unknowns = unknowns

obs, subeqs = cse_and_array_hacks(
obs, subeqs, deps = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
Expand Down Expand Up @@ -637,10 +637,7 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
vars!(all_vars, rhs)

# HACK 1
if cse &&
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
iscall(rhs) && operation(rhs) === getindex &&
Symbolics.shape(rhs) != Symbolics.Unknown()
if cse && is_getindexed_array(rhs)
rhs_arr = arguments(rhs)[1]
if !haskey(rhs_to_tempvar, rhs_arr)
tempvar = gensym(Symbol(lhs))
Expand Down Expand Up @@ -677,6 +674,33 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
arr_obs_occurrences[arg1] = cnt + 1
continue
end

# Also do CSE for `equations(sys)`
if cse
for (i, eq) in enumerate(neweqs)
(; lhs, rhs) = eq
is_getindexed_array(rhs) || continue
rhs_arr = arguments(rhs)[1]
if !haskey(rhs_to_tempvar, rhs_arr)
tempvar = gensym(Symbol(lhs))
N = length(rhs_arr)
tempvar = unwrap(Symbolics.variable(
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
push!(obs, tempeq)
push!(subeqs, tempeq)
end
# don't need getindex_wrapper, but do it anyway to know that this
# hack took place
neweq = lhs ~ getindex_wrapper(
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
neweqs[i] = neweq
end
end

# count variables in unknowns if they are scalarized forms of variables
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
# is an observed equation.
Expand Down Expand Up @@ -713,7 +737,16 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
# need to re-sort subeqs
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])

return obs, subeqs
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
for i in 1:length(subeqs)]

return obs, subeqs, deps
end

function is_getindexed_array(rhs)
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
iscall(rhs) && operation(rhs) === getindex &&
Symbolics.shape(rhs) != Symbolics.Unknown()
end

# PART OF HACK 1
Expand Down
24 changes: 18 additions & 6 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ function generate_initializesystem(sys::ODESystem;
algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), kwargs...)
trueobs = unhack_observed(observed(sys))
trueobs, eqs = unhack_observed(observed(sys), equations(sys))
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
vars_set = Set(vars) # for efficient in-lookup

eqs = equations(sys)
idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff

Expand Down Expand Up @@ -329,11 +328,11 @@ end
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
initialization.
"""
function unhack_observed(eqs::Vector{Equation})
function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
subs = Dict()
tempvars = Set()
rm_idxs = Int[]
for (i, eq) in enumerate(eqs)
for (i, eq) in enumerate(obseqs)
iscall(eq.rhs) || continue
if operation(eq.rhs) == StructuralTransformations.change_origin
push!(rm_idxs, i)
Expand All @@ -347,14 +346,27 @@ function unhack_observed(eqs::Vector{Equation})
end

for (i, eq) in enumerate(eqs)
iscall(eq.rhs) || continue
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
var, idxs = arguments(eq.rhs)
subs[eq.rhs] = var[idxs...]
push!(tempvars, var)
end
end

for (i, eq) in enumerate(obseqs)
if eq.lhs in tempvars
subs[eq.lhs] = eq.rhs
push!(rm_idxs, i)
end
end

eqs = eqs[setdiff(eachindex(eqs), rm_idxs)]
return map(eqs) do eq
obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)]
obseqs = map(obseqs) do eq
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
end
eqs = map(eqs) do eq
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
end
return obseqs, eqs
end
22 changes: 22 additions & 0 deletions test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,28 @@ end
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
StructuralTransformations.change_origin]
end

@testset "CSE hack in equations(sys)" begin
val[] = 0
@variables z(t)[1:2]
@mtkbuild sys = ODESystem(
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
@test length(equations(sys)) == 5
@test length(observed(sys)) == 2
prob = ODEProblem(
sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2])
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
@test val[] == 2

isys = ModelingToolkit.generate_initializesystem(sys)
@test length(unknowns(isys)) == 5
@test length(equations(isys)) == 2
@test !any(equations(isys)) do eq
iscall(eq.rhs) &&
operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
StructuralTransformations.change_origin]
end
end
end

@testset "array and cse hacks can be disabled" begin
Expand Down

0 comments on commit 1025363

Please sign in to comment.