Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: some fixes related to usage of array symbolics #3126

Merged
merged 8 commits into from
Oct 24, 2024
212 changes: 174 additions & 38 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ end
=#

function tearing_reassemble(state::TearingState, var_eq_matching,
full_var_eq_matching = nothing; simplify = false, mm = nothing)
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
@unpack fullvars, sys, structure = state
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
extra_vars = Int[]
Expand Down Expand Up @@ -574,39 +574,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
# TODO: compute the dependency correctly so that we don't have to do this
obs = [fast_substitute(observed(sys), obs_sub); subeqs]

# HACK: Substitute non-scalarized symbolic arrays of observed variables
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
# by the topological sorting and dependency identification pieces
obs_arr_subs = Dict()

for eq in obs
lhs = eq.lhs
iscall(lhs) || continue
operation(lhs) === getindex || continue
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
arg1 = arguments(lhs)[1]
haskey(obs_arr_subs, arg1) && continue
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)] # e.g. p => [p[1], p[2]]
index_first = eachindex(arg1)[1]

# respect non-1-indexed arrays
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
obs_arr_subs[arg1] = Origin(index_first)(obs_arr_subs[arg1])
end
for i in eachindex(neweqs)
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
end
for i in eachindex(obs)
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
end
for i in eachindex(subeqs)
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
end

@set! sys.eqs = neweqs
@set! sys.observed = obs

unknowns = Any[v
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to refactor this into a function call that is able to be turned on/off.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll pull it out into a function and propagate a keyword argument to toggle it. Should it be on by default or off? I feel like HACK2 should be on by default, and CSE off since that's the more bug-prone of the two

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CSE one seems pretty essential to a lot of applications? I would try to see if we can get them on by default. But yes making an option would then make it much easier to isolate any potential bugs to it.

for (i, v) in enumerate(fullvars)
if diff_to_var[i] === nothing && ispresent(i)]
Expand All @@ -616,6 +583,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
end
end
@set! sys.unknowns = unknowns

obs, subeqs, deps = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs

@set! sys.substitutions = Substitutions(subeqs, deps)

# Only makes sense for time-dependent
Expand All @@ -629,6 +603,168 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
return invalidate_cache!(sys)
end

"""
# HACK 1

Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
avoid the unnecessary cost. This and the below hack are implemented simultaneously

# HACK 2

Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
not) we first count the number of times the scalarized form of each observed variable
occurs in observed equations (and unknowns if it's split).
"""
function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true)
# HACK 1
# mapping of rhs to temporary CSE variable
# `f(...) => tmpvar` in above example
rhs_to_tempvar = Dict()

# HACK 2
# map of array observed variable (unscalarized) to number of its
# scalarized terms that appear in observed equations
arr_obs_occurrences = Dict()
# to check if array variables occur in unscalarized form anywhere
all_vars = Set()
for (i, eq) in enumerate(obs)
lhs = eq.lhs
rhs = eq.rhs
vars!(all_vars, rhs)

# HACK 1
if cse && is_getindexed_array(rhs)
rhs_arr = arguments(rhs)[1]
if !haskey(rhs_to_tempvar, rhs_arr)
tempvar = gensym(Symbol(lhs))
N = length(rhs_arr)
tempvar = unwrap(Symbolics.variable(
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
push!(obs, tempeq)
push!(subeqs, tempeq)
end

# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
# which fails the topological sort
neweq = lhs ~ getindex_wrapper(
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
obs[i] = neweq
subeqi = findfirst(isequal(eq), subeqs)
if subeqi !== nothing
subeqs[subeqi] = neweq
end
end
# end HACK 1

array || continue
iscall(lhs) || continue
operation(lhs) === getindex || continue
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
arg1 = arguments(lhs)[1]
cnt = get(arr_obs_occurrences, arg1, 0)
arr_obs_occurrences[arg1] = cnt + 1
continue
end

# Also do CSE for `equations(sys)`
if cse
for (i, eq) in enumerate(neweqs)
(; lhs, rhs) = eq
is_getindexed_array(rhs) || continue
rhs_arr = arguments(rhs)[1]
if !haskey(rhs_to_tempvar, rhs_arr)
tempvar = gensym(Symbol(lhs))
N = length(rhs_arr)
tempvar = unwrap(Symbolics.variable(
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
push!(obs, tempeq)
push!(subeqs, tempeq)
end
# don't need getindex_wrapper, but do it anyway to know that this
# hack took place
neweq = lhs ~ getindex_wrapper(
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
neweqs[i] = neweq
end
end

# count variables in unknowns if they are scalarized forms of variables
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
# is an observed equation.
for sym in unknowns
iscall(sym) || continue
operation(sym) === getindex || continue
Symbolics.shape(sym) != Symbolics.Unknown() || continue
arg1 = arguments(sym)[1]
cnt = get(arr_obs_occurrences, arg1, 0)
cnt == 0 && continue
arr_obs_occurrences[arg1] = cnt + 1
end
for eq in neweqs
vars!(all_vars, eq.rhs)
end
obs_arr_eqs = Equation[]
for (arrvar, cnt) in arr_obs_occurrences
cnt == length(arrvar) || continue
arrvar in all_vars || continue
# firstindex returns 1 for multidimensional array symbolics
firstind = first(eachindex(arrvar))
scal = [arrvar[i] for i in eachindex(arrvar)]
# respect non-1-indexed arrays
# TODO: get rid of this hack together with the above hack, then remove OffsetArrays dependency
# `change_origin` is required because `Origin(firstind)(scal)` makes codegen
# try to `create_array(OffsetArray{...}, ...)` which errors.
# `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
# of `scal`.
push!(obs_arr_eqs, arrvar ~ change_origin(Origin(firstind), scal))
end
append!(obs, obs_arr_eqs)
append!(subeqs, obs_arr_eqs)

# need to re-sort subeqs
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])

deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
for i in 1:length(subeqs)]

return obs, subeqs, deps
end

function is_getindexed_array(rhs)
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
iscall(rhs) && operation(rhs) === getindex &&
Symbolics.shape(rhs) != Symbolics.Unknown()
end

# PART OF HACK 1
getindex_wrapper(x, i) = x[i...]

@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})

# PART OF HACK 2
function change_origin(origin, arr)
return origin(arr)
end

@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin
size = size(arr)
eltype = eltype(arr)
ndims = ndims(arr)
end

function tearing(state::TearingState; kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
complete!(state.structure)
Expand All @@ -643,10 +779,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
instead, which calls this function internally.
"""
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
simplify = false, kwargs...)
simplify = false, cse_hack = true, array_hack = true, kwargs...)
var_eq_matching, full_var_eq_matching = tearing(state)
invalidate_cache!(tearing_reassemble(
state, var_eq_matching, full_var_eq_matching; mm, simplify))
state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
end

"""
Expand All @@ -668,7 +804,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
the system is balanced.
"""
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
mm = nothing, kwargs...)
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
jac = let state = state
(eqs, vars) -> begin
symeqs = EquationsView(state)[eqs]
Expand All @@ -692,5 +828,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
end
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
kwargs...)
tearing_reassemble(state, var_eq_matching; simplify, mm)
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
end
58 changes: 52 additions & 6 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ function generate_initializesystem(sys::ODESystem;
algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), kwargs...)
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
trueobs, eqs = unhack_observed(observed(sys), equations(sys))
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
vars_set = Set(vars) # for efficient in-lookup

eqs = equations(sys)
idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff

Expand All @@ -24,7 +24,7 @@ function generate_initializesystem(sys::ODESystem;
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
)

# 1) process dummy derivatives and u0map into initialization system
Expand Down Expand Up @@ -166,15 +166,14 @@ function generate_initializesystem(sys::ODESystem;
)

# 7) use observed equations for guesses of observed variables if not provided
obseqs = observed(sys)
for eq in obseqs
for eq in trueobs
haskey(defs, eq.lhs) && continue
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue

defs[eq.lhs] = eq.rhs
end

eqs_ics = Symbolics.substitute.([eqs_ics; obseqs], (paramsubs,))
eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,))
vars = [vars; collect(values(paramsubs))]
for k in keys(defs)
defs[k] = substitute(defs[k], paramsubs)
Expand Down Expand Up @@ -324,3 +323,50 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
return nothing, nothing, nothing, nothing
end
end

"""
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
initialization.
"""
function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
subs = Dict()
tempvars = Set()
rm_idxs = Int[]
for (i, eq) in enumerate(obseqs)
iscall(eq.rhs) || continue
if operation(eq.rhs) == StructuralTransformations.change_origin
push!(rm_idxs, i)
continue
end
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
var, idxs = arguments(eq.rhs)
subs[eq.rhs] = var[idxs...]
push!(tempvars, var)
end
end

for (i, eq) in enumerate(eqs)
iscall(eq.rhs) || continue
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
var, idxs = arguments(eq.rhs)
subs[eq.rhs] = var[idxs...]
push!(tempvars, var)
end
end

for (i, eq) in enumerate(obseqs)
if eq.lhs in tempvars
subs[eq.lhs] = eq.rhs
push!(rm_idxs, i)
end
end

obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)]
obseqs = map(obseqs) do eq
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
end
eqs = map(eqs) do eq
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
end
return obseqs, eqs
end
10 changes: 8 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,21 @@ function vars!(vars, O; op = Differential)
f = getcalledparameter(O)
push!(vars, f)
for arg in arguments(O)
vars!(vars, arg; op)
if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
for el in arg
vars!(vars, unwrap(el); op)
end
else
vars!(vars, arg; op)
end
end
return vars
end
return push!(vars, O)
end
if symbolic_type(O) == NotSymbolic() && O isa AbstractArray
for arg in O
vars!(vars, arg; op)
vars!(vars, unwrap(arg); op)
end
return vars
end
Expand Down
8 changes: 8 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,14 @@ end
@test_nowarn ODESystem(Equation[], t; parameter_dependencies = [p ~ 1.0], name = :a)
end

@testset "Variable discovery in arrays of `Num` inside callable symbolic" begin
@variables x(t) y(t)
@parameters foo(::AbstractVector)
sys = @test_nowarn ODESystem(D(x) ~ foo([x, 2y]), t; name = :sys)
@test length(unknowns(sys)) == 2
@test any(isequal(y), unknowns(sys))
end

@testset "Inplace observed" begin
@variables x(t)
@parameters p[1:2] q
Expand Down
Loading
Loading