Skip to content

Commit

Permalink
Merge pull request #2279 from ErikQQY/qqy/refactor_sde
Browse files Browse the repository at this point in the history
Refactor SDEProblem constructor
  • Loading branch information
ChrisRackauckas authored Sep 22, 2023
2 parents 9376247 + e371f0f commit e23075f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ PrecompileTools = "1"
RecursiveArrayTools = "2.3"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "1.76.1"
SciMLBase = "2.0.1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspa
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
end

SDEProblem{iip}(f, f.g, u0, tspan, p; callback = cbs,
SDEProblem{iip}(f, u0, tspan, p; callback = cbs,
noise_rate_prototype = noise_rate_prototype, kwargs...)
end

Expand Down Expand Up @@ -648,7 +648,7 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
tspan = $tspan
p = $p
noise_rate_prototype = $noise_rate_prototype
SDEProblem(f, f.g, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
$(kwargs...))
end
!linenumbers ? striplines(ex) : ex
Expand Down
6 changes: 3 additions & 3 deletions test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ f = eval(generate_diffusion_function(de)[1])
@test f(ones(3), rand(3), nothing) == 0.1ones(3)

f = SDEFunction(de)
prob = SDEProblem(SDEFunction(de), f.g, [1.0, 0.0, 0.0], (0.0, 100.0), (10.0, 26.0, 2.33))
prob = SDEProblem(SDEFunction(de), [1.0, 0.0, 0.0], (0.0, 100.0), (10.0, 26.0, 2.33))
sol = solve(prob, SRIW1(), seed = 1)

probexpr = SDEProblem(SDEFunction(de), f.g, [1.0, 0.0, 0.0], (0.0, 100.0),
probexpr = SDEProblem(SDEFunction(de), [1.0, 0.0, 0.0], (0.0, 100.0),
(10.0, 26.0, 2.33))
solexpr = solve(eval(probexpr), SRIW1(), seed = 1)

Expand All @@ -55,7 +55,7 @@ f(du, [1, 2, 3.0], [0.1, 0.2, 0.3], nothing)
0.2 0.3 0.01*3]

f = SDEFunction(de)
prob = SDEProblem(SDEFunction(de), f.g, [1.0, 0.0, 0.0], (0.0, 100.0), (10.0, 26.0, 2.33),
prob = SDEProblem(SDEFunction(de), [1.0, 0.0, 0.0], (0.0, 100.0), (10.0, 26.0, 2.33),
noise_rate_prototype = zeros(3, 3))
sol = solve(prob, EM(), dt = 0.001)

Expand Down

0 comments on commit e23075f

Please sign in to comment.