Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for OverrideInit and CheckInit #464

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ authors = ["Chris Rackauckas <[email protected]>"]
version = "4.26.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -15,15 +17,20 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Sundials_jll = "fb77eaff-e24c-56d4-86b1-d163f2edb164"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[compat]
Accessors = "0.1.38"
ArrayInterface = "7.17.1"
CEnum = "0.5"
DataStructures = "0.18"
DiffEqBase = "6.154"
ModelingToolkit = "9.54"
PrecompileTools = "1"
Reexport = "1.0"
SciMLBase = "2.9"
SciMLBase = "2.63.1"
Sundials_jll = "5.2"
SymbolicIndexingInterface = "0.3.35"
julia = "1.9"

[extras]
Expand All @@ -34,9 +41,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "AlgebraicMultigrid", "DiffEqCallbacks", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "SparseDiffTools", "SparseConnectivityTracer", "IncompleteLU", "ModelingToolkit"]
test = ["Test", "AlgebraicMultigrid", "DiffEqCallbacks", "ODEProblemLibrary", "DAEProblemLibrary", "ForwardDiff", "SparseDiffTools", "SparseConnectivityTracer", "IncompleteLU", "ModelingToolkit", "SafeTestsets"]
5 changes: 5 additions & 0 deletions src/Sundials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ module Sundials
import Reexport
Reexport.@reexport using DiffEqBase
using SciMLBase: AbstractSciMLOperator
import Accessors: @reset
import ArrayInterface
import SymbolicIndexingInterface as SII
import SymbolicIndexingInterface: ParameterIndexingProxy
import DataStructures
import Logging
import DiffEqBase
Expand Down Expand Up @@ -81,6 +85,7 @@ include("common_interface/verbosity.jl")
include("common_interface/algorithms.jl")
include("common_interface/integrator_types.jl")
include("common_interface/integrator_utils.jl")
include("common_interface/initialize_dae.jl")
include("common_interface/solve.jl")

import PrecompileTools
Expand Down
78 changes: 78 additions & 0 deletions src/common_interface/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
struct SundialsDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end

function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator, initializealg = integrator.initializealg)
_initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm
end

function _initialize_dae!(integrator::IDAIntegrator, prob,
initializealg::IDADefaultInit, isinplace)
if integrator.u_modified
IDAReinit!(integrator)
end
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
tstart, tend = integrator.sol.prob.tspan
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
end
if integrator.alg.init_all
init_type = IDA_Y_INIT
else
init_type = IDA_YA_YDP_INIT
integrator.flag = IDASetId(integrator.mem,
vec(integrator.sol.prob.differential_vars))
end
dt = integrator.dt == tstart ? tend : integrator.dt
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)

# Reflect consistent initial conditions back into the integrator's
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
end
if integrator.t == tstart && integrator.flag < 0
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
ReturnCode.InitialFailure)
end
end

function _initialize_dae!(integrator, prob, ::SundialsDefaultInit, isinplace)
if SciMLBase.has_initializeprob(prob.f)
_initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
elseif integrator isa IDAIntegrator
_initialize_dae!(integrator, prob, IDADefaultInit(), isinplace)
end
end

function _initialize_dae!(integrator, prob, initalg::SciMLBase.NoInit, isinplace) end

function _initialize_dae!(integrator, prob, initalg::SciMLBase.OverrideInit, isinplace::Union{Val{true}, Val{false}})
nlsolve_alg = KINSOL()
u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)

if isinplace === Val{true}()
integrator.u .= u0
if length(integrator.sol.u) == 1
integrator.sol.u[1] .= u0
end
else
integrator.u = u0
if length(integrator.sol.u) == 1
integrator.sol.u[1] = u0
end
end
integrator.p = p
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol

if !success
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure)
end
end

function _initialize_dae!(integrator, prob, initalg::SciMLBase.CheckInit, isinplace::Union{Val{true}, Val{false}})
SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; abstol = integrator.opts.abstol)
end
8 changes: 6 additions & 2 deletions src/common_interface/integrator_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ mutable struct CVODEIntegrator{N,
oType,
LStype,
Atype,
CallbackCacheType} <: AbstractSundialsIntegrator{algType}
CallbackCacheType,
IA} <: AbstractSundialsIntegrator{algType}
u::Array{Float64, N}
u_nvec::NVector
p::pType
Expand All @@ -66,6 +67,7 @@ mutable struct CVODEIntegrator{N,
vector_event_last_time::Int
callback_cache::CallbackCacheType
last_event_error::Float64
initializealg::IA
end

function (integrator::CVODEIntegrator)(t::Number,
Expand Down Expand Up @@ -96,7 +98,8 @@ mutable struct ARKODEIntegrator{N,
Atype,
MLStype,
Mtype,
CallbackCacheType} <: AbstractSundialsIntegrator{ARKODE}
CallbackCacheType,
IA} <: AbstractSundialsIntegrator{ARKODE}
u::Array{Float64, N}
u_nvec::NVector
p::pType
Expand Down Expand Up @@ -124,6 +127,7 @@ mutable struct ARKODEIntegrator{N,
vector_event_last_time::Int
callback_cache::CallbackCacheType
last_event_error::Float64
initializealg::IA
end

function (integrator::ARKODEIntegrator)(t::Number,
Expand Down
41 changes: 5 additions & 36 deletions src/common_interface/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ end
@inline function Base.getproperty(integrator::AbstractSundialsIntegrator, sym::Symbol)
if sym == :dt
return integrator.t - integrator.tprev
elseif sym == :ps
return ParameterIndexingProxy(integrator)
else
return getfield(integrator, sym)
end
Expand All @@ -185,42 +187,6 @@ end
# Required for callbacks
DiffEqBase.set_proposed_dt!(i::AbstractSundialsIntegrator, dt) = nothing

DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator) = nothing

struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm
end

function DiffEqBase.initialize_dae!(integrator::IDAIntegrator,
initializealg::IDADefaultInit)
if integrator.u_modified
IDAReinit!(integrator)
end
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
tstart, tend = integrator.sol.prob.tspan
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
end
if integrator.alg.init_all
init_type = IDA_Y_INIT
else
init_type = IDA_YA_YDP_INIT
integrator.flag = IDASetId(integrator.mem,
vec(integrator.sol.prob.differential_vars))
end
dt = integrator.dt == tstart ? tend : integrator.dt
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)

# Reflect consistent initial conditions back into the integrator's
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
end
if integrator.t == tstart && integrator.flag < 0
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
ReturnCode.InitialFailure)
end
end

DiffEqBase.has_reinit(integrator::AbstractSundialsIntegrator) = true
function DiffEqBase.reinit!(integrator::AbstractSundialsIntegrator,
u0 = integrator.sol.prob.u0;
Expand Down Expand Up @@ -294,3 +260,6 @@ DiffEqBase.get_tstops(integ::AbstractSundialsIntegrator) = integ.opts.tstops
DiffEqBase.get_tstops_array(integ::AbstractSundialsIntegrator) = get_tstops(integ).valtree
DiffEqBase.get_tstops_max(integ::AbstractSundialsIntegrator) =
maximum(get_tstops_array(integ))

# SII
SII.symbolic_container(integ::AbstractSundialsIntegrator) = integ.sol
14 changes: 10 additions & 4 deletions src/common_interface/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
stop_at_next_tstop = false,
userdata = nothing,
alias_u0 = false,
initializealg = SundialsDefaultInit(),
kwargs...) where {uType, tupType, isinplace, Method, LinearSolver
}
tType = eltype(tupType)
Expand Down Expand Up @@ -457,7 +458,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
0,
1,
callback_cache,
0.0)
0.0,
initializealg)
DiffEqBase.initialize_dae!(integrator)
initialize_callbacks!(integrator)
integrator
end # function solve
Expand Down Expand Up @@ -499,6 +502,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
stop_at_next_tstop = false,
userdata = nothing,
alias_u0 = false,
initializealg = SundialsDefaultInit(),
kwargs...) where {uType, tupType, isinplace, Method,
LinearSolver,
MassLinearSolver}
Expand Down Expand Up @@ -945,8 +949,10 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
0,
1,
callback_cache,
0.0)
0.0,
initializealg)

DiffEqBase.initialize_dae!(integrator)
initialize_callbacks!(integrator)
integrator
end # function solve
Expand Down Expand Up @@ -1010,7 +1016,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu
advance_to_tstop = false,
stop_at_next_tstop = false,
userdata = nothing,
initializealg = IDADefaultInit(),
initializealg = SundialsDefaultInit(),
kwargs...) where {uType, duType, tupType, isinplace, LinearSolver
}
tType = eltype(tupType)
Expand Down Expand Up @@ -1313,7 +1319,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu
dutmp,
initializealg)

DiffEqBase.initialize_dae!(integrator, initializealg)
DiffEqBase.initialize_dae!(integrator)
integrator.u_modified && IDAReinit!(integrator)

if save_start
Expand Down
61 changes: 61 additions & 0 deletions test/common_interface/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using ModelingToolkit, SciMLBase, Sundials, Test
using SymbolicIndexingInterface
using ModelingToolkit: t_nounits as t, D_nounits as D

@testset "ODE" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p = missing [guess = 1.0] q = missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, D(y) ~ 5x + q], t; initialization_eqs = [p ^2 + q^2 ~ 3, x^3 + y^3 ~ 5])

@testset "IIP: $iip" for iip in [true, false]
prob = ODEProblem{iip}(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])

@testset "$alg" for alg in [CVODE_BDF, CVODE_Adams, ARKODE]
integ = init(prob, alg())
@test integ.initializealg isa Sundials.SundialsDefaultInit
@test integ[x] ≈ 1.0
@test integ[y] ≈ cbrt(4)
@test integ.ps[p] ≈ 1.0
@test integ.ps[q] ≈ sqrt(2)
sol = solve(prob, alg())
@test SciMLBase.successful_retcode(sol)
@test sol[x, 1] ≈ 1.0
@test sol[y, 1] ≈ cbrt(4)
@test sol.ps[p] ≈ 1.0
@test sol.ps[q] ≈ sqrt(2)
end
end
end
@testset "DAE" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p = missing [guess = 1.0] q = missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, x^3 + y^3 ~ 5], t; initialization_eqs = [p ^2 + q^2 ~ 3])

@testset "DAEProblem{$iip}" for iip in [true, false]
prob = DAEProblem{iip}(sys, [D(x) => cbrt(4), D(y) => -1 / cbrt(4)], [x => 1.0], (0.0, 1.0), [p => 1.0])

@testset "OverrideInit" begin
integ = init(prob, IDA())
@test integ.initializealg isa Sundials.SundialsDefaultInit
@test integ[x] ≈ 1.0
@test integ[y] ≈ cbrt(4)
@test integ.ps[p] ≈ 1.0
@test integ.ps[q] ≈ sqrt(2)
sol = solve(prob, IDA())
@test SciMLBase.successful_retcode(sol)
@test sol[x, 1] ≈ 1.0
@test sol[y, 1] ≈ cbrt(4)
@test sol.ps[p] ≈ 1.0
@test sol.ps[q] ≈ sqrt(2)
end
@testset "CheckInit" begin
@test_throws SciMLBase.CheckInitFailureError init(prob, IDA(); initializealg = SciMLBase.CheckInit())
prob[x] = 1.0
prob[y] = cbrt(4)
prob.ps[p] = 1
prob.ps[q] = sqrt(2)
@test_nowarn init(prob, IDA(); initializealg = SciMLBase.CheckInit())

end
end
end
Loading
Loading