Skip to content

Commit

Permalink
fix: better handling of (possibly scalarized) array parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Aug 2, 2024
1 parent 323380f commit 7291dc8
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 30 deletions.
94 changes: 65 additions & 29 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,27 +241,52 @@ function wrap_array_vars(
end
# tunables are scalarized and concatenated, so we need to have assignments
# for the non-scalarized versions
array_tunables = Dict{Any, AbstractArray{Int}}()
for p in ps
idx = parameter_index(sys, p)
idx isa ParameterIndex || continue
idx.portion isa SciMLStructures.Tunable || continue
idx.idx isa AbstractArray || continue
array_tunables[p] = idx.idx
end
array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
# Other parameters may be scalarized arrays but used in the vector form
other_array_parameters = Assignment[]
other_array_parameters = Dict{Any, Any}()

if ps isa Tuple && eltype(ps) <: AbstractArray
ps = Iterators.flatten(ps)
end
for p in ps
p = unwrap(p)
if iscall(p) && operation(p) == getindex
p = arguments(p)[1]
end
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
scal = collect(p)
# all scalarized variables are in `ps`
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue

idx = parameter_index(sys, p)
if Symbolics.isarraysymbolic(p)
idx === nothing || continue
push!(other_array_parameters, p collect(p))
elseif iscall(p) && operation(p) == getindex
idx === nothing && continue
# all of the scalarized variables are in `ps`
all(x -> any(isequal(x), ps), collect(p))|| continue
push!(other_array_parameters, p collect(p))
idx isa Int && continue
if idx isa ParameterIndex
if idx.portion != SciMLStructures.Tunable()
continue
end
idxs = vec(idx.idx)
sz = size(idx.idx)
else
# idx === nothing
idxs = map(Base.Fix1(parameter_index, sys), scal)
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
idxs = map(x -> x.idx, idxs)
end
if !all(x -> x isa Int, idxs)
other_array_parameters[p] = scal
continue
end

sz = size(idxs)
if vec(idxs) == idxs[begin]:idxs[end]
idxs = idxs[begin]:idxs[end]
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
idxs = idxs[begin]:-1:idxs[end]
end
idxs = vec(idxs)
end
array_tunables[p] = (idxs, sz)
end
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
Expand All @@ -276,9 +301,10 @@ function wrap_array_vars(
Let(
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_tunables],
other_array_parameters
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
Expand All @@ -293,8 +319,10 @@ function wrap_array_vars(
Let(
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_tunables]
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
Expand All @@ -309,8 +337,10 @@ function wrap_array_vars(
vcat(
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k :(view($(expr.args[uind + 2].name), $v))
for (k, v) in array_tunables]
[k :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
Expand Down Expand Up @@ -499,15 +529,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
return unwrap(sym) in 1:length(parameter_symbols(sys))
end
return any(isequal(sym), parameter_symbols(sys)) ||
hasname(sym) && is_parameter(sys, getname(sym))
hasname(sym) && !(iscall(sym) && operation(sym) == getindex) &&
is_parameter(sys, getname(sym))
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return is_parameter(ic, sym)
end

named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)]
named_parameters = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
return any(isequal(sym), named_parameters) ||
count(NAMESPACE_SEPARATOR, string(sym)) == 1 &&
count(isequal(sym),
Expand Down Expand Up @@ -543,7 +576,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
return sym
end
idx = findfirst(isequal(sym), parameter_symbols(sys))
if idx === nothing && hasname(sym)
if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex)
idx = parameter_index(sys, getname(sym))
end
return idx
Expand All @@ -559,13 +592,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
return idx
end
end
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
pnames = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
idx = findfirst(isequal(sym), pnames)
if idx !== nothing
return idx
elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1
return findfirst(isequal(sym),
Symbol.(
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys))))
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames))
end
return nothing
end
Expand Down
3 changes: 2 additions & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
update_inds = outputidxs
end

_ps = ps
ps = reorder_parameters(sys, ps)
if checkvars
u = map(x -> time_varying_as_func(value(x), sys), dvs)
Expand All @@ -440,7 +441,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar) .∘
wrap_array_vars(sys, rhss; dvs, ps),
wrap_array_vars(sys, rhss; dvs, ps = _ps),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
Expand Down
14 changes: 14 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1250,3 +1250,17 @@ end
prob = ODEProblem(ssys, [], (0.0, 1.0), [])
@test prob[x] == prob[y] == prob[z] == 1.0
end

@testset "Scalarized parameters in array functions" begin
@variables u(t)[1:2] x(t)[1:2] o(t)[1:2]
@parameters p[1:2, 1:2] [tunable = false]
@named sys = ODESystem(
[D(u) ~ (sum(u) + sum(x) + sum(p) + sum(o)) * x, o ~ prod(u) * x],
t, [u..., x..., o...], [p...])
sys1, = structural_simplify(sys, ([x...], []))
fn1, = ModelingToolkit.generate_function(sys1; expression = Val{false})
@test_nowarn fn1(ones(4), 2ones(2), 3ones(2, 2), 4.0)
sys2, = structural_simplify(sys, ([x...], []); split = false)
fn2, = ModelingToolkit.generate_function(sys2; expression = Val{false})
@test_nowarn fn2(ones(4), 2ones(6), 4.0)
end

0 comments on commit 7291dc8

Please sign in to comment.