Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better inbounds handling and propagation for generated functions #3277

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
if wrap_code === nothing
wrap_code = isscalar ? identity : (identity, identity)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
if postprocess_fbody === nothing
postprocess_fbody = pre
Expand Down Expand Up @@ -226,6 +229,13 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_inbounds(isscalar)
function wrapper(expr)
Func(expr.args, [], :(@inbounds begin; $(toexpr(expr.body)); end))
end
return isscalar ? wrapper : (wrapper, wrapper)
end

function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
end
Expand Down Expand Up @@ -785,7 +795,7 @@ end
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true

function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -808,7 +818,7 @@ function SymbolicIndexingInterface.observed(
end
end
end
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds)

if is_time_dependent(sys)
return _fn
Expand Down Expand Up @@ -1671,11 +1681,12 @@ struct ObservedFunctionCache{S}
steady_state::Bool
eval_expression::Bool
eval_module::Module
checkbounds::Bool
end

function ObservedFunctionCache(
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
end

# This is hit because ensemble problems do a deepcopy
Expand All @@ -1694,7 +1705,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
end
if ofc.steady_state
obs = let fn = obs
Expand Down
8 changes: 7 additions & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,14 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
cmap = map(x -> x => getdefault(x), cs)
condit = substitute(condit, cmap)
end
if !get(kwargs, :checkbounds, false)
inbounds_wrapper = wrap_inbounds(!(condit isa AbstractArray))
else
inbounds_wrapper = condit isa AbstractArray ? (identity, identity) : identity
end
expr = build_function(
condit, u, t, p...; expression = Val{true},
wrap_code = condition_header(sys) .∘
wrap_code = condition_header(sys) .∘ inbounds_wrapper .∘
wrap_array_vars(sys, condit; dvs, ps, inputs = true) .∘
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
kwargs...)
Expand Down Expand Up @@ -671,6 +676,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
t = get_iv(sys)
integ = gensym(:MTKIntegrator)
pre = get_preprocess_constants(rhss)
inbounds_wrapper = get(kwargs, :checkbounds, false) ? (identity, identity) : wrap_inbounds(false)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = callback_save_header(sys, cb) .∘
add_integrator_header(sys, integ, outvar) .∘
Expand Down
14 changes: 12 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ function generate_tgrad(
else
(ps,)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps) .∘
wrap_parameter_dependencies(sys, !(tgrad isa AbstractArray))
return build_function(tgrad,
Expand All @@ -137,6 +140,9 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
else
(ps,)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(jac,
Expand Down Expand Up @@ -208,6 +214,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
t = get_iv(sys)

if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end

if isdde
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘
Expand Down Expand Up @@ -439,7 +449,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
ArrayInterface.restructure(u0 .* u0', M)
end

observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds)

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down Expand Up @@ -531,7 +541,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down
10 changes: 8 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ function build_explicit_observed_function(sys, ts;
oop_mtkp_wrapper = mtkparams_wrapper
end

if !checkbounds
inbounds_wrapper = wrap_inbounds(false)
else
inbounds_wrapper = (identity, identity)
end

# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
return_value = if isscalar
Expand All @@ -641,14 +647,14 @@ function build_explicit_observed_function(sys, ts;
oop_fn = Func(args, [],
pre(Let(obsexprs,
return_value,
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> inbounds_wrapper[1] |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
iip_fn = build_function(ts,
args...;
postprocess_fbody = pre,
wrap_code = mtkparams_wrapper .∘ array_wrapper .∘
wrap_code = inbounds_wrapper .∘ mtkparams_wrapper .∘ array_wrapper .∘
wrap_assignments(isscalar, obsexprs),
expression = Val{true})[2]
if !expression
Expand Down
2 changes: 1 addition & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
M = calculate_massmatrix(sys)
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

SDEFunction{iip, specialize}(f, g,
sys = sys,
Expand Down
5 changes: 4 additions & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ end
function generate_function(
sys::DiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
exprs = [eq.rhs for eq in equations(sys)]
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘
wrap_parameter_dependencies(sys, false)
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
Expand Down Expand Up @@ -327,7 +330,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
end

observedfun = ObservedFunctionCache(sys)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

DiscreteFunction{iip, specialize}(f;
sys = sys,
Expand Down
4 changes: 2 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
DiscreteProblem(df, u0, tspan, p; kwargs...)
Expand Down Expand Up @@ -504,7 +504,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; kwargs...)
end
Expand Down
16 changes: 14 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ function generate_jacobian(
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(
Expand All @@ -251,6 +254,9 @@ function generate_hessian(
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
pre = get_preprocess_constants(hess)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code, kwargs...)
Expand All @@ -266,6 +272,9 @@ function generate_function(
dvs′ = only(dvs)
end
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps) .∘
wrap_parameter_dependencies(sys, scalar)
p = reorder_parameters(sys, value.(ps))
Expand Down Expand Up @@ -342,7 +351,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

if length(dvs) == length(equations(sys))
resid_prototype = nothing
Expand Down Expand Up @@ -383,7 +392,7 @@ function SciMLBase.IntervalNonlinearFunction(
f(u, p) = f_oop(u, p)
f(u, p::MTKParameters) = f_oop(u, p...)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

IntervalNonlinearFunction{false}(f; observed = observedfun, sys = sys)
end
Expand Down Expand Up @@ -579,6 +588,9 @@ function SCCNonlinearFunction{iip}(
cmap, cs = get_cmap(sys)
cmap_assignments = [eq.lhs eq.rhs for eq in cmap]
rhss = [eq.rhs - eq.lhs for eq in _eqs]
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_assignments(false, cmap_assignments) .∘
(wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘
wrap_parameter_dependencies(sys, false) .∘
Expand Down
11 changes: 10 additions & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ function generate_gradient(sys::OptimizationSystem, vs = unknowns(sys),
grad = calculate_gradient(sys)
pre = get_preprocess_constants(grad)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, grad; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, !(grad isa AbstractArray))
return build_function(grad, vs, p...; postprocess_fbody = pre, wrap_code,
Expand All @@ -219,6 +222,9 @@ function generate_hessian(
end
pre = get_preprocess_constants(hess)
p = reorder_parameters(sys, ps)
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, false)
return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code,
Expand All @@ -235,6 +241,9 @@ function generate_function(sys::OptimizationSystem, vs = unknowns(sys),
else
(ps,)
end
if !get(kwargs, :checkbounds, false)
wrap_code = wrap_code .∘ wrap_inbounds(false)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, eqs; dvs = vs, ps) .∘
wrap_parameter_dependencies(sys, !(eqs isa AbstractArray))
return build_function(eqs, vs, p...; wrap_code,
Expand Down Expand Up @@ -419,7 +428,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
hess_prototype = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)

if length(cstr) > 0
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)
Expand Down
Loading