From 585de359502db53e50609d656e97d3696e2d6300 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 25 Nov 2024 15:05:31 +0530 Subject: [PATCH] fix: fix SDEs with noise dependent on observed variables --- .../symbolics_tearing.jl | 8 ++++++++ src/systems/systems.jl | 2 ++ test/sdesystem.jl | 19 +++++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index a854acb9b1..8161a9572b 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -88,6 +88,14 @@ function tearing_sub(expr, dict, s) s ? simplify(expr) : expr end +function tearing_substitute_expr(sys::AbstractSystem, expr; simplify = false) + empty_substitutions(sys) && return expr + substitutions = get_substitutions(sys) + @unpack subs = substitutions + solved = Dict(eq.lhs => eq.rhs for eq in subs) + return tearing_sub(expr, solved, simplify) +end + """ $(TYPEDSIGNATURES) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 848980605f..4205a4d207 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -152,6 +152,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal noise_eqs = sorted_g_rows is_scalar_noise = false end + + noise_eqs = StructuralTransformations.tearing_substitute_expr(ode_sys, noise_eqs) return SDESystem(full_equations(ode_sys), noise_eqs, get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys); name = nameof(ode_sys), is_scalar_noise) diff --git a/test/sdesystem.jl b/test/sdesystem.jl index c258a4142b..78d7a0418b 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -780,3 +780,22 @@ end prob = @test_nowarn SDEProblem(sys, nothing, (0.0, 1.0)) @test_nowarn solve(prob, ImplicitEM()) end + +@testset "Issue#3212: Noise dependent on observed" begin + sts = @variables begin + x(t) = 1.0 + input(t) + [input = true] + end + ps = @parameters a = 2 + @brownian η + + eqs = [D(x) ~ -a * x + (input + 1) * η + input ~ 0.0] + + sys = System(eqs, t, sts, ps; name = :name) + sys = structural_simplify(sys) + @test ModelingToolkit.get_noiseeqs(sys) ≈ [1.0] + prob = SDEProblem(sys, [], (0.0, 1.0), []) + @test_nowarn solve(prob, RKMil()) +end