Skip to content

Commit

Permalink
Support more of the SciMLBase events API
Browse files Browse the repository at this point in the history
  • Loading branch information
BenChung committed Jul 31, 2024
1 parent c2e6e4a commit 0164062
Show file tree
Hide file tree
Showing 2 changed files with 375 additions and 41 deletions.
188 changes: 147 additions & 41 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,59 @@ end
#################################### continuous events #####################################

const NULL_AFFECT = Equation[]
"""
SymbolicContinuousCallback(eqs::Vector{Equation}, affect, affect_neg, rootfind)
A [`ContinuousCallback`](@ref SciMLBase.ContinuousCallback) specified symbolically. Takes a vector of equations `eq`
as well as the positive-edge `affect` and negative-edge `affect_neg` that apply when *any* of `eq` are satisfied.
By default `affect_neg = affect`; to only get rising edges specify `affect_neg = nothing`.
Assume without loss of generality that the equation is of the form `c(u,p,t) ~ 0`; we denote the integrator state as `i.u`.
For simplicty, we define `prev_sign = sign(c(u[t-1], p[t-1], t-1))` and `cur_sign = sign(c(u[t], p[t], t))`.

Check warning on line 87 in src/systems/callbacks.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"simplicty" should be "simplicity".
A condition edge will be detected and the callback will be invoked iff `prev_sign * cur_sign <= 0`.
Inter-sample condition activation is not guaranteed; for example if we use the dirac delta function as `c` to insert a
sharp discontinuity between integrator steps (which in this example would not normally be identified by adaptivity) then the condition is not
gauranteed to be triggered.

Check warning on line 91 in src/systems/callbacks.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"gauranteed" should be "guaranteed".
Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). Multiple callbacks in the same system with different `rootfind` operations will be resolved
into separate VectorContinuousCallbacks in the enumeration order of `SciMLBase.RootfindOpt`, which may cause some callbacks to not fire if several become
active at the same instant. See the `SciMLBase` documentation for more information on the semantic rules.
The positive edge `affect` will be triggered iff an edge is detected and if `prev_sign < 0`; similarly, `affect_neg` will be
triggered iff an edge is detected `prev_sign > 0`.
Affects (i.e. `affect` and `affect_neg`) can be specified as either:
* A list of equations that should be applied when the callback is triggered (e.g. `x ~ 3, y ~ 7`) which must be of the form `unknown ~ observed value` where each `unknown` appears only once. Equations will be applied in the order that they appear in the vector; parameters and state updates will become immediately visible to following equations.
* A tuple `(f!, unknowns, read_parameters, modified_parameters, ctx)`, where:
+ `f!` is a function with signature `(integ, u, p, ctx)` that is called with the integrator, a state *index* vector `u` derived from `unknowns`, a parameter *index* vector `p` derived from `read_parameters`, and the `ctx` that was given at construction time. Note that `ctx` is aliased between instances.
+ `unknowns` is a vector of symbolic unknown variables and optionally their aliases (e.g. if the model was defined with `@variables x(t)` then a valid value for `unknowns` would be `[x]`). A variable can be aliased with a pair `x => :y`. The indices of these `unknowns` will be passed to `f!` in `u` in a named tuple; in the earlier example, if we pass `[x]` as `unknowns` then `f!` can access `x` as `integ.u[u.x]`. If no alias is specified the name of the index will be the symbol version of the variable name.
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
"""
struct SymbolicContinuousCallback
eqs::Vector{Equation}
affect::Union{Vector{Equation}, FunctionalAffect}
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
new(eqs, make_affect(affect))
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
rootfind::SciMLBase.RootfindOpt
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT, affect_neg = affect, rootfind=SciMLBase.LeftRootFind)
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
end # Default affect to nothing
end
make_affect(affect) = affect
make_affect(affect::Tuple) = FunctionalAffect(affect...)
make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)

function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
end
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
s = foldr(hash, cb.eqs, init = s)
cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) : hash(cb.affect_neg, s)
hash(cb.rootfind, s)
end

to_equation_vector(eq::Equation) = [eq]
Expand All @@ -108,6 +143,8 @@ function SymbolicContinuousCallback(args...)
end # wrap eq in vector
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback([eqs], affect, affect_neg, rootfind)
SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback(eqs, affect, affect_neg, rootfind)

SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
Expand All @@ -130,12 +167,20 @@ function affects(cbs::Vector{SymbolicContinuousCallback})
mapreduce(affects, vcat, cbs, init = Equation[])
end

affect_negs(cb::SymbolicContinuousCallback) = cb.affect_neg
function affect_negs(cbs::Vector{SymbolicContinuousCallback})
mapreduce(affect_negs, vcat, cbs, init = Equation[])
end

namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
namespace_affects(::Nothing, s) = nothing

function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
SymbolicContinuousCallback(namespace_equation.(equations(cb), (s,)),
namespace_affects(affects(cb), s))
SymbolicContinuousCallback(
namespace_equation.(equations(cb), (s,)),
namespace_affects(affects(cb), s),
namespace_affects(affect_negs(cb), s))
end

"""
Expand All @@ -159,7 +204,7 @@ function continuous_events(sys::AbstractSystem)
filter(!isempty, cbs)
end

#################################### continuous events #####################################
#################################### discrete events #####################################

struct SymbolicDiscreteCallback
# condition can be one of:
Expand Down Expand Up @@ -461,12 +506,34 @@ function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sy
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
end
"""
Generate a single rootfinding callback; this happens if there is only one equation in `cbs` passed to
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
"""
function generate_single_rootfinding_callback(eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
if !isequal(eq.lhs, 0)
eq = 0 ~ eq.lhs - eq.rhs
end

rf_oop, rf_ip = generate_custom_function(sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...)
affect_function = compile_affect_fn(cb, sys, dvs, ps, kwargs)
cond = function (u, t, integ)
if DiffEqBase.isinplace(integ.sol.prob)
tmp, = DiffEqBase.get_tmp_cache(integ)
rf_ip(tmp, u, parameter_values(integ), t)
tmp[1]
else
rf_oop(u, parameter_values(integ), t)
end
end
return ContinuousCallback(cond, affect_function.affect, affect_function.affect_neg, rootfind=cb.rootfind)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); rootfind=SciMLBase.RightRootFind, kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
# fuse equations to create VectorContinuousCallback
eqs = reduce(vcat, eqs)
# rewrite all equations as 0 ~ interesting stuff
Expand All @@ -476,45 +543,85 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
end

rhss = map(x -> x.rhs, eqs)
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))

rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
kwargs...)
_, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false}, kwargs...)

affect_functions = map(cbs) do cb # Keep affect function separate
eq_aff = affects(cb)
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs]
cond = function (out, u, t, integ)
rf_ip(out, u, parameter_values(integ), t)
end

if length(eqs) == 1
cond = function (u, t, integ)
if DiffEqBase.isinplace(integ.sol.prob)
tmp, = DiffEqBase.get_tmp_cache(integ)
rf_ip(tmp, u, parameter_values(integ), t)
tmp[1]
else
rf_oop(u, parameter_values(integ), t)
# since there may be different number of conditions and affects,
# we build a map that translates the condition eq. number to the affect number
eq_ind2affect = reduce(vcat,
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
@assert length(eq_ind2affect) == length(eqs)
@assert maximum(eq_ind2affect) == length(affect_functions)

affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
affect_functions[eq_ind2affect[eq_ind]].affect(integ)
end
end
affect_neg = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
affect_neg = affect_functions[eq_ind2affect[eq_ind]].affect_neg
if isnothing(affect_neg)
return # skip if the neg function doesn't exist - don't want to split this into a separate VCC because that'd break ordering
end
affect_neg(integ)
end
ContinuousCallback(cond, affect_functions[])
end
return VectorContinuousCallback(cond, affect, affect_neg, length(eqs), rootfind=rootfind)
end

"""
Compile a single continous callback affect function(s).

Check warning on line 578 in src/systems/callbacks.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"continous" should be "continuous".
"""
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
eq_aff = affects(cb)
eq_neg_aff = affect_negs(cb)
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
if eq_neg_aff === eq_aff
affect_neg = affect
elseif isnothing(eq_neg_aff)
affect_neg = nothing
else
cond = function (out, u, t, integ)
rf_ip(out, u, parameter_values(integ), t)
affect_neg = compile_affect(eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
end
(affect=affect, affect_neg=affect_neg)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
total_eqs = sum(num_eqs)
(isempty(eqs) || total_eqs == 0) && return nothing
if total_eqs == 1
# find the callback with only one eq
cb_ind = findfirst(>(0), num_eqs)
if isnothing(cb_ind)
error("Inconsistent state in affect compilation; one equation but no callback with equations?")
end
cb = cbs[cb_ind]
return generate_single_rootfinding_callback(cb.eqs[], cb, sys, dvs, ps; kwargs...)
end

# since there may be different number of conditions and affects,
# we build a map that translates the condition eq. number to the affect number
eq_ind2affect = reduce(vcat,
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
@assert length(eq_ind2affect) == length(eqs)
@assert maximum(eq_ind2affect) == length(affect_functions)
# group the cbs by what rootfind op they use
# groupby would be very useful here, but alas
cb_classes = Dict{@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
for cb in cbs
push!(get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind=cb.rootfind, )), cb)
end

affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
affect_functions[eq_ind2affect[eq_ind]](integ)
end
end
VectorContinuousCallback(cond, affect, length(eqs))
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
compiled_callbacks = map(collect(pairs(sort!(OrderedDict(cb_classes); by=p->p.rootfind)))) do (equiv_class, cbs_in_class)
return generate_vector_rootfinding_callback(cbs_in_class, sys, dvs, ps; rootfind=equiv_class.rootfind, kwargs...)
end
if length(compiled_callbacks) == 1
return compiled_callbacks[]
else
return CallbackSet(compiled_callbacks...)
end
end

Expand All @@ -528,7 +635,6 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
ps_ind = Dict(reverse(en) for en in enumerate(ps))
p_inds = map(sym -> ps_ind[sym], parameters(affect))
end

# HACK: filter out eliminated symbols. Not clear this is the right thing to do
# (MTK should keep these symbols)
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>
Expand Down
Loading

0 comments on commit 0164062

Please sign in to comment.