Skip to content

Commit

Permalink
feat: add symbolic tstops support to ODESystem
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 19, 2024
1 parent 57e1a43 commit 85ca1a9
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 6 deletions.
17 changes: 17 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ for prop in [:eqs
:split_idxs
:parent
:is_dde
:tstops
:index_cache
:is_scalar_noise
:isscheduled]
Expand Down Expand Up @@ -1377,6 +1378,14 @@ function namespace_initialization_equations(
map(eq -> namespace_equation(eq, sys; ivs), eqs)
end

function namespace_tstops(sys::AbstractSystem)
tstops = symbolic_tstops(sys)
isempty(tstops) && return tstops
map(tstops) do val
namespace_expr(val, sys)
end
end

function namespace_equation(eq::Equation,
sys,
n = nameof(sys);
Expand Down Expand Up @@ -1632,6 +1641,14 @@ function initialization_equations(sys::AbstractSystem)
end
end

function symbolic_tstops(sys::AbstractSystem)
tstops = get_tstops(sys)
systems = get_systems(sys)
isempty(systems) && return tstops
tstops = [tstops; reduce(vcat, namespace_tstops.(get_systems(sys)); init = [])]
return tstops
end

function preface(sys::AbstractSystem)
has_preface(sys) || return nothing
pre = get_preface(sys)
Expand Down
49 changes: 47 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,39 @@ function DAEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
DAEFunctionExpr{true}(sys, args...; kwargs...)
end

struct SymbolicTstops{F}
fn::F
end

function (st::SymbolicTstops)(p, tspan)
unique!(sort!(reduce(vcat, st.fn(p..., tspan...))))
end

function SymbolicTstops(
sys::AbstractSystem; eval_expression = false, eval_module = @__MODULE__)
tstops = symbolic_tstops(sys)
isempty(tstops) && return nothing
t0 = gensym(:t0)
t1 = gensym(:t1)
tstops = map(tstops) do val
if is_array_of_symbolics(val) || val isa AbstractArray
collect(val)
else
term(:, t0, unwrap(val), t1; type = AbstractArray{Real})
end
end
rps = reorder_parameters(sys, parameters(sys))
tstops, _ = build_function(tstops,
rps...,
t0,
t1;
expression = Val{true},
wrap_code = wrap_array_vars(sys, tstops; dvs = nothing) .∘
wrap_parameter_dependencies(sys, false))
tstops = eval_or_rgf(tstops; eval_expression, eval_module)
return SymbolicTstops(tstops)
end

"""
```julia
DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
Expand Down Expand Up @@ -817,6 +850,11 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
kwargs1 = merge(kwargs1, (callback = cbs,))
end

tstops = SymbolicTstops(sys; eval_expression, eval_module)
if tstops !== nothing
kwargs1 = merge(kwargs1, (; tstops))
end

return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
Expand All @@ -843,7 +881,7 @@ end
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
parammap = DiffEqBase.NullParameters();
warn_initialize_determined = true,
check_length = true, kwargs...) where {iip}
check_length = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
end
Expand All @@ -856,8 +894,15 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
differential_vars = map(Base.Fix2(in, diffvars), sts)
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)

tstops = SymbolicTstops(sys; eval_expression, eval_module)
if tstops !== nothing
kwargs1 = merge(kwargs1, (; tstops))
end

DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
kwargs...)
kwargs..., kwargs1...)
end

function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
Expand Down
15 changes: 11 additions & 4 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ struct ODESystem <: AbstractODESystem
"""
is_dde::Bool
"""
A list of points to provide to the solver as tstops. Uses the same syntax as discrete
events.
"""
tstops::Vector{Any}
"""
Cache for intermediate tearing state.
"""
tearing_state::Any
Expand Down Expand Up @@ -187,7 +192,7 @@ struct ODESystem <: AbstractODESystem
connector_type, preface, cevents,
devents, parameter_dependencies,
metadata = nothing, gui_metadata = nothing, is_dde = false,
tearing_state = nothing,
tstops = [], tearing_state = nothing,
substitutions = nothing, complete = false, index_cache = nothing,
discrete_subsystems = nothing, solved_unknowns = nothing,
split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
Expand All @@ -206,7 +211,7 @@ struct ODESystem <: AbstractODESystem
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, metadata,
gui_metadata, is_dde, tearing_state, substitutions, complete, index_cache,
gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
discrete_subsystems, solved_unknowns, split_idxs, parent)
end
end
Expand All @@ -233,7 +238,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
checks = true,
metadata = nothing,
gui_metadata = nothing,
is_dde = nothing)
is_dde = nothing,
tstops = [])
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
Expand Down Expand Up @@ -299,7 +305,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
defaults, guesses, nothing, initializesystem,
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
disc_callbacks, parameter_dependencies,
metadata, gui_metadata, is_dde, checks = checks)
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

function ODESystem(eqs, iv; kwargs...)
Expand Down Expand Up @@ -402,6 +408,7 @@ function flatten(sys::ODESystem, noeqs = false)
description = description(sys),
initialization_eqs = initialization_equations(sys),
is_dde = is_dde(sys),
tstops = symbolic_tstops(sys),
checks = false)
end
end
Expand Down
28 changes: 28 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1524,3 +1524,31 @@ end
sol = solve(prob, DFBDF(), abstol=1e-8, reltol=1e-8)
@test sol[x]sol[y^2 - sum(p)] atol=1e-5
end

@testset "Symbolic tstops" begin
@variables x(t) = 1.0
@parameters p=0.15 q=0.25 r[1:2]=[0.35, 0.45]
@mtkbuild sys = ODESystem(
[D(x) ~ p * x + q * t + sum(r)], t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
prob = ODEProblem(sys, [], (0.0, 5.0))
sol = solve(prob)
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
prob2 = remake(prob; tspan = (0.0, 10.0))
sol2 = solve(prob2)
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)

@variables y(t) [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ p * x + q * t + sum(r), y^3 ~ 2x + 1],
t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
prob = DAEProblem(
sys, [D(y) => 2D(x) / 3y^2, D(x) => p * x + q * t + sum(r)], [], (0.0, 5.0))
sol = solve(prob, DImplicitEuler())
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
prob2 = remake(prob; tspan = (0.0, 10.0))
sol2 = solve(prob2, DImplicitEuler())
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
end

0 comments on commit 85ca1a9

Please sign in to comment.