Skip to content

Commit

Permalink
feat: allow CSE and array hacks to be disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 24, 2024
1 parent 79a7fc9 commit d2fe0eb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 33 deletions.
79 changes: 46 additions & 33 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ end
=#

function tearing_reassemble(state::TearingState, var_eq_matching,
full_var_eq_matching = nothing; simplify = false, mm = nothing)
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
@unpack fullvars, sys, structure = state
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
extra_vars = Int[]
Expand Down Expand Up @@ -584,24 +584,48 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
end
@set! sys.unknowns = unknowns

# HACK: Since we don't support array equations, any equation of the sort
# `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
# calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
# for this case to avoid the unnecessary cost.
# This and the below hack are implemented simultaneously
obs, subeqs = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs

@set! sys.substitutions = Substitutions(subeqs, deps)

# Only makes sense for time-dependent
# TODO: generalize to SDE
if sys isa ODESystem
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
end
sys = schedule(sys)
@set! state.sys = sys
@set! sys.tearing_state = state
return invalidate_cache!(sys)
end

"""
# HACK 1
Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
avoid the unnecessary cost. This and the below hack are implemented simultaneously
# HACK 2
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
not) we first count the number of times the scalarized form of each observed variable
occurs in observed equations (and unknowns if it's split).
"""
function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true)
# HACK 1
# mapping of rhs to temporary CSE variable
# `f(...) => tmpvar` in above example
rhs_to_tempvar = Dict()

# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
# are equations, add an equation `p ~ [p[1], p[2], ...]`
# allow topsort to reorder them
# only add the new equation if all `p[i]` are present and the unscalarized
# form is used in any equation (observed or not)
# we first count the number of times the scalarized form of each observed
# variable occurs in observed equations (and unknowns if it's split).

# HACK 2
# map of array observed variable (unscalarized) to number of its
# scalarized terms that appear in observed equations
arr_obs_occurrences = Dict()
Expand All @@ -613,7 +637,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
vars!(all_vars, rhs)

# HACK 1
if (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
if cse &&
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
iscall(rhs) && operation(rhs) === getindex &&
Symbolics.shape(rhs) != Symbolics.Unknown()
rhs_arr = arguments(rhs)[1]
Expand Down Expand Up @@ -643,6 +668,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
end
# end HACK 1

array || continue
iscall(lhs) || continue
operation(lhs) === getindex || continue
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
Expand Down Expand Up @@ -687,20 +713,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
# need to re-sort subeqs
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])

@set! sys.eqs = neweqs
@set! sys.observed = obs

@set! sys.substitutions = Substitutions(subeqs, deps)

# Only makes sense for time-dependent
# TODO: generalize to SDE
if sys isa ODESystem
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
end
sys = schedule(sys)
@set! state.sys = sys
@set! sys.tearing_state = state
return invalidate_cache!(sys)
return obs, subeqs
end

# PART OF HACK 1
Expand Down Expand Up @@ -733,10 +746,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
instead, which calls this function internally.
"""
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
simplify = false, kwargs...)
simplify = false, cse_hack = true, array_hack = true, kwargs...)
var_eq_matching, full_var_eq_matching = tearing(state)
invalidate_cache!(tearing_reassemble(
state, var_eq_matching, full_var_eq_matching; mm, simplify))
state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
end

"""
Expand All @@ -758,7 +771,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
the system is balanced.
"""
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
mm = nothing, kwargs...)
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
jac = let state = state
(eqs, vars) -> begin
symeqs = EquationsView(state)[eqs]
Expand All @@ -782,5 +795,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
end
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
kwargs...)
tearing_reassemble(state, var_eq_matching; simplify, mm)
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
end
44 changes: 44 additions & 0 deletions test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,47 @@ end
StructuralTransformations.change_origin]
end
end

@testset "array and cse hacks can be disabled" begin
@testset "fully_determined = true" begin
@variables x(t) y(t)[1:2] z(t)[1:2]
@parameters foo(::AbstractVector)[1:2]
_tmp_fn(x) = 2x
@named sys = ODESystem(
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)

sys1 = structural_simplify(sys; cse_hack = false)
@test length(observed(sys1)) == 6
@test !any(observed(sys1)) do eq
iscall(eq.rhs) &&
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
end

sys2 = structural_simplify(sys; array_hack = false)
@test length(observed(sys2)) == 5
@test !any(observed(sys2)) do eq
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
end
end

@testset "fully_determined = false" begin
@variables x(t) y(t)[1:2] z(t)[1:2] w(t)
@parameters foo(::AbstractVector)[1:2]
_tmp_fn(x) = 2x
@named sys = ODESystem(
[D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)

sys1 = structural_simplify(sys; cse_hack = false, fully_determined = false)
@test length(observed(sys1)) == 6
@test !any(observed(sys1)) do eq
iscall(eq.rhs) &&
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
end

sys2 = structural_simplify(sys; array_hack = false, fully_determined = false)
@test length(observed(sys2)) == 5
@test !any(observed(sys2)) do eq
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
end
end
end

0 comments on commit d2fe0eb

Please sign in to comment.