From 31edc8fdfa74cab76d4955874d13f7293c4e163b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 18 Dec 2024 12:12:59 +0530 Subject: [PATCH] feat: better inbounds handling and propagation for `ODEProblem` and observed functions --- src/systems/abstractsystem.jl | 18 +++++++++++++----- src/systems/diffeqs/abstractodesystem.jl | 9 ++++++++- src/systems/diffeqs/odesystem.jl | 10 ++++++++-- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 168260ae69..c6aec762b6 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -226,6 +226,13 @@ function wrap_assignments(isscalar, assignments; let_block = false) end end +function wrap_inbounds(isscalar) + function wrapper(expr) + Func(expr.args, [], :(@inbounds begin; $(toexpr(expr.body)); end)) + end + return isscalar ? wrapper : (wrapper, wrapper) +end + function wrap_parameter_dependencies(sys::AbstractSystem, isscalar) wrap_assignments(isscalar, [eq.lhs ← eq.rhs for eq in parameter_dependencies(sys)]) end @@ -785,7 +792,7 @@ end SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true function SymbolicIndexingInterface.observed( - sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__) + sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing if sym isa Symbol _sym = get(ic.symbol_to_variable, sym, nothing) @@ -808,7 +815,7 @@ function SymbolicIndexingInterface.observed( end end end - _fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module) + _fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds) if is_time_dependent(sys) return _fn @@ -1671,11 +1678,12 @@ struct ObservedFunctionCache{S} steady_state::Bool eval_expression::Bool eval_module::Module + checkbounds::Bool end function ObservedFunctionCache( - sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__) - return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module) + sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true) + return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds) end # This is hit because ensemble problems do a deepcopy @@ -1694,7 +1702,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...) obs = get!(ofc.dict, value(obsvar)) do SymbolicIndexingInterface.observed( ofc.sys, obsvar; eval_expression = ofc.eval_expression, - eval_module = ofc.eval_module) + eval_module = ofc.eval_module, checkbounds = ofc.checkbounds) end if ofc.steady_state obs = let fn = obs diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index f4e29346ff..b683ff5cfd 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -137,6 +137,9 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), else (ps,) end + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) .∘ wrap_parameter_dependencies(sys, false) return build_function(jac, @@ -208,6 +211,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps)) t = get_iv(sys) + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end + if isdde build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs..., wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘ @@ -439,7 +446,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, ArrayInterface.restructure(u0 .* u0', M) end - observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds) jac_prototype = if sparse uElType = u0 === nothing ? Float64 : eltype(u0) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 2b0bd8c8d7..63bbd43a68 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -629,6 +629,12 @@ function build_explicit_observed_function(sys, ts; oop_mtkp_wrapper = mtkparams_wrapper end + if !checkbounds + inbounds_wrapper = wrap_inbounds(false) + else + inbounds_wrapper = (identity, identity) + end + # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` return_value = if isscalar @@ -641,14 +647,14 @@ function build_explicit_observed_function(sys, ts; oop_fn = Func(args, [], pre(Let(obsexprs, return_value, - false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr + false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> inbounds_wrapper[1] |> toexpr oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module) if !isscalar iip_fn = build_function(ts, args...; postprocess_fbody = pre, - wrap_code = mtkparams_wrapper .∘ array_wrapper .∘ + wrap_code = inbounds_wrapper .∘ mtkparams_wrapper .∘ array_wrapper .∘ wrap_assignments(isscalar, obsexprs), expression = Val{true})[2] if !expression