Skip to content

Commit

Permalink
feat: better inbounds handling and propagtion in generated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 18, 2024
1 parent 31f7a54 commit 5c8926e
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 17 deletions.
21 changes: 16 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
if wrap_code === nothing
wrap_code = isscalar ? identity : (identity, identity)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
if postprocess_fbody === nothing
postprocess_fbody = pre
Expand Down Expand Up @@ -226,6 +229,13 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_inbounds(isscalar)
function wrapper(expr)
Func(expr.args, [], :(@inbounds begin; $(toexpr(expr.body)); end))
end
return isscalar ? wrapper : (wrapper, wrapper)
end

function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
end
Expand Down Expand Up @@ -785,7 +795,7 @@ end
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true

function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -808,7 +818,7 @@ function SymbolicIndexingInterface.observed(
end
end
end
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds)

if is_time_dependent(sys)
return _fn
Expand Down Expand Up @@ -1671,11 +1681,12 @@ struct ObservedFunctionCache{S}
steady_state::Bool
eval_expression::Bool
eval_module::Module
checkbounds::Bool
end

function ObservedFunctionCache(
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
end

# This is hit because ensemble problems do a deepcopy
Expand All @@ -1694,7 +1705,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
end
if ofc.steady_state
obs = let fn = obs
Expand Down
8 changes: 7 additions & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,14 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
cmap = map(x -> x => getdefault(x), cs)
condit = substitute(condit, cmap)
end
if !get(kwargs, :checkbounds, false)
inbounds_wrapper = wrap_inbounds(!(condit isa AbstractArray))
else
inbounds_wrapper = condit isa AbstractArray ? (identity, identity) : identity
end
expr = build_function(
condit, u, t, p...; expression = Val{true},
wrap_code = condition_header(sys) .∘
wrap_code = condition_header(sys) .∘ inbounds_wrapper .∘
wrap_array_vars(sys, condit; dvs, ps, inputs = true) .∘
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
kwargs...)
Expand Down Expand Up @@ -671,6 +676,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
t = get_iv(sys)
integ = gensym(:MTKIntegrator)
pre = get_preprocess_constants(rhss)
inbounds_wrapper = get(kwargs, :checkbounds, false) ? (identity, identity) : wrap_inbounds(false)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = callback_save_header(sys, cb) .∘
add_integrator_header(sys, integ, outvar) .∘
Expand Down
14 changes: 12 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ function generate_tgrad(
else
(ps,)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps) .∘
wrap_parameter_dependencies(sys, !(tgrad isa AbstractArray))
return build_function(tgrad,
Expand All @@ -137,6 +140,9 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
else
(ps,)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(jac,
Expand Down Expand Up @@ -208,6 +214,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
t = get_iv(sys)

if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end

if isdde
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘
Expand Down Expand Up @@ -439,7 +449,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
ArrayInterface.restructure(u0 .* u0', M)
end

observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds)

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down Expand Up @@ -531,7 +541,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down
10 changes: 8 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ function build_explicit_observed_function(sys, ts;
oop_mtkp_wrapper = mtkparams_wrapper
end

if !checkbounds
inbounds_wrapper = wrap_inbounds(false)
else
inbounds_wrapper = (identity, identity)
end

# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
return_value = if isscalar
Expand All @@ -641,14 +647,14 @@ function build_explicit_observed_function(sys, ts;
oop_fn = Func(args, [],
pre(Let(obsexprs,
return_value,
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> inbounds_wrapper[1] |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
iip_fn = build_function(ts,
args...;
postprocess_fbody = pre,
wrap_code = mtkparams_wrapper .∘ array_wrapper .∘
wrap_code = inbounds_wrapper .∘ mtkparams_wrapper .∘ array_wrapper .∘
wrap_assignments(isscalar, obsexprs),
expression = Val{true})[2]
if !expression
Expand Down
2 changes: 1 addition & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
M = calculate_massmatrix(sys)
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

SDEFunction{iip, specialize}(f, g,
sys = sys,
Expand Down
5 changes: 4 additions & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ end
function generate_function(
sys::DiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
exprs = [eq.rhs for eq in equations(sys)]
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘
wrap_parameter_dependencies(sys, false)
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
Expand Down Expand Up @@ -327,7 +330,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
end

observedfun = ObservedFunctionCache(sys)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

DiscreteFunction{iip, specialize}(f;
sys = sys,
Expand Down
4 changes: 2 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
DiscreteProblem(df, u0, tspan, p; kwargs...)
Expand Down Expand Up @@ -504,7 +504,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; kwargs...)
end
Expand Down
16 changes: 14 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ function generate_jacobian(
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(
Expand All @@ -251,6 +254,9 @@ function generate_hessian(
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
pre = get_preprocess_constants(hess)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code, kwargs...)
Expand All @@ -266,6 +272,9 @@ function generate_function(
dvs′ = only(dvs)
end
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps) .∘
wrap_parameter_dependencies(sys, scalar)
p = reorder_parameters(sys, value.(ps))
Expand Down Expand Up @@ -342,7 +351,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

if length(dvs) == length(equations(sys))
resid_prototype = nothing
Expand Down Expand Up @@ -383,7 +392,7 @@ function SciMLBase.IntervalNonlinearFunction(
f(u, p) = f_oop(u, p)
f(u, p::MTKParameters) = f_oop(u, p...)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

IntervalNonlinearFunction{false}(f; observed = observedfun, sys = sys)
end
Expand Down Expand Up @@ -579,6 +588,9 @@ function SCCNonlinearFunction{iip}(
cmap, cs = get_cmap(sys)
cmap_assignments = [eq.lhs eq.rhs for eq in cmap]
rhss = [eq.rhs - eq.lhs for eq in _eqs]
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_assignments(false, cmap_assignments) .∘
(wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘
wrap_parameter_dependencies(sys, false) .∘
Expand Down
11 changes: 10 additions & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ function generate_gradient(sys::OptimizationSystem, vs = unknowns(sys),
grad = calculate_gradient(sys)
pre = get_preprocess_constants(grad)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, grad; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, !(grad isa AbstractArray))
return build_function(grad, vs, p...; postprocess_fbody = pre, wrap_code,
Expand All @@ -219,6 +222,9 @@ function generate_hessian(
end
pre = get_preprocess_constants(hess)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code,
Expand All @@ -235,6 +241,9 @@ function generate_function(sys::OptimizationSystem, vs = unknowns(sys),
else
(ps,)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, eqs; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, !(eqs isa AbstractArray))
return build_function(eqs, vs, p...; wrap_code,
Expand Down Expand Up @@ -419,7 +428,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
hess_prototype = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)

if length(cstr) > 0
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)
Expand Down

0 comments on commit 5c8926e

Please sign in to comment.