diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 3d143d49fb..d8ff71c324 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -469,90 +469,6 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, initializeprobpmap = initializeprobpmap) end -""" -```julia -SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan, - parammap = DiffEqBase.NullParameters(); - version = nothing, tgrad = false, - jac = true, sparse = true, - simplify = false, - kwargs...) where {iip} -``` - -Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and -`ps` are used to set the order of the dependent variable and parameter vectors, -respectively. `u0map` should be used to specify the initial condition, or be a function returning an initial condition. -""" -function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...) - BVProblem{true}(sys, args...; kwargs...) -end - -function SciMLBase.BVProblem(sys::AbstractODESystem, - u0map::StaticArray, - args...; - kwargs...) - BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...) -end - -function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...) - BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) -end - -function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...) - BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) -end - -function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [], - tspan = get_tspan(sys), - parammap = DiffEqBase.NullParameters(); - version = nothing, tgrad = false, - callback = nothing, - check_length = true, - warn_initialize_determined = true, - eval_expression = false, - eval_module = @__MODULE__, - kwargs...) where {iip, specialize} - if !iscomplete(sys) - error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`") - end - - f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap; - t = tspan !== nothing ? tspan[1] : tspan, - check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...) - - cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) - kwargs = filter_kwargs(kwargs) - - kwargs1 = (;) - if cbs !== nothing - kwargs1 = merge(kwargs1, (callback = cbs,)) - end - - # Construct initial conditions. - _u0 = u0 isa Function ? u0(tspan[1]) : u0 - - # Define the boundary conditions. - bc = if iip - (residual, u, p, t) -> (residual .= u[1] .- _u0) - else - (u, p, t) -> (u[1] - _u0) - end - - return BVProblem{iip}(f, bc, _u0, tspan, p; kwargs1..., kwargs...) -end - -get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.") - -@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing, - ::Val{1}, ::Val{dims}, elems...) where {dims} - [elems...] -end - -@inline function SymbolicUtils.Code.create_array( - ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims} - T[elems...] -end - """ ```julia DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys), @@ -943,6 +859,87 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = end get_callback(prob::ODEProblem) = prob.kwargs[:callback] +""" +```julia +SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan, + parammap = DiffEqBase.NullParameters(); + version = nothing, tgrad = false, + jac = true, sparse = true, + simplify = false, + kwargs...) where {iip} +``` + +Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and +`ps` are used to set the order of the dependent variable and parameter vectors, +respectively. `u0map` should be used to specify the initial condition. +""" +function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...) + BVProblem{true}(sys, args...; kwargs...) +end + +function SciMLBase.BVProblem(sys::AbstractODESystem, + u0map::StaticArray, + args...; + kwargs...) + BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...) +end + +function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...) + BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) +end + +function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...) + BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) +end + +function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [], + tspan = get_tspan(sys), + parammap = DiffEqBase.NullParameters(); + version = nothing, tgrad = false, + callback = nothing, + check_length = true, + warn_initialize_determined = true, + eval_expression = false, + eval_module = @__MODULE__, + kwargs...) where {iip, specialize} + if !iscomplete(sys) + error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`") + end + + f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap; + t = tspan !== nothing ? tspan[1] : tspan, + check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...) + + cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) + kwargs = filter_kwargs(kwargs) + + kwargs1 = (;) + if cbs !== nothing + kwargs1 = merge(kwargs1, (callback = cbs,)) + end + + # Define the boundary conditions. + bc = if iip + (residual, u, p, t) -> (residual .= u[1] .- u0) + else + (u, p, t) -> (u[1] - u0) + end + + return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...) +end + +get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.") + +@inline function SymbolicUtils.Code.create_array(::Type{<:Base.ReinterpretArray}, ::Nothing, + ::Val{1}, ::Val{dims}, elems...) where {dims} + [elems...] +end + +@inline function SymbolicUtils.Code.create_array( + ::Type{<:Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims} + T[elems...] +end + """ ```julia DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan, diff --git a/test/bvproblem.jl b/test/bvproblem.jl index 1072874917..c5a302147d 100644 --- a/test/bvproblem.jl +++ b/test/bvproblem.jl @@ -10,8 +10,8 @@ solvers = [MIRK4, RadauIIa5, LobattoIIIa3] eqs = [D(x) ~ α * x - β * x * y, D(y) ~ -γ * y + δ * x * y] -u0map = [:x => 1.0, :y => 2.0] -parammap = [:α => 7.5, :β => 4, :γ => 8.0, :δ => 5.0] +u0map = [x => 1.0, y => 2.0] +parammap = [α => 7.5, β => 4, γ => 8.0, δ => 5.0] tspan = (0.0, 10.0) @mtkbuild lotkavolterra = ODESystem(eqs, t)