diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 84cee928cd..a854acb9b1 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -230,7 +230,7 @@ end =# function tearing_reassemble(state::TearingState, var_eq_matching, - full_var_eq_matching = nothing; simplify = false, mm = nothing) + full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true) @unpack fullvars, sys, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure extra_vars = Int[] @@ -574,39 +574,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching, # TODO: compute the dependency correctly so that we don't have to do this obs = [fast_substitute(observed(sys), obs_sub); subeqs] - # HACK: Substitute non-scalarized symbolic arrays of observed variables - # E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations - # ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled - # by the topological sorting and dependency identification pieces - obs_arr_subs = Dict() - - for eq in obs - lhs = eq.lhs - iscall(lhs) || continue - operation(lhs) === getindex || continue - Symbolics.shape(lhs) !== Symbolics.Unknown() || continue - arg1 = arguments(lhs)[1] - haskey(obs_arr_subs, arg1) && continue - obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)] # e.g. p => [p[1], p[2]] - index_first = eachindex(arg1)[1] - - # respect non-1-indexed arrays - # TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency - obs_arr_subs[arg1] = Origin(index_first)(obs_arr_subs[arg1]) - end - for i in eachindex(neweqs) - neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator) - end - for i in eachindex(obs) - obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator) - end - for i in eachindex(subeqs) - subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator) - end - - @set! sys.eqs = neweqs - @set! sys.observed = obs - unknowns = Any[v for (i, v) in enumerate(fullvars) if diff_to_var[i] === nothing && ispresent(i)] @@ -616,6 +583,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching, end end @set! sys.unknowns = unknowns + + obs, subeqs, deps = cse_and_array_hacks( + obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack) + + @set! sys.eqs = neweqs + @set! sys.observed = obs + @set! sys.substitutions = Substitutions(subeqs, deps) # Only makes sense for time-dependent @@ -629,6 +603,168 @@ function tearing_reassemble(state::TearingState, var_eq_matching, return invalidate_cache!(sys) end +""" +# HACK 1 + +Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]` +gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets +_very_ expensive. this hack performs a limited form of CSE specifically for this case to +avoid the unnecessary cost. This and the below hack are implemented simultaneously + +# HACK 2 + +Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an +equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation +if all `p[i]` are present and the unscalarized form is used in any equation (observed or +not) we first count the number of times the scalarized form of each observed variable +occurs in observed equations (and unknowns if it's split). +""" +function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true) + # HACK 1 + # mapping of rhs to temporary CSE variable + # `f(...) => tmpvar` in above example + rhs_to_tempvar = Dict() + + # HACK 2 + # map of array observed variable (unscalarized) to number of its + # scalarized terms that appear in observed equations + arr_obs_occurrences = Dict() + # to check if array variables occur in unscalarized form anywhere + all_vars = Set() + for (i, eq) in enumerate(obs) + lhs = eq.lhs + rhs = eq.rhs + vars!(all_vars, rhs) + + # HACK 1 + if cse && is_getindexed_array(rhs) + rhs_arr = arguments(rhs)[1] + if !haskey(rhs_to_tempvar, rhs_arr) + tempvar = gensym(Symbol(lhs)) + N = length(rhs_arr) + tempvar = unwrap(Symbolics.variable( + tempvar; T = Symbolics.symtype(rhs_arr))) + tempvar = setmetadata( + tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) + tempeq = tempvar ~ rhs_arr + rhs_to_tempvar[rhs_arr] = tempvar + push!(obs, tempeq) + push!(subeqs, tempeq) + end + + # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different, + # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr` + # which fails the topological sort + neweq = lhs ~ getindex_wrapper( + rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) + obs[i] = neweq + subeqi = findfirst(isequal(eq), subeqs) + if subeqi !== nothing + subeqs[subeqi] = neweq + end + end + # end HACK 1 + + array || continue + iscall(lhs) || continue + operation(lhs) === getindex || continue + Symbolics.shape(lhs) != Symbolics.Unknown() || continue + arg1 = arguments(lhs)[1] + cnt = get(arr_obs_occurrences, arg1, 0) + arr_obs_occurrences[arg1] = cnt + 1 + continue + end + + # Also do CSE for `equations(sys)` + if cse + for (i, eq) in enumerate(neweqs) + (; lhs, rhs) = eq + is_getindexed_array(rhs) || continue + rhs_arr = arguments(rhs)[1] + if !haskey(rhs_to_tempvar, rhs_arr) + tempvar = gensym(Symbol(lhs)) + N = length(rhs_arr) + tempvar = unwrap(Symbolics.variable( + tempvar; T = Symbolics.symtype(rhs_arr))) + tempvar = setmetadata( + tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr)) + tempeq = tempvar ~ rhs_arr + rhs_to_tempvar[rhs_arr] = tempvar + push!(obs, tempeq) + push!(subeqs, tempeq) + end + # don't need getindex_wrapper, but do it anyway to know that this + # hack took place + neweq = lhs ~ getindex_wrapper( + rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) + neweqs[i] = neweq + end + end + + # count variables in unknowns if they are scalarized forms of variables + # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)` + # is an observed equation. + for sym in unknowns + iscall(sym) || continue + operation(sym) === getindex || continue + Symbolics.shape(sym) != Symbolics.Unknown() || continue + arg1 = arguments(sym)[1] + cnt = get(arr_obs_occurrences, arg1, 0) + cnt == 0 && continue + arr_obs_occurrences[arg1] = cnt + 1 + end + for eq in neweqs + vars!(all_vars, eq.rhs) + end + obs_arr_eqs = Equation[] + for (arrvar, cnt) in arr_obs_occurrences + cnt == length(arrvar) || continue + arrvar in all_vars || continue + # firstindex returns 1 for multidimensional array symbolics + firstind = first(eachindex(arrvar)) + scal = [arrvar[i] for i in eachindex(arrvar)] + # respect non-1-indexed arrays + # TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency + # `change_origin` is required because `Origin(firstind)(scal)` makes codegen + # try to `create_array(OffsetArray{...}, ...)` which errors. + # `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size` + # of `scal`. + push!(obs_arr_eqs, arrvar ~ change_origin(Origin(firstind), scal)) + end + append!(obs, obs_arr_eqs) + append!(subeqs, obs_arr_eqs) + + # need to re-sort subeqs + subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) + + deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1)) + for i in 1:length(subeqs)] + + return obs, subeqs, deps +end + +function is_getindexed_array(rhs) + (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) && + iscall(rhs) && operation(rhs) === getindex && + Symbolics.shape(rhs) != Symbolics.Unknown() +end + +# PART OF HACK 1 +getindex_wrapper(x, i) = x[i...] + +@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}}) + +# PART OF HACK 2 +function change_origin(origin, arr) + return origin(arr) +end + +@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin + size = size(arr) + eltype = eltype(arr) + ndims = ndims(arr) +end + function tearing(state::TearingState; kwargs...) state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...) complete!(state.structure) @@ -643,10 +779,10 @@ new residual equations after tearing. End users are encouraged to call [`structu instead, which calls this function internally. """ function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing, - simplify = false, kwargs...) + simplify = false, cse_hack = true, array_hack = true, kwargs...) var_eq_matching, full_var_eq_matching = tearing(state) invalidate_cache!(tearing_reassemble( - state, var_eq_matching, full_var_eq_matching; mm, simplify)) + state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack)) end """ @@ -668,7 +804,7 @@ Perform index reduction and use the dummy derivative technique to ensure that the system is balanced. """ function dummy_derivative(sys, state = TearingState(sys); simplify = false, - mm = nothing, kwargs...) + mm = nothing, cse_hack = true, array_hack = true, kwargs...) jac = let state = state (eqs, vars) -> begin symeqs = EquationsView(state)[eqs] @@ -692,5 +828,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false, end var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, kwargs...) - tearing_reassemble(state, var_eq_matching; simplify, mm) + tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack) end diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index becace8ec5..eefe393acc 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -12,10 +12,10 @@ function generate_initializesystem(sys::ODESystem; algebraic_only = false, check_units = true, check_defguess = false, name = nameof(sys), kwargs...) - vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)]) + trueobs, eqs = unhack_observed(observed(sys), equations(sys)) + vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup - eqs = equations(sys) idxs_diff = isdiffeq.(eqs) idxs_alge = .!idxs_diff @@ -24,7 +24,7 @@ function generate_initializesystem(sys::ODESystem; D = Differential(get_iv(sys)) diffmap = merge( Dict(eq.lhs => eq.rhs for eq in eqs_diff), - Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys)) + Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs) ) # 1) process dummy derivatives and u0map into initialization system @@ -166,15 +166,14 @@ function generate_initializesystem(sys::ODESystem; ) # 7) use observed equations for guesses of observed variables if not provided - obseqs = observed(sys) - for eq in obseqs + for eq in trueobs haskey(defs, eq.lhs) && continue any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue defs[eq.lhs] = eq.rhs end - eqs_ics = Symbolics.substitute.([eqs_ics; obseqs], (paramsubs,)) + eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,)) vars = [vars; collect(values(paramsubs))] for k in keys(defs) defs[k] = substitute(defs[k], paramsubs) @@ -324,3 +323,50 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) return nothing, nothing, nothing, nothing end end + +""" +Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with +initialization. +""" +function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) + subs = Dict() + tempvars = Set() + rm_idxs = Int[] + for (i, eq) in enumerate(obseqs) + iscall(eq.rhs) || continue + if operation(eq.rhs) == StructuralTransformations.change_origin + push!(rm_idxs, i) + continue + end + if operation(eq.rhs) == StructuralTransformations.getindex_wrapper + var, idxs = arguments(eq.rhs) + subs[eq.rhs] = var[idxs...] + push!(tempvars, var) + end + end + + for (i, eq) in enumerate(eqs) + iscall(eq.rhs) || continue + if operation(eq.rhs) == StructuralTransformations.getindex_wrapper + var, idxs = arguments(eq.rhs) + subs[eq.rhs] = var[idxs...] + push!(tempvars, var) + end + end + + for (i, eq) in enumerate(obseqs) + if eq.lhs in tempvars + subs[eq.lhs] = eq.rhs + push!(rm_idxs, i) + end + end + + obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)] + obseqs = map(obseqs) do eq + fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + end + eqs = map(eqs) do eq + fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + end + return obseqs, eqs +end diff --git a/src/utils.jl b/src/utils.jl index d2e8a3ea38..830ec98e44 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -389,7 +389,13 @@ function vars!(vars, O; op = Differential) f = getcalledparameter(O) push!(vars, f) for arg in arguments(O) - vars!(vars, arg; op) + if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray + for el in arg + vars!(vars, unwrap(el); op) + end + else + vars!(vars, arg; op) + end end return vars end @@ -397,7 +403,7 @@ function vars!(vars, O; op = Differential) end if symbolic_type(O) == NotSymbolic() && O isa AbstractArray for arg in O - vars!(vars, arg; op) + vars!(vars, unwrap(arg); op) end return vars end diff --git a/test/odesystem.jl b/test/odesystem.jl index 27ceafd210..9286fb3c01 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1448,6 +1448,14 @@ end @test_nowarn ODESystem(Equation[], t; parameter_dependencies = [p ~ 1.0], name = :a) end +@testset "Variable discovery in arrays of `Num` inside callable symbolic" begin + @variables x(t) y(t) + @parameters foo(::AbstractVector) + sys = @test_nowarn ODESystem(D(x) ~ foo([x, 2y]), t; name = :sys) + @test length(unknowns(sys)) == 2 + @test any(isequal(y), unknowns(sys)) +end + @testset "Inplace observed" begin @variables x(t) @parameters p[1:2] q diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 8644d96945..2704559f72 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -40,3 +40,115 @@ end @test ModelingToolkit.đť‘ neighbors(g, 1) == [2] @test ModelingToolkit.đť‘‘neighbors(g, 2) == [1] end + +@testset "array observed used unscalarized in another observed" begin + @variables x(t) y(t)[1:2] z(t)[1:2] + @parameters foo(::AbstractVector)[1:2] + _tmp_fn(x) = 2x + @mtkbuild sys = ODESystem( + [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) + @test length(equations(sys)) == 1 + @test length(observed(sys)) == 7 + @test any(eq -> isequal(eq.lhs, y), observed(sys)) + @test any(eq -> isequal(eq.lhs, z), observed(sys)) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + + isys = ModelingToolkit.generate_initializesystem(sys) + @test length(unknowns(isys)) == 5 + @test length(equations(isys)) == 4 + @test !any(equations(isys)) do eq + iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, + StructuralTransformations.change_origin] + end +end + +@testset "scalarized array observed calling same function multiple times" begin + @variables x(t) y(t)[1:2] + @parameters foo(::Real)[1:2] + val = Ref(0) + function _tmp_fn2(x) + val[] += 1 + return [x, 2x] + end + @mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t) + @test length(equations(sys)) == 1 + @test length(observed(sys)) == 3 + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 1 + + isys = ModelingToolkit.generate_initializesystem(sys) + @test length(unknowns(isys)) == 3 + @test length(equations(isys)) == 2 + @test !any(equations(isys)) do eq + iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, + StructuralTransformations.change_origin] + end + + @testset "CSE hack in equations(sys)" begin + val[] = 0 + @variables z(t)[1:2] + @mtkbuild sys = ODESystem( + [D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t) + @test length(equations(sys)) == 5 + @test length(observed(sys)) == 2 + prob = ODEProblem( + sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2]) + @test_nowarn prob.f(prob.u0, prob.p, 0.0) + @test val[] == 2 + + isys = ModelingToolkit.generate_initializesystem(sys) + @test length(unknowns(isys)) == 5 + @test length(equations(isys)) == 2 + @test !any(equations(isys)) do eq + iscall(eq.rhs) && + operation(eq.rhs) in [StructuralTransformations.getindex_wrapper, + StructuralTransformations.change_origin] + end + end +end + +@testset "array and cse hacks can be disabled" begin + @testset "fully_determined = true" begin + @variables x(t) y(t)[1:2] z(t)[1:2] + @parameters foo(::AbstractVector)[1:2] + _tmp_fn(x) = 2x + @named sys = ODESystem( + [D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) + + sys1 = structural_simplify(sys; cse_hack = false) + @test length(observed(sys1)) == 6 + @test !any(observed(sys1)) do eq + iscall(eq.rhs) && + operation(eq.rhs) == StructuralTransformations.getindex_wrapper + end + + sys2 = structural_simplify(sys; array_hack = false) + @test length(observed(sys2)) == 5 + @test !any(observed(sys2)) do eq + iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin + end + end + + @testset "fully_determined = false" begin + @variables x(t) y(t)[1:2] z(t)[1:2] w(t) + @parameters foo(::AbstractVector)[1:2] + _tmp_fn(x) = 2x + @named sys = ODESystem( + [D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t) + + sys1 = structural_simplify(sys; cse_hack = false, fully_determined = false) + @test length(observed(sys1)) == 6 + @test !any(observed(sys1)) do eq + iscall(eq.rhs) && + operation(eq.rhs) == StructuralTransformations.getindex_wrapper + end + + sys2 = structural_simplify(sys; array_hack = false, fully_determined = false) + @test length(observed(sys2)) == 5 + @test !any(observed(sys2)) do eq + iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin + end + end +end