diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 1e3490ee69..ab0dd08764 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -40,10 +40,12 @@ const TunableIndexMap = Dict{BasicSymbolic, Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}} const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}} +const SymbolicParam = Union{BasicSymbolic, CallWithMetadata} + struct IndexCache unknown_idx::UnknownIndexMap # sym => (bufferidx, idx_in_buffer) - discrete_idx::Dict{BasicSymbolic, DiscreteIndex} + discrete_idx::Dict{SymbolicParam, DiscreteIndex} # sym => (clockidx, idx_in_clockbuffer) callback_to_clocks::Dict{Any, Vector{Int}} tunable_idx::TunableIndexMap @@ -56,13 +58,13 @@ struct IndexCache tunable_buffer_size::BufferTemplate constant_buffer_sizes::Vector{BufferTemplate} nonnumeric_buffer_sizes::Vector{BufferTemplate} - symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}} + symbol_to_variable::Dict{Symbol, SymbolicParam} end function IndexCache(sys::AbstractSystem) unks = solved_unknowns(sys) unk_idxs = UnknownIndexMap() - symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}() + symbol_to_variable = Dict{Symbol, SymbolicParam}() let idx = 1 for sym in unks @@ -95,7 +97,7 @@ function IndexCache(sys::AbstractSystem) tunable_buffers = Dict{Any, Set{BasicSymbolic}}() constant_buffers = Dict{Any, Set{BasicSymbolic}}() - nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}() + nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}() function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S} sym = unwrap(sym) @@ -103,10 +105,10 @@ function IndexCache(sys::AbstractSystem) push!(buf, sym) end - disc_param_callbacks = Dict{BasicSymbolic, Set{Int}}() + disc_param_callbacks = Dict{SymbolicParam, Set{Int}}() events = vcat(continuous_events(sys), discrete_events(sys)) for (i, event) in enumerate(events) - discs = Set{BasicSymbolic}() + discs = Set{SymbolicParam}() affs = affects(event) if !(affs isa AbstractArray) affs = [affs] @@ -130,26 +132,32 @@ function IndexCache(sys::AbstractSystem) isequal(only(arguments(sym)), get_iv(sys)) clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym) push!(clocks, i) - else + elseif is_variable_floatingpoint(sym) insert_by_type!(constant_buffers, sym, symtype(sym)) + else + stype = symtype(sym) + if stype <: FnType + stype = fntype_to_function_type(stype) + end + insert_by_type!(nonnumeric_buffers, sym, stype) end end end clock_partitions = unique(collect(values(disc_param_callbacks))) disc_symtypes = unique(symtype.(keys(disc_param_callbacks))) disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes)) - disc_syms_by_symtype = [BasicSymbolic[] for _ in disc_symtypes] + disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes] for sym in keys(disc_param_callbacks) push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym) end - disc_syms_by_symtype_by_partition = [Vector{BasicSymbolic}[] for _ in disc_symtypes] + disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes] for (i, buffer) in enumerate(disc_syms_by_symtype) for partition in clock_partitions push!(disc_syms_by_symtype_by_partition[i], [sym for sym in buffer if disc_param_callbacks[sym] == partition]) end end - disc_idxs = Dict{BasicSymbolic, DiscreteIndex}() + disc_idxs = Dict{SymbolicParam, DiscreteIndex}() callback_to_clocks = Dict{ Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}() for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition) @@ -191,6 +199,7 @@ function IndexCache(sys::AbstractSystem) end haskey(disc_idxs, p) && continue haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue + haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() && diff --git a/test/index_cache.jl b/test/index_cache.jl index c479075797..455203d759 100644 --- a/test/index_cache.jl +++ b/test/index_cache.jl @@ -92,3 +92,29 @@ end reorder_dimension_by_tunables!(dst, sys, src, [r, q, p]; dim = 2) @test dst ≈ stack([vcat(4ones(4), 3ones(3), 1.0) for i in 1:5]; dims = 1) end + +mutable struct ParamTest + y::Any +end +(pt::ParamTest)(x) = pt.y - x +@testset "Issue#3215: Callable discrete parameter" begin + function update_affect!(integ, u, p, ctx) + integ.p[p.p_1].y = integ.t + end + + tp1 = typeof(ParamTest(1)) + @parameters (p_1::tp1)(..) = ParamTest(1) + @variables x(ModelingToolkit.t_nounits) = 0 + + event1 = [1.0, 2, 3] => (update_affect!, [], [p_1], [p_1], nothing) + + @named sys = ODESystem([ + ModelingToolkit.D_nounits(x) ~ p_1(x) + ], + ModelingToolkit.t_nounits; + discrete_events = [event1] + ) + ss = @test_nowarn complete(sys) + @test length(parameters(ss)) == 1 + @test !is_timeseries_parameter(ss, p_1) +end