Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add symbolic tstops support to ODESystem #3219

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ for prop in [:eqs
:split_idxs
:parent
:is_dde
:tstops
:index_cache
:is_scalar_noise
:isscheduled]
Expand Down Expand Up @@ -1377,6 +1378,14 @@ function namespace_initialization_equations(
map(eq -> namespace_equation(eq, sys; ivs), eqs)
end

function namespace_tstops(sys::AbstractSystem)
tstops = symbolic_tstops(sys)
isempty(tstops) && return tstops
map(tstops) do val
namespace_expr(val, sys)
end
end

function namespace_equation(eq::Equation,
sys,
n = nameof(sys);
Expand Down Expand Up @@ -1632,6 +1641,14 @@ function initialization_equations(sys::AbstractSystem)
end
end

function symbolic_tstops(sys::AbstractSystem)
tstops = get_tstops(sys)
systems = get_systems(sys)
isempty(systems) && return tstops
tstops = [tstops; reduce(vcat, namespace_tstops.(get_systems(sys)); init = [])]
return tstops
end

function preface(sys::AbstractSystem)
has_preface(sys) || return nothing
pre = get_preface(sys)
Expand Down
49 changes: 47 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,39 @@ function DAEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
DAEFunctionExpr{true}(sys, args...; kwargs...)
end

struct SymbolicTstops{F}
fn::F
end

function (st::SymbolicTstops)(p, tspan)
unique!(sort!(reduce(vcat, st.fn(p..., tspan...))))
end

function SymbolicTstops(
sys::AbstractSystem; eval_expression = false, eval_module = @__MODULE__)
tstops = symbolic_tstops(sys)
isempty(tstops) && return nothing
t0 = gensym(:t0)
t1 = gensym(:t1)
tstops = map(tstops) do val
if is_array_of_symbolics(val) || val isa AbstractArray
collect(val)
else
term(:, t0, unwrap(val), t1; type = AbstractArray{Real})
end
end
rps = reorder_parameters(sys, parameters(sys))
tstops, _ = build_function(tstops,
rps...,
t0,
t1;
expression = Val{true},
wrap_code = wrap_array_vars(sys, tstops; dvs = nothing) .∘
wrap_parameter_dependencies(sys, false))
tstops = eval_or_rgf(tstops; eval_expression, eval_module)
return SymbolicTstops(tstops)
end

"""
```julia
DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
Expand Down Expand Up @@ -817,6 +850,11 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
kwargs1 = merge(kwargs1, (callback = cbs,))
end

tstops = SymbolicTstops(sys; eval_expression, eval_module)
if tstops !== nothing
kwargs1 = merge(kwargs1, (; tstops))
end

return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
Expand All @@ -843,7 +881,7 @@ end
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
parammap = DiffEqBase.NullParameters();
warn_initialize_determined = true,
check_length = true, kwargs...) where {iip}
check_length = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
end
Expand All @@ -856,8 +894,15 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
differential_vars = map(Base.Fix2(in, diffvars), sts)
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)

tstops = SymbolicTstops(sys; eval_expression, eval_module)
if tstops !== nothing
kwargs1 = merge(kwargs1, (; tstops))
end

DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
kwargs...)
kwargs..., kwargs1...)
end

function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
Expand Down
15 changes: 11 additions & 4 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ struct ODESystem <: AbstractODESystem
"""
is_dde::Bool
"""
A list of points to provide to the solver as tstops. Uses the same syntax as discrete
events.
"""
tstops::Vector{Any}
"""
Cache for intermediate tearing state.
"""
tearing_state::Any
Expand Down Expand Up @@ -187,7 +192,7 @@ struct ODESystem <: AbstractODESystem
connector_type, preface, cevents,
devents, parameter_dependencies,
metadata = nothing, gui_metadata = nothing, is_dde = false,
tearing_state = nothing,
tstops = [], tearing_state = nothing,
substitutions = nothing, complete = false, index_cache = nothing,
discrete_subsystems = nothing, solved_unknowns = nothing,
split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
Expand All @@ -206,7 +211,7 @@ struct ODESystem <: AbstractODESystem
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, metadata,
gui_metadata, is_dde, tearing_state, substitutions, complete, index_cache,
gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
discrete_subsystems, solved_unknowns, split_idxs, parent)
end
end
Expand All @@ -233,7 +238,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
checks = true,
metadata = nothing,
gui_metadata = nothing,
is_dde = nothing)
is_dde = nothing,
tstops = [])
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
Expand Down Expand Up @@ -299,7 +305,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
defaults, guesses, nothing, initializesystem,
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
disc_callbacks, parameter_dependencies,
metadata, gui_metadata, is_dde, checks = checks)
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

function ODESystem(eqs, iv; kwargs...)
Expand Down Expand Up @@ -402,6 +408,7 @@ function flatten(sys::ODESystem, noeqs = false)
description = description(sys),
initialization_eqs = initialization_equations(sys),
is_dde = is_dde(sys),
tstops = symbolic_tstops(sys),
checks = false)
end
end
Expand Down
28 changes: 28 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1524,3 +1524,31 @@ end
sol = solve(prob, DFBDF(), abstol=1e-8, reltol=1e-8)
@test sol[x]≈sol[y^2 - sum(p)] atol=1e-5
end

@testset "Symbolic tstops" begin
@variables x(t) = 1.0
@parameters p=0.15 q=0.25 r[1:2]=[0.35, 0.45]
@mtkbuild sys = ODESystem(
[D(x) ~ p * x + q * t + sum(r)], t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
prob = ODEProblem(sys, [], (0.0, 5.0))
sol = solve(prob)
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
prob2 = remake(prob; tspan = (0.0, 10.0))
sol2 = solve(prob2)
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)

@variables y(t) [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ p * x + q * t + sum(r), y^3 ~ 2x + 1],
t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
prob = DAEProblem(
sys, [D(y) => 2D(x) / 3y^2, D(x) => p * x + q * t + sum(r)], [], (0.0, 5.0))
sol = solve(prob, DImplicitEuler())
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
prob2 = remake(prob; tspan = (0.0, 10.0))
sol2 = solve(prob2, DImplicitEuler())
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
end
Loading