Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support callable parameters provided to discretes list of callback #3237

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ const TunableIndexMap = Dict{BasicSymbolic,
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}

const SymbolicParam = Union{BasicSymbolic, CallWithMetadata}

struct IndexCache
unknown_idx::UnknownIndexMap
# sym => (bufferidx, idx_in_buffer)
discrete_idx::Dict{BasicSymbolic, DiscreteIndex}
discrete_idx::Dict{SymbolicParam, DiscreteIndex}
# sym => (clockidx, idx_in_clockbuffer)
callback_to_clocks::Dict{Any, Vector{Int}}
tunable_idx::TunableIndexMap
Expand All @@ -56,13 +58,13 @@ struct IndexCache
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
nonnumeric_buffer_sizes::Vector{BufferTemplate}
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
symbol_to_variable::Dict{Symbol, SymbolicParam}
end

function IndexCache(sys::AbstractSystem)
unks = solved_unknowns(sys)
unk_idxs = UnknownIndexMap()
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()
symbol_to_variable = Dict{Symbol, SymbolicParam}()

let idx = 1
for sym in unks
Expand Down Expand Up @@ -95,18 +97,18 @@ function IndexCache(sys::AbstractSystem)

tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()

function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
sym = unwrap(sym)
buf = get!(buffers, ctype, S())
push!(buf, sym)
end

disc_param_callbacks = Dict{BasicSymbolic, Set{Int}}()
disc_param_callbacks = Dict{SymbolicParam, Set{Int}}()
events = vcat(continuous_events(sys), discrete_events(sys))
for (i, event) in enumerate(events)
discs = Set{BasicSymbolic}()
discs = Set{SymbolicParam}()
affs = affects(event)
if !(affs isa AbstractArray)
affs = [affs]
Expand All @@ -130,26 +132,32 @@ function IndexCache(sys::AbstractSystem)
isequal(only(arguments(sym)), get_iv(sys))
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
push!(clocks, i)
else
elseif is_variable_floatingpoint(sym)
insert_by_type!(constant_buffers, sym, symtype(sym))
else
stype = symtype(sym)
if stype <: FnType
stype = fntype_to_function_type(stype)
end
insert_by_type!(nonnumeric_buffers, sym, stype)
end
end
end
clock_partitions = unique(collect(values(disc_param_callbacks)))
disc_symtypes = unique(symtype.(keys(disc_param_callbacks)))
disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes))
disc_syms_by_symtype = [BasicSymbolic[] for _ in disc_symtypes]
disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes]
for sym in keys(disc_param_callbacks)
push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym)
end
disc_syms_by_symtype_by_partition = [Vector{BasicSymbolic}[] for _ in disc_symtypes]
disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes]
for (i, buffer) in enumerate(disc_syms_by_symtype)
for partition in clock_partitions
push!(disc_syms_by_symtype_by_partition[i],
[sym for sym in buffer if disc_param_callbacks[sym] == partition])
end
end
disc_idxs = Dict{BasicSymbolic, DiscreteIndex}()
disc_idxs = Dict{SymbolicParam, DiscreteIndex}()
callback_to_clocks = Dict{
Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}()
for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition)
Expand Down Expand Up @@ -191,6 +199,7 @@ function IndexCache(sys::AbstractSystem)
end
haskey(disc_idxs, p) && continue
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue
insert_by_type!(
if ctype <: Real || ctype <: AbstractArray{<:Real}
if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() &&
Expand Down
26 changes: 26 additions & 0 deletions test/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,29 @@ end
reorder_dimension_by_tunables!(dst, sys, src, [r, q, p]; dim = 2)
@test dst ≈ stack([vcat(4ones(4), 3ones(3), 1.0) for i in 1:5]; dims = 1)
end

mutable struct ParamTest
y::Any
end
(pt::ParamTest)(x) = pt.y - x
@testset "Issue#3215: Callable discrete parameter" begin
function update_affect!(integ, u, p, ctx)
integ.p[p.p_1].y = integ.t
end

tp1 = typeof(ParamTest(1))
@parameters (p_1::tp1)(..) = ParamTest(1)
@variables x(ModelingToolkit.t_nounits) = 0

event1 = [1.0, 2, 3] => (update_affect!, [], [p_1], [p_1], nothing)

@named sys = ODESystem([
ModelingToolkit.D_nounits(x) ~ p_1(x)
],
ModelingToolkit.t_nounits;
discrete_events = [event1]
)
ss = @test_nowarn complete(sys)
@test length(parameters(ss)) == 1
@test !is_timeseries_parameter(ss, p_1)
end
Loading