Skip to content

Commit

Permalink
feat: better inbounds handling and propagation for ODEProblem and o…
Browse files Browse the repository at this point in the history
…bserved functions
  • Loading branch information
AayushSabharwal committed Dec 18, 2024
1 parent 31f7a54 commit 31edc8f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
18 changes: 13 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_inbounds(isscalar)
function wrapper(expr)
Func(expr.args, [], :(@inbounds begin; $(toexpr(expr.body)); end))
end
return isscalar ? wrapper : (wrapper, wrapper)
end

function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
end
Expand Down Expand Up @@ -785,7 +792,7 @@ end
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true

function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -808,7 +815,7 @@ function SymbolicIndexingInterface.observed(
end
end
end
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds)

if is_time_dependent(sys)
return _fn
Expand Down Expand Up @@ -1671,11 +1678,12 @@ struct ObservedFunctionCache{S}
steady_state::Bool
eval_expression::Bool
eval_module::Module
checkbounds::Bool
end

function ObservedFunctionCache(
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
end

# This is hit because ensemble problems do a deepcopy
Expand All @@ -1694,7 +1702,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
end
if ofc.steady_state
obs = let fn = obs
Expand Down
9 changes: 8 additions & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
else
(ps,)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(jac,
Expand Down Expand Up @@ -208,6 +211,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
t = get_iv(sys)

if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end

if isdde
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘
Expand Down Expand Up @@ -439,7 +446,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
ArrayInterface.restructure(u0 .* u0', M)
end

observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds)

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down
10 changes: 8 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ function build_explicit_observed_function(sys, ts;
oop_mtkp_wrapper = mtkparams_wrapper
end

if !checkbounds
inbounds_wrapper = wrap_inbounds(false)
else
inbounds_wrapper = (identity, identity)
end

# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
return_value = if isscalar
Expand All @@ -641,14 +647,14 @@ function build_explicit_observed_function(sys, ts;
oop_fn = Func(args, [],
pre(Let(obsexprs,
return_value,
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> inbounds_wrapper[1] |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
iip_fn = build_function(ts,
args...;
postprocess_fbody = pre,
wrap_code = mtkparams_wrapper .∘ array_wrapper .∘
wrap_code = inbounds_wrapper .∘ mtkparams_wrapper .∘ array_wrapper .∘
wrap_assignments(isscalar, obsexprs),
expression = Val{true})[2]
if !expression
Expand Down

0 comments on commit 31edc8f

Please sign in to comment.