From b5dc2ea54b95364c13e17a991033f4a188ce69e8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 22 Jul 2024 14:20:36 +0530 Subject: [PATCH] fix: create specialized `isdiag` for symbolics in noise matrix --- src/systems/diffeqs/sdesystem.jl | 9 ++++++++- test/sdesystem.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 5efdc2cb33..489a2e75db 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -244,13 +244,20 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem) all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) end +function __num_isdiag(mat) + for i in axes(mat, 1), j in axes(mat, 2) + i == j || isequal(mat[i, j], 0) || return false + end + return true +end + function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys), ps = full_parameters(sys); isdde = false, kwargs...) eqs = get_noiseeqs(sys) if isdde eqs = delay_to_function(sys, eqs) end - if eqs isa AbstractMatrix && isdiag(eqs) + if eqs isa AbstractMatrix && __num_isdiag(eqs) eqs = diag(eqs) end u = map(x -> time_varying_as_func(value(x), sys), dvs) diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 55481b5514..d71848bf4b 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -722,3 +722,29 @@ let # test that diagonal noise is correctly handled # SOSRI only works for diagonal and scalar noise @test solve(prob, SOSRI()).retcode == ReturnCode.Success end + +@testset "Non-diagonal noise check" begin + @parameters σ ρ β + @variables x(t) y(t) z(t) + @brownian a b c + eqs = [D(x) ~ σ * (y - x) + 0.1a * x + 0.1b * y, + D(y) ~ x * (ρ - z) - y + 0.1b * y, + D(z) ~ x * y - β * z + 0.1c * z] + @mtkbuild de = System(eqs, t) + + u0map = [ + x => 1.0, + y => 0.0, + z => 0.0 + ] + + parammap = [ + σ => 10.0, + β => 26.0, + ρ => 2.33 + ] + + prob = SDEProblem(de, u0map, (0.0, 100.0), parammap) + # SOSRI only works for diagonal and scalar noise + @test solve(prob, ImplicitEM()).retcode == ReturnCode.Success +end