From 6af4c998b65c0f4f845b8a073c9db84474ffbe27 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 1 Aug 2024 15:18:19 +0530 Subject: [PATCH] fix: better handling of (possibly scalarized) array parameters --- src/systems/abstractsystem.jl | 76 ++++++++++++++++++++++++----------- test/odesystem.jl | 14 +++++++ 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index e6d120bce4..9d31dd4c85 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -242,25 +242,42 @@ function wrap_array_vars( # tunables are scalarized and concatenated, so we need to have assignments # for the non-scalarized versions array_tunables = Dict{Any, AbstractArray{Int}}() - for p in ps - idx = parameter_index(sys, p) - idx isa ParameterIndex || continue - idx.portion isa SciMLStructures.Tunable || continue - idx.idx isa AbstractArray || continue - array_tunables[p] = idx.idx - end # Other parameters may be scalarized arrays but used in the vector form - other_array_parameters = Assignment[] + other_array_parameters = Dict{Any, Any}() + for p in ps + p = unwrap(p) + if iscall(p) && operation(p) == getindex + p = arguments(p)[1] + end + symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue + scal = collect(p) + # all scalarized variables are in `ps` + all(x -> any(isequal(x), ps), scal) || continue + (haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue + idx = parameter_index(sys, p) - if Symbolics.isarraysymbolic(p) - idx === nothing || continue - push!(other_array_parameters, p ← collect(p)) - elseif iscall(p) && operation(p) == getindex - idx === nothing && continue - # all of the scalarized variables are in `ps` - all(x -> any(isequal(x), ps), collect(p))|| continue - push!(other_array_parameters, p ← collect(p)) + if idx === nothing + idxs = map(Base.Fix1(parameter_index, sys), scal) + if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs) + idxs = map(x -> x.idx, idxs) + end + if all(x -> x isa Int, idxs) + if vec(idxs) == idxs[begin]:idxs[end] + idxs = reshape(idxs[begin]:idxs[end], size(idxs)) + elseif vec(idxs) == idxs[begin]:-1:idxs[end] + idxs = reshape(idxs[begin]:-1:idxs[end], size(idxs)) + end + array_tunables[p] = idxs + else + other_array_parameters[p] = scal + end + elseif idx isa Int + continue + elseif idx.portion != SciMLStructures.Tunable() + other_array_parameters[p] = scal + else + array_tunables[p] = idx.idx end end for (k, inds) in array_vars @@ -278,7 +295,8 @@ function wrap_array_vars( [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], [k ← :(view($(expr.args[uind + 1].name), $v)) for (k, v) in array_tunables], - other_array_parameters + [k ← Code.MakeArray(v, typeof(v)) + for (k, v) in other_array_parameters] ), expr.body, false @@ -294,7 +312,9 @@ function wrap_array_vars( vcat( [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], [k ← :(view($(expr.args[uind + 1].name), $v)) - for (k, v) in array_tunables] + for (k, v) in array_tunables], + [k ← Code.MakeArray(v, typeof(v)) + for (k, v) in other_array_parameters] ), expr.body, false @@ -310,7 +330,9 @@ function wrap_array_vars( [k ← :(view($(expr.args[uind + 1].name), $v)) for (k, v) in array_vars], [k ← :(view($(expr.args[uind + 2].name), $v)) - for (k, v) in array_tunables] + for (k, v) in array_tunables], + [k ← Code.MakeArray(v, typeof(v)) + for (k, v) in other_array_parameters] ), expr.body, false @@ -499,7 +521,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) return unwrap(sym) in 1:length(parameter_symbols(sys)) end return any(isequal(sym), parameter_symbols(sys)) || - hasname(sym) && is_parameter(sys, getname(sym)) + hasname(sym) && !(iscall(sym) && operation(sym) == getindex) && + is_parameter(sys, getname(sym)) end function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol) @@ -507,7 +530,9 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol return is_parameter(ic, sym) end - named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)] + named_parameters = [getname(x) + for x in parameter_symbols(sys) + if hasname(x) && !(iscall(x) && operation(x) == getindex)] return any(isequal(sym), named_parameters) || count(NAMESPACE_SEPARATOR, string(sym)) == 1 && count(isequal(sym), @@ -543,7 +568,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) return sym end idx = findfirst(isequal(sym), parameter_symbols(sys)) - if idx === nothing && hasname(sym) + if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex) idx = parameter_index(sys, getname(sym)) end return idx @@ -559,13 +584,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym return idx end end - idx = findfirst(isequal(sym), getname.(parameter_symbols(sys))) + pnames = [getname(x) + for x in parameter_symbols(sys) + if hasname(x) && !(iscall(x) && operation(x) == getindex)] + idx = findfirst(isequal(sym), pnames) if idx !== nothing return idx elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1 return findfirst(isequal(sym), Symbol.( - nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys)))) + nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames)) end return nothing end diff --git a/test/odesystem.jl b/test/odesystem.jl index 0f675c49e7..7888a29f21 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1250,3 +1250,17 @@ end prob = ODEProblem(ssys, [], (0.0, 1.0), []) @test prob[x] == prob[y] == prob[z] == 1.0 end + +@testset "Scalarized parameters in array functions" begin + @variables u(t)[1:2] x(t)[1:2] o(t)[1:2] + @parameters p[1:2, 1:2] [tunable = false] + @named sys = ODESystem( + [D(u) ~ (sum(u) + sum(x) + sum(p) + sum(o)) * x, o ~ prod(u) * x], + t, [u..., x..., o...], [p...]) + sys1, = structural_simplify(sys, ([x...], [])) + fn1, = ModelingToolkit.generate_function(sys1; expression = Val{false}) + @test_nowarn fn1(ones(4), 2ones(2), 3ones(2, 2), 4.0) + sys2, = structural_simplify(sys, ([x...], []); split = false) + fn2, = ModelingToolkit.generate_function(sys2; expression = Val{false}) + @test_nowarn fn2(ones(4), 2ones(6), 4.0) +end