From 4ccd1d9f94f44584c92b8996f3b56b8e111923ea Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 26 Jul 2024 16:09:26 +0530 Subject: [PATCH] fix: fix remaking scalarized array parameters with non-scalarized symbolic map --- src/systems/parameter_buffer.jl | 29 ++++++++++++++++++++++++++--- test/mtkparameters.jl | 17 +++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index e491344fcd..1cc944e1d9 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -506,15 +506,38 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va @set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf)) for buf in newbuf.nonnumeric) + syms = collect(keys(vals)) + vals = Dict{Any, Any}(vals) + for sym in syms + symbolic_type(sym) == ArraySymbolic() || continue + is_parameter(indp, sym) && continue + stype = symtype(unwrap(sym)) + stype <: AbstractArray || continue + Symbolics.shape(sym) == Symbolics.Unknown() && continue + for i in eachindex(sym) + vals[sym[i]] = vals[sym][i] + end + end + # If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill # down to an `AbstractSystem` using `symbolic_container`. We leverage this to get # the index cache. ic = get_index_cache(indp_to_system(indp)) for (p, val) in vals idx = parameter_index(indp, p) - validate_parameter_type(ic, p, idx, val) - _set_parameter_unchecked!( - newbuf, val, idx; update_dependent = false) + if idx !== nothing + validate_parameter_type(ic, p, idx, val) + _set_parameter_unchecked!( + newbuf, val, idx; update_dependent = false) + elseif symbolic_type(p) == ArraySymbolic() + for (i, j) in zip(eachindex(p), eachindex(val)) + pi = p[i] + idx = parameter_index(indp, pi) + validate_parameter_type(ic, pi, idx, val[j]) + _set_parameter_unchecked!( + newbuf, val[j], idx; update_dependent = false) + end + end end @set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.( diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index d5e16bb071..30bbb27ede 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -290,3 +290,20 @@ end sys = complete(sys) @test_throws ["Could not evaluate", "b", "Missing", "2c"] MTKParameters(sys, [a => 1.0]) end + +@testset "Issue#3804" begin + @parameters k[1:4] + @variables (V(t))[1:2] + eqs = [ + D(V[1]) ~ k[1] - k[2] * V[1], + D(V[2]) ~ k[3] - k[4] * V[2] + ] + @mtkbuild osys_scal = ODESystem(eqs, t, [V[1], V[2]], [k[1], k[2], k[3], k[4]]) + + u0 = [V => [10.0, 20.0]] + ps_vec = [k => [2.0, 3.0, 4.0, 5.0]] + ps_scal = [k[1] => 1.0, k[2] => 2.0, k[3] => 3.0, k[4] => 4.0] + oprob_scal_scal = ODEProblem(osys_scal, u0, 1.0, ps_scal) + newoprob = remake(oprob_scal_scal; p = ps_vec) + @test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0] +end