Skip to content

Commit

Permalink
feat: add simple CSE for array scalarization case
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 20, 2024
1 parent a5d7a48 commit d1f61f1
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
56 changes: 54 additions & 2 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
end
@set! sys.unknowns = unknowns

# HACK: Since we don't support array equations, any equation of the sort
# `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
# calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
# for this case to avoid the unnecessary cost.
# This and the below hack are implemented simultaneously

# mapping of rhs to temporary CSE variable
# `f(...) => tmpvar` in above example
rhs_to_tempvar = Dict()

# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
# are equations, add an equation `p ~ [p[1], p[2], ...]`
# allow topsort to reorder them
Expand All @@ -597,9 +607,44 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
arr_obs_occurrences = Dict()
# to check if array variables occur in unscalarized form anywhere
all_vars = Set()
for eq in obs
vars!(all_vars, eq.rhs)
for (i, eq) in enumerate(obs)
lhs = eq.lhs
rhs = eq.rhs
vars!(all_vars, rhs)

# HACK 1
if iscall(rhs) && operation(rhs) === getindex &&
Symbolics.shape(rhs) != Symbolics.Unknown()
rhs_arr = arguments(rhs)[1]
if !haskey(rhs_to_tempvar, rhs_arr)
tempvar = gensym(Symbol(lhs))
N = length(rhs_arr)
tempvar = unwrap(Symbolics.variable(
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
# ideally we would like to do this:
push!(obs, tempeq)
push!(subeqs, tempeq)
# and let topsort_equations handle it, but that treats `x` and `x[1]`
# as different variables and thus doesn
end

# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
# which fails the topological sort
neweq = lhs ~ getindex_wrapper(
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
obs[i] = neweq
subeqi = findfirst(isequal(eq), subeqs)
if subeqi !== nothing
subeqs[subeqi] = neweq
end
end
# end HACK 1

iscall(lhs) || continue
operation(lhs) === getindex || continue
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
Expand Down Expand Up @@ -640,6 +685,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
end
append!(obs, obs_arr_eqs)
append!(subeqs, obs_arr_eqs)

# need to re-sort subeqs
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])

Expand All @@ -659,6 +705,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
return invalidate_cache!(sys)
end

# PART OF HACK 1
getindex_wrapper(x, i) = x[i...]

@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})

# PART OF HACK 2
function change_origin(origin, arr)
return origin(arr)
end
Expand Down
18 changes: 17 additions & 1 deletion test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,25 @@ end
@mtkbuild sys = ODESystem(
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
@test length(equations(sys)) == 1
@test length(observed(sys)) == 6
@test length(observed(sys)) == 7
@test any(eq -> isequal(eq.lhs, y), observed(sys))
@test any(eq -> isequal(eq.lhs, z), observed(sys))
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
end

@testset "scalarized array observed calling same function multiple times" begin
@variables x(t) y(t)[1:2]
@parameters foo(::Real)[1:2]
val = Ref(0)
function _tmp_fn2(x)
val[] += 1
return [x, 2x]
end
@mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
@test length(equations(sys)) == 1
@test length(observed(sys)) == 3
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2])
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
@test val[] == 1
end

0 comments on commit d1f61f1

Please sign in to comment.