From 697048a6e582f00cd9ed269df91f37ba3b131bc4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 18 Oct 2024 20:33:59 +0530 Subject: [PATCH] feat: add simple CSE for array scalarization case --- .../symbolics_tearing.jl | 55 ++++++++++++++++++- test/structural_transformation/utils.jl | 18 +++++- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 889b75c611..b9c30a0bc1 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end @set! sys.unknowns = unknowns + # HACK: Since we don't support array equations, any equation of the sort + # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly + # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically + # for this case to avoid the unnecessary cost. + # This and the below hack are implemented simultaneously + + # mapping of rhs to temporary CSE variable + # `f(...) => tmpvar` in above example + rhs_to_tempvar = Dict() + # HACK: Add equations for array observed variables. If `p[i] ~ (...)` # are equations, add an equation `p ~ [p[1], p[2], ...]` # allow topsort to reorder them @@ -597,9 +607,43 @@ function tearing_reassemble(state::TearingState, var_eq_matching, arr_obs_occurrences = Dict() # to check if array variables occur in unscalarized form anywhere all_vars = Set() - for eq in obs - vars!(all_vars, eq.rhs) + for (i, eq) in enumerate(obs) lhs = eq.lhs + rhs = eq.rhs + vars!(all_vars, rhs) + + # HACK 1 + if iscall(rhs) && operation(rhs) === getindex && + Symbolics.shape(rhs) != Symbolics.Unknown() + rhs_arr = arguments(rhs)[1] + if !haskey(rhs_to_tempvar, rhs_arr) + tempvar = gensym(Symbol(lhs)) + N = length(rhs_arr) + tempvar = unwrap(Symbolics.variable( + tempvar; T = Symbolics.symtype(rhs_arr))) + tempvar = setmetadata( + tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) + tempeq = tempvar ~ rhs_arr + rhs_to_tempvar[rhs_arr] = tempvar + # ideally we would like to do this: + push!(obs, tempeq) + push!(subeqs, tempeq) + # and let topsort_equations handle it, but that treats `x` and `x[1]` + # as different variables and thus doesn + end + + # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different, + # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr` + # which fails the topological sort + neweq = lhs ~ getindex_wrapper(rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) + obs[i] = neweq + subeqi = findfirst(isequal(eq), subeqs) + if subeqi !== nothing + subeqs[subeqi] = neweq + end + end + # end HACK 1 + iscall(lhs) || continue operation(lhs) === getindex || continue Symbolics.shape(lhs) != Symbolics.Unknown() || continue @@ -640,6 +684,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end append!(obs, obs_arr_eqs) append!(subeqs, obs_arr_eqs) + # need to re-sort subeqs subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) @@ -659,6 +704,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching, return invalidate_cache!(sys) end +# PART OF HACK 1 +getindex_wrapper(x, i) = x[i...] + +@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}}) + +# PART OF HACK 2 function change_origin(origin, arr) return origin(arr) end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index c7146bab65..dded8333f2 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -48,9 +48,25 @@ end @mtkbuild sys = ODESystem( [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) @test length(equations(sys)) == 1 - @test length(observed(sys)) == 6 + @test length(observed(sys)) == 7 @test any(eq -> isequal(eq.lhs, y), observed(sys)) @test any(eq -> isequal(eq.lhs, z), observed(sys)) prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn]) @test_nowarn prob.f(prob.u0, prob.p, 0.0) end + +@testset "scalarized array observed calling same function multiple times" begin + @variables x(t) y(t)[1:2] + @parameters foo(::Real)[1:2] + val = Ref(0) + function _tmp_fn2(x) + val[] += 1 + return [x, 2x] + end + @mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t) + @test length(equations(sys)) == 1 + @test length(observed(sys)) == 3 + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 1 +end