Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
vyudu committed Dec 20, 2024
1 parent b2bf7c0 commit 3751c2a
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 86 deletions.
165 changes: 81 additions & 84 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,90 +469,6 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
initializeprobpmap = initializeprobpmap)
end

"""
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```
Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
`ps` are used to set the order of the dependent variable and parameter vectors,
respectively. `u0map` should be used to specify the initial condition, or be a function returning an initial condition.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end

f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end

# Construct initial conditions.
_u0 = u0 isa Function ? u0(tspan[1]) : u0

# Define the boundary conditions.
bc = if iip
(residual, u, p, t) -> (residual .= u[1] .- _u0)
else
(u, p, t) -> (u[1] - _u0)
end

return BVProblem{iip}(f, bc, _u0, tspan, p; kwargs1..., kwargs...)
end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
::Val{1}, ::Val{dims}, elems...) where {dims}
[elems...]
end

@inline function SymbolicUtils.Code.create_array(
::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
T[elems...]
end

"""
```julia
DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
Expand Down Expand Up @@ -943,6 +859,87 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

"""
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```
Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
`ps` are used to set the order of the dependent variable and parameter vectors,
respectively. `u0map` should be used to specify the initial condition.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end

f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end

# Define the boundary conditions.
bc = if iip
(residual, u, p, t) -> (residual .= u[1] .- u0)
else
(u, p, t) -> (u[1] - u0)
end

return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing,
::Val{1}, ::Val{dims}, elems...) where {dims}
[elems...]
end

@inline function SymbolicUtils.Code.create_array(
::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
T[elems...]
end

"""
```julia
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
Expand Down
4 changes: 2 additions & 2 deletions test/bvproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
eqs = [D(x) ~ α * x - β * x * y,
D(y) ~ -γ * y + δ * x * y]

u0map = [:x => 1.0, :y => 2.0]
parammap = [:α => 7.5, :β => 4, :γ => 8.0, :δ => 5.0]
u0map = [x => 1.0, y => 2.0]
parammap ==> 7.5, β => 4, γ => 8.0, δ => 5.0]
tspan = (0.0, 10.0)

@mtkbuild lotkavolterra = ODESystem(eqs, t)
Expand Down

0 comments on commit 3751c2a

Please sign in to comment.