From 5c8926e935dae0a022a7d5f7384a81d2b93a5d81 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 18 Dec 2024 12:12:59 +0530 Subject: [PATCH] feat: better inbounds handling and propagtion in generated functions --- src/systems/abstractsystem.jl | 21 ++++++++++++++----- src/systems/callbacks.jl | 8 ++++++- src/systems/diffeqs/abstractodesystem.jl | 14 +++++++++++-- src/systems/diffeqs/odesystem.jl | 10 +++++++-- src/systems/diffeqs/sdesystem.jl | 2 +- .../discrete_system/discrete_system.jl | 5 ++++- src/systems/jumps/jumpsystem.jl | 4 ++-- src/systems/nonlinear/nonlinearsystem.jl | 16 ++++++++++++-- .../optimization/optimizationsystem.jl | 11 +++++++++- 9 files changed, 74 insertions(+), 17 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 168260ae69..fcbb304d60 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -172,6 +172,9 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys if wrap_code === nothing wrap_code = isscalar ? identity : (identity, identity) end + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs) if postprocess_fbody === nothing postprocess_fbody = pre @@ -226,6 +229,13 @@ function wrap_assignments(isscalar, assignments; let_block = false) end end +function wrap_inbounds(isscalar) + function wrapper(expr) + Func(expr.args, [], :(@inbounds begin; $(toexpr(expr.body)); end)) + end + return isscalar ? wrapper : (wrapper, wrapper) +end + function wrap_parameter_dependencies(sys::AbstractSystem, isscalar) wrap_assignments(isscalar, [eq.lhs ← eq.rhs for eq in parameter_dependencies(sys)]) end @@ -785,7 +795,7 @@ end SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true function SymbolicIndexingInterface.observed( - sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__) + sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing if sym isa Symbol _sym = get(ic.symbol_to_variable, sym, nothing) @@ -808,7 +818,7 @@ function SymbolicIndexingInterface.observed( end end end - _fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module) + _fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds) if is_time_dependent(sys) return _fn @@ -1671,11 +1681,12 @@ struct ObservedFunctionCache{S} steady_state::Bool eval_expression::Bool eval_module::Module + checkbounds::Bool end function ObservedFunctionCache( - sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__) - return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module) + sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true) + return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds) end # This is hit because ensemble problems do a deepcopy @@ -1694,7 +1705,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...) obs = get!(ofc.dict, value(obsvar)) do SymbolicIndexingInterface.observed( ofc.sys, obsvar; eval_expression = ofc.eval_expression, - eval_module = ofc.eval_module) + eval_module = ofc.eval_module, checkbounds = ofc.checkbounds) end if ofc.steady_state obs = let fn = obs diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index eaf31a9c5f..d3009089d8 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -583,9 +583,14 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; cmap = map(x -> x => getdefault(x), cs) condit = substitute(condit, cmap) end + if !get(kwargs, :checkbounds, false) + inbounds_wrapper = wrap_inbounds(!(condit isa AbstractArray)) + else + inbounds_wrapper = condit isa AbstractArray ? (identity, identity) : identity + end expr = build_function( condit, u, t, p...; expression = Val{true}, - wrap_code = condition_header(sys) .∘ + wrap_code = condition_header(sys) .∘ inbounds_wrapper .∘ wrap_array_vars(sys, condit; dvs, ps, inputs = true) .∘ wrap_parameter_dependencies(sys, !(condit isa AbstractArray)), kwargs...) @@ -671,6 +676,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no t = get_iv(sys) integ = gensym(:MTKIntegrator) pre = get_preprocess_constants(rhss) + inbounds_wrapper = get(kwargs, :checkbounds, false) ? (identity, identity) : wrap_inbounds(false) rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true}, wrap_code = callback_save_header(sys, cb) .∘ add_integrator_header(sys, integ, outvar) .∘ diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index f4e29346ff..4f55d5beb1 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -116,6 +116,9 @@ function generate_tgrad( else (ps,) end + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps) .∘ wrap_parameter_dependencies(sys, !(tgrad isa AbstractArray)) return build_function(tgrad, @@ -137,6 +140,9 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), else (ps,) end + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) .∘ wrap_parameter_dependencies(sys, false) return build_function(jac, @@ -208,6 +214,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps)) t = get_iv(sys) + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end + if isdde build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs..., wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘ @@ -439,7 +449,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, ArrayInterface.restructure(u0 .* u0', M) end - observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds) jac_prototype = if sparse uElType = u0 === nothing ? Float64 : eltype(u0) @@ -531,7 +541,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) _jac = nothing end - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) jac_prototype = if sparse uElType = u0 === nothing ? Float64 : eltype(u0) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 2b0bd8c8d7..63bbd43a68 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -629,6 +629,12 @@ function build_explicit_observed_function(sys, ts; oop_mtkp_wrapper = mtkparams_wrapper end + if !checkbounds + inbounds_wrapper = wrap_inbounds(false) + else + inbounds_wrapper = (identity, identity) + end + # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` return_value = if isscalar @@ -641,14 +647,14 @@ function build_explicit_observed_function(sys, ts; oop_fn = Func(args, [], pre(Let(obsexprs, return_value, - false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr + false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> inbounds_wrapper[1] |> toexpr oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module) if !isscalar iip_fn = build_function(ts, args...; postprocess_fbody = pre, - wrap_code = mtkparams_wrapper .∘ array_wrapper .∘ + wrap_code = inbounds_wrapper .∘ mtkparams_wrapper .∘ array_wrapper .∘ wrap_assignments(isscalar, obsexprs), expression = Val{true})[2] if !expression diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index ac47f4c45c..88d98dcb07 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -589,7 +589,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( M = calculate_massmatrix(sys) _M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M) - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) SDEFunction{iip, specialize}(f, g, sys = sys, diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 3e220998cb..8e48fc9369 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -234,6 +234,9 @@ end function generate_function( sys::DiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...) exprs = [eq.rhs for eq in equations(sys)] + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘ wrap_parameter_dependencies(sys, false) generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...) @@ -327,7 +330,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}( f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t)) end - observedfun = ObservedFunctionCache(sys) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) DiscreteFunction{iip, specialize}(f; sys = sys, diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index e5e17fb5f9..02784718a9 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -406,7 +406,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false) f = DiffEqBase.DISCRETE_INPLACE_DEFAULT - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun) DiscreteProblem(df, u0, tspan, p; kwargs...) @@ -504,7 +504,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false) f = (du, u, p, t) -> (du .= 0; nothing) - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) df = ODEFunction(f; sys, observed = observedfun) return ODEProblem(df, u0, tspan, p; kwargs...) end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 3cb68853aa..5e42e35e23 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -227,6 +227,9 @@ function generate_jacobian( jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify) pre, sol_states = get_substitutions_and_solved_unknowns(sys) p = reorder_parameters(sys, ps) + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs = vs, ps) .∘ wrap_parameter_dependencies(sys, false) return build_function( @@ -251,6 +254,9 @@ function generate_hessian( hess = calculate_hessian(sys, sparse = sparse, simplify = simplify) pre = get_preprocess_constants(hess) p = reorder_parameters(sys, ps) + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘ wrap_parameter_dependencies(sys, false) return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code, kwargs...) @@ -266,6 +272,9 @@ function generate_function( dvs′ = only(dvs) end pre, sol_states = get_substitutions_and_solved_unknowns(sys) + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps) .∘ wrap_parameter_dependencies(sys, scalar) p = reorder_parameters(sys, value.(ps)) @@ -342,7 +351,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s _jac = nothing end - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) if length(dvs) == length(equations(sys)) resid_prototype = nothing @@ -383,7 +392,7 @@ function SciMLBase.IntervalNonlinearFunction( f(u, p) = f_oop(u, p) f(u, p::MTKParameters) = f_oop(u, p...) - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false)) IntervalNonlinearFunction{false}(f; observed = observedfun, sys = sys) end @@ -579,6 +588,9 @@ function SCCNonlinearFunction{iip}( cmap, cs = get_cmap(sys) cmap_assignments = [eq.lhs ← eq.rhs for eq in cmap] rhss = [eq.rhs - eq.lhs for eq in _eqs] + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_assignments(false, cmap_assignments) .∘ (wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘ wrap_parameter_dependencies(sys, false) .∘ diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 43e9294dd3..96ca5ee255 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -199,6 +199,9 @@ function generate_gradient(sys::OptimizationSystem, vs = unknowns(sys), grad = calculate_gradient(sys) pre = get_preprocess_constants(grad) p = reorder_parameters(sys, ps) + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, grad; dvs = vs, ps) .∘ wrap_parameter_dependencies(sys, !(grad isa AbstractArray)) return build_function(grad, vs, p...; postprocess_fbody = pre, wrap_code, @@ -219,6 +222,9 @@ function generate_hessian( end pre = get_preprocess_constants(hess) p = reorder_parameters(sys, ps) + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘ wrap_parameter_dependencies(sys, false) return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code, @@ -235,6 +241,9 @@ function generate_function(sys::OptimizationSystem, vs = unknowns(sys), else (ps,) end + if !get(kwargs, :checkbounds, false) + wrap_code = wrap_code .∘ wrap_inbounds(false) + end wrap_code = wrap_code .∘ wrap_array_vars(sys, eqs; dvs = vs, ps) .∘ wrap_parameter_dependencies(sys, !(eqs isa AbstractArray)) return build_function(eqs, vs, p...; wrap_code, @@ -419,7 +428,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, hess_prototype = nothing end - observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) + observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds) if length(cstr) > 0 @named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)