From 016406276b5e9c9f3a24f273f4d25608afdc33df Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Tue, 30 Jul 2024 17:03:26 -0700 Subject: [PATCH] Support more of the SciMLBase events API --- src/systems/callbacks.jl | 188 +++++++++++++++++++++++++------- test/symbolic_events.jl | 228 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 375 insertions(+), 41 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 3fe1f7f006..247a5f9ffe 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -76,11 +76,44 @@ end #################################### continuous events ##################################### const NULL_AFFECT = Equation[] +""" + SymbolicContinuousCallback(eqs::Vector{Equation}, affect, affect_neg, rootfind) + +A [`ContinuousCallback`](@ref SciMLBase.ContinuousCallback) specified symbolically. Takes a vector of equations `eq` +as well as the positive-edge `affect` and negative-edge `affect_neg` that apply when *any* of `eq` are satisfied. +By default `affect_neg = affect`; to only get rising edges specify `affect_neg = nothing`. + +Assume without loss of generality that the equation is of the form `c(u,p,t) ~ 0`; we denote the integrator state as `i.u`. +For simplicty, we define `prev_sign = sign(c(u[t-1], p[t-1], t-1))` and `cur_sign = sign(c(u[t], p[t], t))`. +A condition edge will be detected and the callback will be invoked iff `prev_sign * cur_sign <= 0`. +Inter-sample condition activation is not guaranteed; for example if we use the dirac delta function as `c` to insert a +sharp discontinuity between integrator steps (which in this example would not normally be identified by adaptivity) then the condition is not +gauranteed to be triggered. + +Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used +is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). Multiple callbacks in the same system with different `rootfind` operations will be resolved +into separate VectorContinuousCallbacks in the enumeration order of `SciMLBase.RootfindOpt`, which may cause some callbacks to not fire if several become +active at the same instant. See the `SciMLBase` documentation for more information on the semantic rules. + +The positive edge `affect` will be triggered iff an edge is detected and if `prev_sign < 0`; similarly, `affect_neg` will be +triggered iff an edge is detected `prev_sign > 0`. + +Affects (i.e. `affect` and `affect_neg`) can be specified as either: +* A list of equations that should be applied when the callback is triggered (e.g. `x ~ 3, y ~ 7`) which must be of the form `unknown ~ observed value` where each `unknown` appears only once. Equations will be applied in the order that they appear in the vector; parameters and state updates will become immediately visible to following equations. +* A tuple `(f!, unknowns, read_parameters, modified_parameters, ctx)`, where: + + `f!` is a function with signature `(integ, u, p, ctx)` that is called with the integrator, a state *index* vector `u` derived from `unknowns`, a parameter *index* vector `p` derived from `read_parameters`, and the `ctx` that was given at construction time. Note that `ctx` is aliased between instances. + + `unknowns` is a vector of symbolic unknown variables and optionally their aliases (e.g. if the model was defined with `@variables x(t)` then a valid value for `unknowns` would be `[x]`). A variable can be aliased with a pair `x => :y`. The indices of these `unknowns` will be passed to `f!` in `u` in a named tuple; in the earlier example, if we pass `[x]` as `unknowns` then `f!` can access `x` as `integ.u[u.x]`. If no alias is specified the name of the index will be the symbol version of the variable name. + + `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`. + + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition. + + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem. +""" struct SymbolicContinuousCallback eqs::Vector{Equation} affect::Union{Vector{Equation}, FunctionalAffect} - function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT) - new(eqs, make_affect(affect)) + affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing} + rootfind::SciMLBase.RootfindOpt + function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT, affect_neg = affect, rootfind=SciMLBase.LeftRootFind) + new(eqs, make_affect(affect), make_affect(affect_neg), rootfind) end # Default affect to nothing end make_affect(affect) = affect @@ -88,12 +121,14 @@ make_affect(affect::Tuple) = FunctionalAffect(affect...) make_affect(affect::NamedTuple) = FunctionalAffect(; affect...) function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) - isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) + isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind) end Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs) function Base.hash(cb::SymbolicContinuousCallback, s::UInt) s = foldr(hash, cb.eqs, init = s) - cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s) + s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s) + s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) : hash(cb.affect_neg, s) + hash(cb.rootfind, s) end to_equation_vector(eq::Equation) = [eq] @@ -108,6 +143,8 @@ function SymbolicContinuousCallback(args...) end # wrap eq in vector SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2]) SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough +SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback([eqs], affect, affect_neg, rootfind) +SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback(eqs, affect, affect_neg, rootfind) SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb] SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs @@ -130,12 +167,20 @@ function affects(cbs::Vector{SymbolicContinuousCallback}) mapreduce(affects, vcat, cbs, init = Equation[]) end +affect_negs(cb::SymbolicContinuousCallback) = cb.affect_neg +function affect_negs(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(affect_negs, vcat, cbs, init = Equation[]) +end + namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) +namespace_affects(::Nothing, s) = nothing function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback - SymbolicContinuousCallback(namespace_equation.(equations(cb), (s,)), - namespace_affects(affects(cb), s)) + SymbolicContinuousCallback( + namespace_equation.(equations(cb), (s,)), + namespace_affects(affects(cb), s), + namespace_affects(affect_negs(cb), s)) end """ @@ -159,7 +204,7 @@ function continuous_events(sys::AbstractSystem) filter(!isempty, cbs) end -#################################### continuous events ##################################### +#################################### discrete events ##################################### struct SymbolicDiscreteCallback # condition can be one of: @@ -461,12 +506,34 @@ function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sy isempty(cbs) && return nothing generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...) end +""" +Generate a single rootfinding callback; this happens if there is only one equation in `cbs` passed to +generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback. +""" +function generate_single_rootfinding_callback(eq, cb, sys::AbstractODESystem, dvs = unknowns(sys), + ps = full_parameters(sys); kwargs...) + if !isequal(eq.lhs, 0) + eq = 0 ~ eq.lhs - eq.rhs + end + + rf_oop, rf_ip = generate_custom_function(sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...) + affect_function = compile_affect_fn(cb, sys, dvs, ps, kwargs) + cond = function (u, t, integ) + if DiffEqBase.isinplace(integ.sol.prob) + tmp, = DiffEqBase.get_tmp_cache(integ) + rf_ip(tmp, u, parameter_values(integ), t) + tmp[1] + else + rf_oop(u, parameter_values(integ), t) + end + end + return ContinuousCallback(cond, affect_function.affect, affect_function.affect_neg, rootfind=cb.rootfind) +end -function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys), - ps = full_parameters(sys); kwargs...) +function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys), + ps = full_parameters(sys); rootfind=SciMLBase.RightRootFind, kwargs...) eqs = map(cb -> flatten_equations(cb.eqs), cbs) num_eqs = length.(eqs) - (isempty(eqs) || sum(num_eqs) == 0) && return nothing # fuse equations to create VectorContinuousCallback eqs = reduce(vcat, eqs) # rewrite all equations as 0 ~ interesting stuff @@ -476,45 +543,85 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow end rhss = map(x -> x.rhs, eqs) - root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss)))) - - rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false}, - kwargs...) + _, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false}, kwargs...) - affect_functions = map(cbs) do cb # Keep affect function separate - eq_aff = affects(cb) - affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...) + affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs] + cond = function (out, u, t, integ) + rf_ip(out, u, parameter_values(integ), t) end - if length(eqs) == 1 - cond = function (u, t, integ) - if DiffEqBase.isinplace(integ.sol.prob) - tmp, = DiffEqBase.get_tmp_cache(integ) - rf_ip(tmp, u, parameter_values(integ), t) - tmp[1] - else - rf_oop(u, parameter_values(integ), t) + # since there may be different number of conditions and affects, + # we build a map that translates the condition eq. number to the affect number + eq_ind2affect = reduce(vcat, + [fill(i, num_eqs[i]) for i in eachindex(affect_functions)]) + @assert length(eq_ind2affect) == length(eqs) + @assert maximum(eq_ind2affect) == length(affect_functions) + + affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect + function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations + affect_functions[eq_ind2affect[eq_ind]].affect(integ) + end + end + affect_neg = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect + function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations + affect_neg = affect_functions[eq_ind2affect[eq_ind]].affect_neg + if isnothing(affect_neg) + return # skip if the neg function doesn't exist - don't want to split this into a separate VCC because that'd break ordering end + affect_neg(integ) end - ContinuousCallback(cond, affect_functions[]) + end + return VectorContinuousCallback(cond, affect, affect_neg, length(eqs), rootfind=rootfind) +end + +""" +Compile a single continous callback affect function(s). +""" +function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) + eq_aff = affects(cb) + eq_neg_aff = affect_negs(cb) + affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...) + if eq_neg_aff === eq_aff + affect_neg = affect + elseif isnothing(eq_neg_aff) + affect_neg = nothing else - cond = function (out, u, t, integ) - rf_ip(out, u, parameter_values(integ), t) + affect_neg = compile_affect(eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...) + end + (affect=affect, affect_neg=affect_neg) +end + +function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys), + ps = full_parameters(sys); kwargs...) + eqs = map(cb -> flatten_equations(cb.eqs), cbs) + num_eqs = length.(eqs) + total_eqs = sum(num_eqs) + (isempty(eqs) || total_eqs == 0) && return nothing + if total_eqs == 1 + # find the callback with only one eq + cb_ind = findfirst(>(0), num_eqs) + if isnothing(cb_ind) + error("Inconsistent state in affect compilation; one equation but no callback with equations?") end + cb = cbs[cb_ind] + return generate_single_rootfinding_callback(cb.eqs[], cb, sys, dvs, ps; kwargs...) + end - # since there may be different number of conditions and affects, - # we build a map that translates the condition eq. number to the affect number - eq_ind2affect = reduce(vcat, - [fill(i, num_eqs[i]) for i in eachindex(affect_functions)]) - @assert length(eq_ind2affect) == length(eqs) - @assert maximum(eq_ind2affect) == length(affect_functions) + # group the cbs by what rootfind op they use + # groupby would be very useful here, but alas + cb_classes = Dict{@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}() + for cb in cbs + push!(get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind=cb.rootfind, )), cb) + end - affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect - function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations - affect_functions[eq_ind2affect[eq_ind]](integ) - end - end - VectorContinuousCallback(cond, affect, length(eqs)) + # generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order + compiled_callbacks = map(collect(pairs(sort!(OrderedDict(cb_classes); by=p->p.rootfind)))) do (equiv_class, cbs_in_class) + return generate_vector_rootfinding_callback(cbs_in_class, sys, dvs, ps; rootfind=equiv_class.rootfind, kwargs...) + end + if length(compiled_callbacks) == 1 + return compiled_callbacks[] + else + return CallbackSet(compiled_callbacks...) end end @@ -528,7 +635,6 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...) ps_ind = Dict(reverse(en) for en in enumerate(ps)) p_inds = map(sym -> ps_ind[sym], parameters(affect)) end - # HACK: filter out eliminated symbols. Not clear this is the right thing to do # (MTK should keep these symbols) u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |> diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index c9b4946d30..3332ceda0d 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -5,6 +5,7 @@ using ModelingToolkit: SymbolicContinuousCallback, t_nounits as t, D_nounits as D using StableRNGs +import SciMLBase using SymbolicIndexingInterface rng = StableRNG(12345) @@ -12,6 +13,7 @@ rng = StableRNG(12345) eqs = [D(x) ~ 1] affect = [x ~ 0] +affect_neg = [x ~ 1] ## Test SymbolicContinuousCallback @testset "SymbolicContinuousCallback constructors" begin @@ -19,31 +21,43 @@ affect = [x ~ 0] @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == NULL_AFFECT + @test e.affect_neg == NULL_AFFECT + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == NULL_AFFECT + @test e.affect_neg == NULL_AFFECT + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs, NULL_AFFECT) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == NULL_AFFECT + @test e.affect_neg == NULL_AFFECT + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs[], NULL_AFFECT) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == NULL_AFFECT + @test e.affect_neg == NULL_AFFECT + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs => NULL_AFFECT) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == NULL_AFFECT + @test e.affect_neg == NULL_AFFECT + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs[] => NULL_AFFECT) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == NULL_AFFECT + @test e.affect_neg == NULL_AFFECT + @test e.rootfind == SciMLBase.LeftRootFind ## With affect @@ -51,32 +65,126 @@ affect = [x ~ 0] @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == affect + @test e.affect_neg == affect + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs, affect) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == affect + @test e.affect_neg == affect + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs, affect) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == affect + @test e.affect_neg == affect + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs[], affect) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == affect + @test e.affect_neg == affect + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs => affect) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == affect + @test e.affect_neg == affect + @test e.rootfind == SciMLBase.LeftRootFind e = SymbolicContinuousCallback(eqs[] => affect) @test e isa SymbolicContinuousCallback @test isequal(e.eqs, eqs) @test e.affect == affect + @test e.affect_neg == affect + @test e.rootfind == SciMLBase.LeftRootFind + # with only positive edge affect + + e = SymbolicContinuousCallback(eqs[], affect, affect_neg=nothing) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test isnothing(e.affect_neg) + @test e.rootfind == SciMLBase.LeftRootFind + + e = SymbolicContinuousCallback(eqs, affect, affect_neg=nothing) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test isnothing(e.affect_neg) + @test e.rootfind == SciMLBase.LeftRootFind + + e = SymbolicContinuousCallback(eqs, affect, affect_neg=nothing) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test isnothing(e.affect_neg) + @test e.rootfind == SciMLBase.LeftRootFind + + e = SymbolicContinuousCallback(eqs[], affect, affect_neg=nothing) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test isnothing(e.affect_neg) + @test e.rootfind == SciMLBase.LeftRootFind + + # with explicit edge affects + + e = SymbolicContinuousCallback(eqs[], affect, affect_neg=affect_neg) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test e.affect_neg == affect_neg + @test e.rootfind == SciMLBase.LeftRootFind + + e = SymbolicContinuousCallback(eqs, affect, affect_neg=affect_neg) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test e.affect_neg == affect_neg + @test e.rootfind == SciMLBase.LeftRootFind + + e = SymbolicContinuousCallback(eqs, affect, affect_neg=affect_neg) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test e.affect_neg == affect_neg + @test e.rootfind == SciMLBase.LeftRootFind + + e = SymbolicContinuousCallback(eqs[], affect, affect_neg=affect_neg) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test e.affect_neg == affect_neg + @test e.rootfind == SciMLBase.LeftRootFind + + # with different root finding ops + + e = SymbolicContinuousCallback(eqs[], affect, affect_neg=affect_neg, rootfind=SciMLBase.LeftRootFind) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test e.affect_neg == affect_neg + @test e.rootfind == SciMLBase.LeftRootFind + + e = SymbolicContinuousCallback(eqs[], affect, affect_neg=affect_neg, rootfind=SciMLBase.RightRootFind) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test e.affect_neg == affect_neg + @test e.rootfind == SciMLBase.RightRootFind + + e = SymbolicContinuousCallback(eqs[], affect, affect_neg=affect_neg, rootfind=SciMLBase.NoRootFind) + @test e isa SymbolicContinuousCallback + @test isequal(e.eqs, eqs) + @test e.affect == affect + @test e.affect_neg == affect_neg + @test e.rootfind == SciMLBase.NoRootFind # test plural constructor e = SymbolicContinuousCallbacks(eqs[]) @@ -605,3 +713,123 @@ let @test sol[1, 6] < 1.0 # test whether x(t) decreases over time @test sol[1, 18] > 0.5 # test whether event happened end + +@testset "Additional SymbolicContinuousCallback options" begin + # baseline affect (pos + neg + left root find) + @variables c1(t)=1.0 c2(t)=1.0 # c1 = cos(t), c2 = cos(3t) + eqs = [D(c1) ~ -sin(t); D(c2) ~ -3*sin(3*t)] + record_crossings(i, u, _, c) = push!(c, i.t => i.u[u.v]) + cr1 = []; cr2 = [] + evt1 = ModelingToolkit.SymbolicContinuousCallback([c1 ~ 0], (record_crossings, [c1 => :v], [], [], cr1)) + evt2 = ModelingToolkit.SymbolicContinuousCallback([c2 ~ 0], (record_crossings, [c2 => :v], [], [], cr2)) + @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) + trigsys_ss = structural_simplify(trigsys) + prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) + sol = solve(prob, Tsit5()) + required_crossings_c1 = [π/2, 3*π/2] + required_crossings_c2 = [π/6, π/2, 5*π/6, 7*π/6, 3*π/2, 11*π/6] + @test maximum(abs.(first.(cr1) .- required_crossings_c1)) < 1e-4 + @test maximum(abs.(first.(cr2) .- required_crossings_c2)) < 1e-4 + @test sign.(cos.(required_crossings_c1 .- 1e-6)) == sign.(last.(cr1)) + @test sign.(cos.(3*(required_crossings_c2 .- 1e-6))) == sign.(last.(cr2)) + + # with neg affect (pos * neg + left root find) + cr1p = []; cr2p = [] + cr1n = []; cr2n = [] + evt1 = ModelingToolkit.SymbolicContinuousCallback([c1 ~ 0], (record_crossings, [c1 => :v], [], [], cr1p); affect_neg = (record_crossings, [c1 => :v], [], [], cr1n)) + evt2 = ModelingToolkit.SymbolicContinuousCallback([c2 ~ 0], (record_crossings, [c2 => :v], [], [], cr2p); affect_neg = (record_crossings, [c2 => :v], [], [], cr2n)) + @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) + trigsys_ss = structural_simplify(trigsys) + prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) + sol = solve(prob, Tsit5(); dtmax = 0.01) + c1_pc = filter((<=)(0) ∘ sin, required_crossings_c1) + c1_nc = filter((>=)(0) ∘ sin, required_crossings_c1) + c2_pc = filter(c -> -sin(3c) > 0, required_crossings_c2) + c2_nc = filter(c -> -sin(3c) < 0, required_crossings_c2) + @test maximum(abs.(c1_pc .- first.(cr1p))) < 1e-5 + @test maximum(abs.(c1_nc .- first.(cr1n))) < 1e-5 + @test maximum(abs.(c2_pc .- first.(cr2p))) < 1e-5 + @test maximum(abs.(c2_nc .- first.(cr2n))) < 1e-5 + @test sign.(cos.(c1_pc .- 1e-6)) == sign.(last.(cr1p)) + @test sign.(cos.(c1_nc .- 1e-6)) == sign.(last.(cr1n)) + @test sign.(cos.(3*(c2_pc .- 1e-6))) == sign.(last.(cr2p)) + @test sign.(cos.(3*(c2_nc .- 1e-6))) == sign.(last.(cr2n)) + + # with nothing neg affect (pos * neg + left root find) + cr1p = []; cr2p = [] + evt1 = ModelingToolkit.SymbolicContinuousCallback([c1 ~ 0], (record_crossings, [c1 => :v], [], [], cr1p); affect_neg = nothing) + evt2 = ModelingToolkit.SymbolicContinuousCallback([c2 ~ 0], (record_crossings, [c2 => :v], [], [], cr2p); affect_neg = nothing) + @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) + trigsys_ss = structural_simplify(trigsys) + prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) + sol = solve(prob, Tsit5(); dtmax = 0.01) + @test maximum(abs.(c1_pc .- first.(cr1p))) < 1e-5 + @test maximum(abs.(c2_pc .- first.(cr2p))) < 1e-5 + @test sign.(cos.(c1_pc .- 1e-6)) == sign.(last.(cr1p)) + @test sign.(cos.(3*(c2_pc .- 1e-6))) == sign.(last.(cr2p)) + + + #mixed + cr1p = []; cr2p = [] + cr1n = []; cr2n = [] + evt1 = ModelingToolkit.SymbolicContinuousCallback([c1 ~ 0], (record_crossings, [c1 => :v], [], [], cr1p); affect_neg = nothing) + evt2 = ModelingToolkit.SymbolicContinuousCallback([c2 ~ 0], (record_crossings, [c2 => :v], [], [], cr2p); affect_neg = (record_crossings, [c2 => :v], [], [], cr2n)) + @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) + trigsys_ss = structural_simplify(trigsys) + prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) + sol = solve(prob, Tsit5(); dtmax = 0.01) + c1_pc = filter((<=)(0) ∘ sin, required_crossings_c1) + c2_pc = filter(c -> -sin(3c) > 0, required_crossings_c2) + c2_nc = filter(c -> -sin(3c) < 0, required_crossings_c2) + @test maximum(abs.(c1_pc .- first.(cr1p))) < 1e-5 + @test maximum(abs.(c2_pc .- first.(cr2p))) < 1e-5 + @test maximum(abs.(c2_nc .- first.(cr2n))) < 1e-5 + @test sign.(cos.(c1_pc .- 1e-6)) == sign.(last.(cr1p)) + @test sign.(cos.(3*(c2_pc .- 1e-6))) == sign.(last.(cr2p)) + @test sign.(cos.(3*(c2_nc .- 1e-6))) == sign.(last.(cr2n)) + + + # baseline affect w/ right rootfind (pos + neg + right root find) + @variables c1(t)=1.0 c2(t)=1.0 # c1 = cos(t), c2 = cos(3t) + cr1 = []; cr2 = [] + evt1 = ModelingToolkit.SymbolicContinuousCallback([c1 ~ 0], (record_crossings, [c1 => :v], [], [], cr1); rootfind=SciMLBase.RightRootFind) + evt2 = ModelingToolkit.SymbolicContinuousCallback([c2 ~ 0], (record_crossings, [c2 => :v], [], [], cr2); rootfind=SciMLBase.RightRootFind) + @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) + trigsys_ss = structural_simplify(trigsys) + prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) + sol = solve(prob, Tsit5()) + required_crossings_c1 = [π/2, 3*π/2] + required_crossings_c2 = [π/6, π/2, 5*π/6, 7*π/6, 3*π/2, 11*π/6] + @test maximum(abs.(first.(cr1) .- required_crossings_c1)) < 1e-4 + @test maximum(abs.(first.(cr2) .- required_crossings_c2)) < 1e-4 + @test sign.(cos.(required_crossings_c1 .+ 1e-6)) == sign.(last.(cr1)) + @test sign.(cos.(3*(required_crossings_c2 .+ 1e-6))) == sign.(last.(cr2)) + + + + # baseline affect w/ mixed rootfind (pos + neg + right root find) + cr1 = []; cr2 = [] + evt1 = ModelingToolkit.SymbolicContinuousCallback([c1 ~ 0], (record_crossings, [c1 => :v], [], [], cr1); rootfind=SciMLBase.LeftRootFind) + evt2 = ModelingToolkit.SymbolicContinuousCallback([c2 ~ 0], (record_crossings, [c2 => :v], [], [], cr2); rootfind=SciMLBase.RightRootFind) + @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) + trigsys_ss = structural_simplify(trigsys) + prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) + sol = solve(prob, Tsit5()) + @test maximum(abs.(first.(cr1) .- required_crossings_c1)) < 1e-4 + @test maximum(abs.(first.(cr2) .- required_crossings_c2)) < 1e-4 + @test sign.(cos.(required_crossings_c1 .- 1e-6)) == sign.(last.(cr1)) + @test sign.(cos.(3*(required_crossings_c2 .+ 1e-6))) == sign.(last.(cr2)) + + #flip order and ensure results are okay + cr1 = []; cr2 = [] + evt1 = ModelingToolkit.SymbolicContinuousCallback([c1 ~ 0], (record_crossings, [c1 => :v], [], [], cr1); rootfind=SciMLBase.LeftRootFind) + evt2 = ModelingToolkit.SymbolicContinuousCallback([c2 ~ 0], (record_crossings, [c2 => :v], [], [], cr2); rootfind=SciMLBase.RightRootFind) + @named trigsys = ODESystem(eqs, t; continuous_events = [evt2, evt1]) + trigsys_ss = structural_simplify(trigsys) + prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) + sol = solve(prob, Tsit5()) + @test maximum(abs.(first.(cr1) .- required_crossings_c1)) < 1e-4 + @test maximum(abs.(first.(cr2) .- required_crossings_c2)) < 1e-4 + @test sign.(cos.(required_crossings_c1 .- 1e-6)) == sign.(last.(cr1)) + @test sign.(cos.(3*(required_crossings_c2 .+ 1e-6))) == sign.(last.(cr2)) +end \ No newline at end of file