Skip to content

Commit

Permalink
Merge pull request #2931 from isaacsas/auto_alg_jsys_support
Browse files Browse the repository at this point in the history
update JumpSystem for auto-alg support
  • Loading branch information
ChrisRackauckas authored Aug 7, 2024
2 parents adb6a84 + 4ef25cb commit 3f7ad46
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
InteractiveUtils = "1"
JuliaFormatter = "1.0.47"
JumpProcesses = "9.1"
JumpProcesses = "9.13.1"
LabelledArrays = "1.3"
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16"
Libdl = "1"
Expand Down
6 changes: 4 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ jprob = JumpProblem(complete(js), dprob, Direct())
sol = solve(jprob, SSAStepper())
```
"""
function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback = nothing,
function JumpProcesses.JumpProblem(js::JumpSystem, prob,
aggregator = JumpProcesses.NullAggregator(); callback = nothing,
eval_expression = false, eval_module = @__MODULE__, kwargs...)
if !iscomplete(js)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `JumpProblem`")
Expand All @@ -448,7 +449,8 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback =
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)

if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator) ||
(aggregator isa JumpProcesses.NullAggregator)
jdeps = asgraph(js)
vdeps = variable_dependencies(js)
vtoj = jdeps.badjlist
Expand Down
29 changes: 26 additions & 3 deletions test/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,21 @@ parammap = [β => 0.1 / 1000, γ => 0.01]
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
Nsims = 30000
function getmean(jprob, Nsims)
function getmean(jprob, Nsims; use_stepper = true)
m = 0.0
for i in 1:Nsims
sol = solve(jprob, SSAStepper())
sol = use_stepper ? solve(jprob, SSAStepper()) : solve(jprob)
m += sol[end, end]
end
m / Nsims
end
m = getmean(jprob, Nsims)

# test auto-alg selection works
jprobb = JumpProblem(js2, dprob; save_positions = (false, false), rng)
mb = getmean(jprobb, Nsims; use_stepper = false)
@test abs(m - mb) / m < 0.01

@variables S2(t)
obs = [S2 ~ 2 * S]
@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs)
Expand All @@ -89,7 +94,6 @@ sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10)
@test all(2 .* sol[S] .== sol[S2])

# test save_positions is working

jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
sol = solve(jprob, SSAStepper(), saveat = 1.0)
@test all((sol.t) .== collect(0.0:tspan[2]))
Expand Down Expand Up @@ -270,3 +274,22 @@ affect = [X ~ X - 1]

j1 = ConstantRateJump(k, [X ~ X - 1])
@test_nowarn @mtkbuild js1 = JumpSystem([j1], t, [X], [k])

# test correct autosolver is selected, which implies appropriate dep graphs are available
let
@parameters k
@variables X(t)
rate = k
affect = [X ~ X - 1]
j1 = ConstantRateJump(k, [X ~ X - 1])

Nv = [1, JumpProcesses.USE_DIRECT_THRESHOLD + 1, JumpProcesses.USE_RSSA_THRESHOLD + 1]
algtypes = [Direct, RSSA, RSSACR]
for (N, algtype) in zip(Nv, algtypes)
@named jsys = JumpSystem([deepcopy(j1) for _ in 1:N], t, [X], [k])
jsys = complete(jsys)
dprob = DiscreteProblem(jsys, [X => 10], (0.0, 10.0), [k => 1])
jprob = JumpProblem(jsys, dprob)
@test jprob.aggregator isa algtype
end
end

0 comments on commit 3f7ad46

Please sign in to comment.