Skip to content

Commit

Permalink
tell SDEProblem that the system contains scalar noise
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Jul 21, 2024
1 parent 8a713fe commit 1c40b9e
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 21 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
[weakdeps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"

[extensions]
MTKBifurcationKitExt = "BifurcationKit"
MTKDeepDiffsExt = "DeepDiffs"
MTKDiffEqNoiseProcess = "DiffEqNoiseProcess"

[compat]
AbstractTrees = "0.3, 0.4"
Expand Down
8 changes: 8 additions & 0 deletions ext/MTKDiffEqNoiseProcess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module MTKDiffEqNoiseProcess

using ModelingToolkit: ModelingToolkit
using DiffEqNoiseProcess: WienerProcess

ModelingToolkit.scalar_noise() = WienerProcess(0.0, 0.0, 0.0)

end
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,8 @@ for prop in [:eqs
:solved_unknowns
:split_idxs
:parent
:index_cache]
:index_cache
:is_scalar_noise]
fname_get = Symbol(:get_, prop)
fname_has = Symbol(:has_, prop)
@eval begin
Expand Down
48 changes: 39 additions & 9 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,18 @@ struct SDESystem <: AbstractODESystem
The hierarchical parent system before simplification.
"""
parent::Any

"""
Signal for whether the noise equations should be treated as a scalar process. This should only
be `true` when `noiseeqs isa Vector`.
"""
is_scalar_noise::Bool

function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
tgrad,
jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
complete = false, index_cache = nothing, parent = nothing;
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise=false;
checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_independent_variables([iv])
Expand All @@ -146,6 +151,9 @@ struct SDESystem <: AbstractODESystem
throw(ArgumentError("Noise equations ill-formed. Number of rows must match number of drift equations. size(neqs,1) = $(size(neqs,1)) != length(deqs) = $(length(deqs))"))
end
check_equations(equations(cevents), iv)
if is_scalar_noise && neqs isa AbstractMatrix
throw(ArgumentError("Noise equations ill-formed. Recieved a matrix of noise equations of size $(size(neqs)), but `is_scalar_noise` was set to `true`. Scalar noise is only compatible with an `AbstractVector` of noise equations."))

Check warning on line 155 in src/systems/diffeqs/sdesystem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Recieved" should be "Received".
end
end
if checks == true || (checks & CheckUnits) > 0
u = __get_unit_type(dvs, ps, iv)
Expand All @@ -154,7 +162,7 @@ struct SDESystem <: AbstractODESystem
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac,
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent)
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise)
end
end

Expand All @@ -173,7 +181,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
discrete_events = nothing,
parameter_dependencies = nothing,
metadata = nothing,
gui_metadata = nothing)
gui_metadata = nothing,
complete = false,
index_cache = nothing,
parent = nothing,
is_scalar_noise=false)
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
iv′ = value(iv)
Expand Down Expand Up @@ -208,9 +220,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks)
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
complete, index_cache, parent, is_scalar_noise; checks = checks)
end

function SDESystem(sys::ODESystem, neqs; kwargs...)
Expand All @@ -225,6 +238,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
isequal(nameof(sys1), nameof(sys2)) &&
isequal(get_eqs(sys1), get_eqs(sys2)) &&
isequal(get_noiseeqs(sys1), get_noiseeqs(sys2)) &&
isequal(get_is_scalar_noise(sys1), get_is_scalar_noise(sys2)) &&
_eq_unordered(get_unknowns(sys1), get_unknowns(sys2)) &&
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
Expand Down Expand Up @@ -601,6 +615,9 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
SDEFunctionExpr{true}(sys, args...; kwargs...)
end


function scalar_noise end # defined in ../ext/MTKDiffEqNoiseProcess.jl

function DiffEqBase.SDEProblem{iip, specialize}(
sys::SDESystem, u0map = [], tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
Expand All @@ -616,16 +633,24 @@ function DiffEqBase.SDEProblem{iip, specialize}(
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))

noiseeqs = get_noiseeqs(sys)
is_scalar_noise = get_is_scalar_noise(sys)
if noiseeqs isa AbstractVector
noise_rate_prototype = nothing
if is_scalar_noise
noise = scalar_noise()
else
noise = nothing
end
elseif sparsenoise
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
noise = nothing
else
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
noise = nothing
end

SDEProblem{iip}(f, u0, tspan, p; callback = cbs,
SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
noise_rate_prototype = noise_rate_prototype, kwargs...)
end

Expand Down Expand Up @@ -693,8 +718,12 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))

noiseeqs = get_noiseeqs(sys)
is_scalar_noise = get_is_scalar_noise(sys)
if noiseeqs isa AbstractVector
noise_rate_prototype = nothing
if is_scalar_noise
noise = scalar_noise()
end
elseif sparsenoise
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
Expand All @@ -708,7 +737,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
tspan = $tspan
p = $p
noise_rate_prototype = $noise_rate_prototype
SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
noise = $noise
SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise,
$(kwargs...))
end
!linenumbers ? Base.remove_linenums!(ex) : ex
Expand Down
17 changes: 8 additions & 9 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,20 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
@views copyto!(sorted_g_rows[i, :], g[g_row, :])
end
# Fix for https://github.com/SciML/ModelingToolkit.jl/issues/2490
noise_eqs = if isdiag(sorted_g_rows)
if isdiag(sorted_g_rows)
# If the noise matrix is diagonal, then we just give solver just takes a vector column of equations
# and it interprets that as diagonal noise.
diag(sorted_g_rows)
noise_eqs = diag(sorted_g_rows)
is_scalar_noise = false
elseif sorted_g_rows isa AbstractMatrix && size(sorted_g_rows, 2) == 1
##-------------------------------------------------------------------------------
## TODO: re-enable this code once we add a way to signal that the noise is scalar
# sorted_g_rows[:, 1]
##-------------------------------------------------------------------------------
sorted_g_rows
noise_eqs = sorted_g_rows[:, 1]
is_scalar_noise = true
else
sorted_g_rows
noise_eqs = sorted_g_rows
is_scalar_noise = false
end
return SDESystem(full_equations(ode_sys), noise_eqs,
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
name = nameof(ode_sys))
name = nameof(ode_sys), is_scalar_noise)
end
end
4 changes: 2 additions & 2 deletions test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ let
]
prob = SDEProblem(de, u0map, (0.0, 100.0), parammap)
# TODO: re-enable this when we support scalar noise
@test_broken solve(prob, SOSRI()).retcode == ReturnCode.Success
@test solve(prob, SOSRI()).retcode == ReturnCode.Success
end

let # test to make sure that scalar noise always recieve the same kicks

Check warning on line 687 in test/sdesystem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"recieve" should be "receive".
Expand All @@ -692,7 +692,7 @@ let # test to make sure that scalar noise always recieve the same kicks

@mtkbuild de = System(eqs, t)
prob = SDEProblem(de, [x => 0, y => 0], (0.0, 10.0), [])
sol = solve(prob, ImplicitEM())
sol = solve(prob, SOSRI())
@test sol[end][1] == sol[end][2]
end

Expand Down

0 comments on commit 1c40b9e

Please sign in to comment.