Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating BVProblem from ODESystem #3251

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
75 changes: 75 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,81 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
initializeprobpmap = initializeprobpmap)
end

"""
vyudu marked this conversation as resolved.
Show resolved Hide resolved
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```

Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
`ps` are used to set the order of the dependent variable and parameter vectors,
respectively. `u0map` should be used to specify the initial condition, or be a function returning an initial condition.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end

f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end

# Construct initial conditions.
_u0 = u0 isa Function ? u0(tspan[1]) : u0
vyudu marked this conversation as resolved.
Show resolved Hide resolved

# Define the boundary conditions.
bc = if iip
(residual, u, p, t) -> (residual .= u[1] - _u0)
vyudu marked this conversation as resolved.
Show resolved Hide resolved
else
(u, p, t) -> (u[1] - _u0)
end

return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
end

get_callback(prob::BVProblem) = prob.kwargs[:callback]
vyudu marked this conversation as resolved.
Show resolved Hide resolved

"""
```julia
DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
Expand Down
56 changes: 56 additions & 0 deletions test/bvproblem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using BoundaryValueDiffEq, OrdinaryDiffEq
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters α = 7.5 β = 4. γ = 8. δ = 5.
@variables x(t) = 1. y(t) = 2.

eqs = [D(x) ~ α*x - β*x*y,
D(y) ~ -γ*y + δ*x*y]

u0map = [:x => 1., :y => 2.]
parammap = [:α => 7.5, :β => 4, :γ => 8., :δ => 5.]
tspan = (0., 10.)

@mtkbuild lotkavolterra = ODESystem(eqs, t)

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
sol = solve(bvp, MIRK4(), dt = 0.01);

bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)
sol2 = solve(bvp, MIRK4(), dt = 0.01);

op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
osol = solve(op, Vern9())

@test isapprox(sol.u[end],osol.u[end]; atol = 0.001)
@test isapprox(sol2.u[end],osol.u[end]; atol = 0.001)
@test sol.u[1] == [1., 2.]
@test sol2.u[1] == [1., 2.]

### Testing on pendulum

@parameters g = 9.81 L = 1.
@variables θ(t) = π/2

eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]

@mtkbuild pend = ODESystem(eqs, t)

u0map = [θ => π/2, D(θ) => π/2]
parammap = [:L => 2., :g => 9.81]
tspan = (0., 10.)

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
sol = solve(bvp, MIRK4(), dt = 0.01);

bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
sol2 = solve(bvp2, MIRK4(), dt = 0.01);

op = ODEProblem(pend, u0map, tspan, parammap)
osol = solve(op, Vern9())

@test sol.u[end] ≈ osol.u[end]
@test sol.u[1] == [π/2, π/2]
@test sol2.u[end] ≈ osol.u[end]
@test sol2.u[1] == [π/2, π/2]
Loading