From 10253638d71ad9101f113b0dcaf1828e43615541 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 24 Oct 2024 20:16:14 +0530 Subject: [PATCH] feat: extend CSE hack to non-observed equations --- .../symbolics_tearing.jl | 45 ++++++++++++++++--- src/systems/nonlinear/initializesystem.jl | 24 +++++++--- test/structural_transformation/utils.jl | 22 +++++++++ 3 files changed, 79 insertions(+), 12 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index cabf0415ea..a854acb9b1 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -584,7 +584,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end @set! sys.unknowns = unknowns - obs, subeqs = cse_and_array_hacks( + obs, subeqs, deps = cse_and_array_hacks( obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack) @set! sys.eqs = neweqs @@ -637,10 +637,7 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = vars!(all_vars, rhs) # HACK 1 - if cse && - (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && - iscall(rhs) && operation(rhs) === getindex && - Symbolics.shape(rhs) != Symbolics.Unknown() + if cse && is_getindexed_array(rhs) rhs_arr = arguments(rhs)[1] if !haskey(rhs_to_tempvar, rhs_arr) tempvar = gensym(Symbol(lhs)) @@ -677,6 +674,33 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = arr_obs_occurrences[arg1] = cnt + 1 continue end + + # Also do CSE for `equations(sys)` + if cse + for (i, eq) in enumerate(neweqs) + (; lhs, rhs) = eq + is_getindexed_array(rhs) || continue + rhs_arr = arguments(rhs)[1] + if !haskey(rhs_to_tempvar, rhs_arr) + tempvar = gensym(Symbol(lhs)) + N = length(rhs_arr) + tempvar = unwrap(Symbolics.variable( + tempvar; T = Symbolics.symtype(rhs_arr))) + tempvar = setmetadata( + tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) + tempeq = tempvar ~ rhs_arr + rhs_to_tempvar[rhs_arr] = tempvar + push!(obs, tempeq) + push!(subeqs, tempeq) + end + # don't need getindex_wrapper, but do it anyway to know that this + # hack took place + neweq = lhs ~ getindex_wrapper( + rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) + neweqs[i] = neweq + end + end + # count variables in unknowns if they are scalarized forms of variables # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)` # is an observed equation. @@ -713,7 +737,16 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = # need to re-sort subeqs subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) - return obs, subeqs + deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1)) + for i in 1:length(subeqs)] + + return obs, subeqs, deps +end + +function is_getindexed_array(rhs) + (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && + iscall(rhs) && operation(rhs) === getindex && + Symbolics.shape(rhs) != Symbolics.Unknown() end # PART OF HACK 1 diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 6c7457e49b..eefe393acc 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -12,11 +12,10 @@ function generate_initializesystem(sys::ODESystem; algebraic_only = false, check_units = true, check_defguess = false, name = nameof(sys), kwargs...) - trueobs = unhack_observed(observed(sys)) + trueobs, eqs = unhack_observed(observed(sys), equations(sys)) vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup - eqs = equations(sys) idxs_diff = isdiffeq.(eqs) idxs_alge = .!idxs_diff @@ -329,11 +328,11 @@ end Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with initialization. """ -function unhack_observed(eqs::Vector{Equation}) +function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) subs = Dict() tempvars = Set() rm_idxs = Int[] - for (i, eq) in enumerate(eqs) + for (i, eq) in enumerate(obseqs) iscall(eq.rhs) || continue if operation(eq.rhs) == StructuralTransformations.change_origin push!(rm_idxs, i) @@ -347,14 +346,27 @@ function unhack_observed(eqs::Vector{Equation}) end for (i, eq) in enumerate(eqs) + iscall(eq.rhs) || continue + if operation(eq.rhs) == StructuralTransformations.getindex_wrapper + var, idxs = arguments(eq.rhs) + subs[eq.rhs] = var[idxs...] + push!(tempvars, var) + end + end + + for (i, eq) in enumerate(obseqs) if eq.lhs in tempvars subs[eq.lhs] = eq.rhs push!(rm_idxs, i) end end - eqs = eqs[setdiff(eachindex(eqs), rm_idxs)] - return map(eqs) do eq + obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)] + obseqs = map(obseqs) do eq + fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + end + eqs = map(eqs) do eq fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) end + return obseqs, eqs end diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 04600a7a6b..2704559f72 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -85,6 +85,28 @@ end iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, StructuralTransformations.change_origin] end + + @testset "CSE hack in equations(sys)" begin + val[] = 0 + @variables z(t)[1:2] + @mtkbuild sys = ODESystem( + [D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t) + @test length(equations(sys)) == 5 + @test length(observed(sys)) == 2 + prob = ODEProblem( + sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 2 + + isys = ModelingToolkit.generate_initializesystem(sys) + @test length(unknowns(isys)) == 5 + @test length(equations(isys)) == 2 + @test !any(equations(isys)) do eq + iscall(eq.rhs) && + operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, + StructuralTransformations.change_origin] + end + end end @testset "array and cse hacks can be disabled" begin