Skip to content

Commit

Permalink
refactor: turn tunables portion into a Vector{T}
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jul 30, 2024
1 parent a7f6cd2 commit c848028
Show file tree
Hide file tree
Showing 14 changed files with 261 additions and 184 deletions.
2 changes: 1 addition & 1 deletion src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{true}, kwargs...)
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
(; f, dvs, ps, io_sys = sys)
end
Expand Down
41 changes: 32 additions & 9 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,29 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys))
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
for (j, x) in enumerate(dvs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
if dvs !== nothing
for (j, x) in enumerate(dvs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
end
end
uind = 1
else
uind = 0
end
array_pars = Dict{Any, AbstractArray{Int}}()
for p in ps
idx = parameter_index(sys, p)
idx isa ParameterIndex || continue
idx.portion isa SciMLStructures.Tunable || continue
idx.idx isa AbstractArray || continue
array_pars[p] = idx.idx
end
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
Expand All @@ -244,7 +258,10 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars]
),
expr.body,
false
)
Expand All @@ -256,7 +273,10 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars]
),
expr.body,
false
)
Expand All @@ -267,7 +287,10 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[2].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind+2].name), $v)) for (k, v) in array_pars]
),
expr.body,
false
)
Expand Down
10 changes: 5 additions & 5 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
condit = substitute(condit, cmap)
end
expr = build_function(
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
condit, u, t, p...; expression = Val{true},
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
kwargs...)
if expression == Val{true}
return expr
Expand Down Expand Up @@ -411,10 +412,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
update_inds = map(sym -> unknownind[sym], update_vars)
elseif isparameter(first(lhss)) && alleq
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
update_inds = map(update_vars) do sym
pind = parameter_index(sys, sym)
discrete_linear_index(ic, pind)
return parameter_index(sys, sym)
end
else
psind = Dict(reverse(en) for en in enumerate(ps))
Expand All @@ -440,7 +439,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
integ = gensym(:MTKIntegrator)
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar),
wrap_code = add_integrator_header(sys, integ, outvar) .∘
wrap_array_vars(sys, rhss; dvs, ps),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
Expand Down
12 changes: 8 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end

function generate_tgrad(
sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys);
simplify = false, kwargs...)
simplify = false, wrap_code = identity, kwargs...)
tgrad = calculate_tgrad(sys, simplify = simplify)
pre = get_preprocess_constants(tgrad)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
Expand All @@ -97,29 +97,33 @@ function generate_tgrad(
else
(ps,)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps)
return build_function(tgrad,
dvs,
p...,
get_iv(sys);
postprocess_fbody = pre,
wrap_code,
kwargs...)
end

function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys);
simplify = false, sparse = false, kwargs...)
simplify = false, sparse = false, wrap_code = identity, kwargs...)
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
pre = get_preprocess_constants(jac)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
reorder_parameters(get_index_cache(sys), ps)
else
(ps,)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps)
return build_function(jac,
dvs,
p...,
get_iv(sys);
postprocess_fbody = pre,
wrap_code,
kwargs...)
end

Expand Down Expand Up @@ -188,12 +192,12 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
if implicit_dae
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
states = sol_states,
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
kwargs...)
else
build_function(rhss, u, p..., t; postprocess_fbody = pre,
states = sol_states,
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
kwargs...)
end
end
Expand Down
10 changes: 8 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ function build_explicit_observed_function(sys, ts;
if inputs !== nothing
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
end
_ps = ps
if ps isa Tuple
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
Expand All @@ -505,19 +506,24 @@ function build_explicit_observed_function(sys, ts;
end
pre = get_postprocess_fbody(sys)

array_wrapper = if param_only
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing)
else
wrap_array_vars(sys, ts; ps = _ps)
end
# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
oop_fn = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
false))) |> array_wrapper[1] |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
iip_fn = build_function(ts,
args...;
postprocess_fbody = pre,
wrap_code = wrap_array_vars(sys, ts) .∘ wrap_assignments(isscalar, obsexprs),
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs),
expression = Val{true})[2]
if !expression
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,10 @@ function flatten(sys::DiscreteSystem, noeqs = false)
end

function generate_function(
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...)
generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...)
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); wrap_code = identity, kwargs...)
exprs = [eq.rhs for eq in equations(sys)]
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs)
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
end

function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;
Expand Down
66 changes: 42 additions & 24 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)
const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
const UnknownIndexMap = Dict{
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
const TunableIndexMap = Dict{BasicSymbolic,
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}

struct IndexCache
unknown_idx::UnknownIndexMap
discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}}
tunable_idx::ParamIndexMap
tunable_idx::TunableIndexMap
constant_idx::ParamIndexMap
dependent_idx::ParamIndexMap
nonnumeric_idx::ParamIndexMap
observed_syms::Set{BasicSymbolic}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_sizes::Vector{BufferTemplate}
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
dependent_buffer_sizes::Vector{BufferTemplate}
nonnumeric_buffer_sizes::Vector{BufferTemplate}
Expand Down Expand Up @@ -75,7 +77,7 @@ function IndexCache(sys::AbstractSystem)
end
end

observed_syms = Set{Union{Symbol, BasicSymbolic}}()
observed_syms = Set{BasicSymbolic}()
for eq in observed(sys)
if symbolic_type(eq.lhs) != NotSymbolic()
sym = eq.lhs
Expand Down Expand Up @@ -236,7 +238,10 @@ function IndexCache(sys::AbstractSystem)
haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
insert_by_type!(
if ctype <: Real || ctype <: AbstractArray{<:Real}
if istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown()
if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown &&
(ctype == Real || ctype <: AbstractFloat ||
ctype <: AbstractArray{Real} ||
ctype <: AbstractArray{<:AbstractFloat})
tunable_buffers
else
constant_buffers
Expand Down Expand Up @@ -292,11 +297,30 @@ function IndexCache(sys::AbstractSystem)
return idxs, buffer_sizes
end

tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers)
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)

tunable_idxs = TunableIndexMap()
tunable_buffer_size = 0
for (i, (_, buf)) in enumerate(tunable_buffers)
for (j, p) in enumerate(buf)
idx = if size(p) == ()
tunable_buffer_size + 1
else
reshape(
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
end
tunable_buffer_size += length(p)
tunable_idxs[p] = idx
tunable_idxs[default_toterm(p)] = idx
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
symbol_to_variable[getname(p)] = p
symbol_to_variable[getname(default_toterm(p))] = p
end
end
end

for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs),
observed_syms, independent_variable_symbols(sys)))
Expand All @@ -314,7 +338,7 @@ function IndexCache(sys::AbstractSystem)
nonnumeric_idxs,
observed_syms,
disc_buffer_sizes,
tunable_buffer_sizes,
BufferTemplate(Real, tunable_buffer_size),
const_buffer_sizes,
dependent_buffer_sizes,
nonnumeric_buffer_sizes,
Expand Down Expand Up @@ -410,20 +434,6 @@ function check_index_map(idxmap, sym)
end
end

function discrete_linear_index(ic::IndexCache, idx::ParameterIndex)
idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected")
ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0)
for clockbuftemps in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1)
ind += sum(temp.length for temp in clockbuftemps; init = 0)
end
ind += sum(
temp.length
for temp in Iterators.take(ic.discrete_buffer_sizes[idx.idx[1]], idx.idx[2] - 1);
init = 0)
ind += idx.idx[3]
return ind
end

function reorder_parameters(sys::AbstractSystem, ps; kwargs...)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
reorder_parameters(get_index_cache(sys), ps; kwargs...)
Expand All @@ -436,8 +446,12 @@ end

function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
isempty(ps) && return ()
param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
for temp in ic.tunable_buffer_sizes)
param_buf = if ic.tunable_buffer_size.length == 0
()
else
(BasicSymbolic[unwrap(variable(:DEF))
for _ in 1:(ic.tunable_buffer_size.length)],)
end
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
for temp in Iterators.flatten(ic.discrete_buffer_sizes))
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
Expand All @@ -453,8 +467,12 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
i, j, k = ic.discrete_idx[p]
disc_buf[(i - 1) * disc_offset + j][k] = p
elseif haskey(ic.tunable_idx, p)
i, j = ic.tunable_idx[p]
param_buf[i][j] = p
i = ic.tunable_idx[p]
if i isa Int
param_buf[1][i] = unwrap(p)
else
param_buf[1][i] = unwrap.(collect(p))
end
elseif haskey(ic.constant_idx, p)
i, j = ic.constant_idx[p]
const_buf[i][j] = p
Expand Down
1 change: 1 addition & 0 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ function generate_rate_function(js::JumpSystem, rate)
p = reorder_parameters(js, full_parameters(js))
rf = build_function(rate, unknowns(js), p...,
get_iv(js),
wrap_code = wrap_array_vars(js, rate; dvs = unknowns(js), ps = parameters(js)),
expression = Val{true})
end

Expand Down
Loading

0 comments on commit c848028

Please sign in to comment.