Skip to content

Commit

Permalink
Merge pull request #2908 from AayushSabharwal/as/better-tunables
Browse files Browse the repository at this point in the history
refactor: turn tunables portion into a Vector{T}
  • Loading branch information
ChrisRackauckas authored Aug 2, 2024
2 parents e64c479 + 7291dc8 commit fbc2f40
Show file tree
Hide file tree
Showing 18 changed files with 359 additions and 205 deletions.
2 changes: 1 addition & 1 deletion src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{true}, kwargs...)
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
(; f, dvs, ps, io_sys = sys)
end
Expand Down
108 changes: 94 additions & 14 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,70 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys))
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
for (j, x) in enumerate(dvs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
if dvs !== nothing
for (j, x) in enumerate(dvs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
end
end
uind = 1
else
uind = 0
end
# tunables are scalarized and concatenated, so we need to have assignments
# for the non-scalarized versions
array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
# Other parameters may be scalarized arrays but used in the vector form
other_array_parameters = Dict{Any, Any}()

if ps isa Tuple && eltype(ps) <: AbstractArray
ps = Iterators.flatten(ps)
end
for p in ps
p = unwrap(p)
if iscall(p) && operation(p) == getindex
p = arguments(p)[1]
end
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
scal = collect(p)
# all scalarized variables are in `ps`
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue

idx = parameter_index(sys, p)
idx isa Int && continue
if idx isa ParameterIndex
if idx.portion != SciMLStructures.Tunable()
continue
end
idxs = vec(idx.idx)
sz = size(idx.idx)
else
# idx === nothing
idxs = map(Base.Fix1(parameter_index, sys), scal)
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
idxs = map(x -> x.idx, idxs)
end
if !all(x -> x isa Int, idxs)
other_array_parameters[p] = scal
continue
end

sz = size(idxs)
if vec(idxs) == idxs[begin]:idxs[end]
idxs = idxs[begin]:idxs[end]
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
idxs = idxs[begin]:-1:idxs[end]
end
idxs = vec(idxs)
end
array_tunables[p] = (idxs, sz)
end
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
Expand All @@ -244,7 +299,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
)
Expand All @@ -256,7 +317,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
)
Expand All @@ -267,7 +334,14 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[2].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
)
Expand Down Expand Up @@ -455,15 +529,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
return unwrap(sym) in 1:length(parameter_symbols(sys))
end
return any(isequal(sym), parameter_symbols(sys)) ||
hasname(sym) && is_parameter(sys, getname(sym))
hasname(sym) && !(iscall(sym) && operation(sym) == getindex) &&
is_parameter(sys, getname(sym))
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return is_parameter(ic, sym)
end

named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)]
named_parameters = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
return any(isequal(sym), named_parameters) ||
count(NAMESPACE_SEPARATOR, string(sym)) == 1 &&
count(isequal(sym),
Expand Down Expand Up @@ -499,7 +576,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
return sym
end
idx = findfirst(isequal(sym), parameter_symbols(sys))
if idx === nothing && hasname(sym)
if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex)
idx = parameter_index(sys, getname(sym))
end
return idx
Expand All @@ -515,13 +592,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
return idx
end
end
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
pnames = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
idx = findfirst(isequal(sym), pnames)
if idx !== nothing
return idx
elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1
return findfirst(isequal(sym),
Symbol.(
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys))))
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames))
end
return nothing
end
Expand Down
11 changes: 6 additions & 5 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
condit = substitute(condit, cmap)
end
expr = build_function(
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
condit, u, t, p...; expression = Val{true},
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
kwargs...)
if expression == Val{true}
return expr
Expand Down Expand Up @@ -411,10 +412,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
update_inds = map(sym -> unknownind[sym], update_vars)
elseif isparameter(first(lhss)) && alleq
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
update_inds = map(update_vars) do sym
pind = parameter_index(sys, sym)
discrete_linear_index(ic, pind)
return parameter_index(sys, sym)
end
else
psind = Dict(reverse(en) for en in enumerate(ps))
Expand All @@ -428,6 +427,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
update_inds = outputidxs
end

_ps = ps
ps = reorder_parameters(sys, ps)
if checkvars
u = map(x -> time_varying_as_func(value(x), sys), dvs)
Expand All @@ -440,7 +440,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
integ = gensym(:MTKIntegrator)
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar),
wrap_code = add_integrator_header(sys, integ, outvar) .∘
wrap_array_vars(sys, rhss; dvs, ps = _ps),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
Expand Down
12 changes: 8 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end

function generate_tgrad(
sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys);
simplify = false, kwargs...)
simplify = false, wrap_code = identity, kwargs...)
tgrad = calculate_tgrad(sys, simplify = simplify)
pre = get_preprocess_constants(tgrad)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
Expand All @@ -97,29 +97,33 @@ function generate_tgrad(
else
(ps,)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps)
return build_function(tgrad,
dvs,
p...,
get_iv(sys);
postprocess_fbody = pre,
wrap_code,
kwargs...)
end

function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys);
simplify = false, sparse = false, kwargs...)
simplify = false, sparse = false, wrap_code = identity, kwargs...)
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
pre = get_preprocess_constants(jac)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
reorder_parameters(get_index_cache(sys), ps)
else
(ps,)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps)
return build_function(jac,
dvs,
p...,
get_iv(sys);
postprocess_fbody = pre,
wrap_code,
kwargs...)
end

Expand Down Expand Up @@ -188,12 +192,12 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
if implicit_dae
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
states = sol_states,
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
kwargs...)
else
build_function(rhss, u, p..., t; postprocess_fbody = pre,
states = sol_states,
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
kwargs...)
end
end
Expand Down
10 changes: 8 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ function build_explicit_observed_function(sys, ts;
if inputs !== nothing
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
end
_ps = ps
if ps isa Tuple
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
Expand All @@ -505,19 +506,24 @@ function build_explicit_observed_function(sys, ts;
end
pre = get_postprocess_fbody(sys)

array_wrapper = if param_only
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing)
else
wrap_array_vars(sys, ts; ps = _ps)
end
# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
oop_fn = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
false))) |> array_wrapper[1] |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
iip_fn = build_function(ts,
args...;
postprocess_fbody = pre,
wrap_code = wrap_array_vars(sys, ts) .∘ wrap_assignments(isscalar, obsexprs),
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs),
expression = Val{true})[2]
if !expression
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,10 @@ function flatten(sys::DiscreteSystem, noeqs = false)
end

function generate_function(
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...)
generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...)
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); wrap_code = identity, kwargs...)
exprs = [eq.rhs for eq in equations(sys)]
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs)
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
end

function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;
Expand Down
Loading

0 comments on commit fbc2f40

Please sign in to comment.