Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: turn tunables portion into a Vector{T} #2908

Merged
merged 3 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{true}, kwargs...)
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
(; f, dvs, ps, io_sys = sys)
end
Expand Down
108 changes: 94 additions & 14 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,70 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys))
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
for (j, x) in enumerate(dvs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
if dvs !== nothing
for (j, x) in enumerate(dvs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
end
end
uind = 1
else
uind = 0
end
# tunables are scalarized and concatenated, so we need to have assignments
# for the non-scalarized versions
array_tunables = Dict{Any, Tuple{AbstractArray{Int}, Tuple{Vararg{Int}}}}()
# Other parameters may be scalarized arrays but used in the vector form
other_array_parameters = Dict{Any, Any}()

if ps isa Tuple && eltype(ps) <: AbstractArray
ps = Iterators.flatten(ps)
end
for p in ps
p = unwrap(p)
if iscall(p) && operation(p) == getindex
p = arguments(p)[1]
end
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
scal = collect(p)
# all scalarized variables are in `ps`
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue

idx = parameter_index(sys, p)
idx isa Int && continue
if idx isa ParameterIndex
if idx.portion != SciMLStructures.Tunable()
continue
end
idxs = vec(idx.idx)
sz = size(idx.idx)
else
# idx === nothing
idxs = map(Base.Fix1(parameter_index, sys), scal)
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
idxs = map(x -> x.idx, idxs)
end
if !all(x -> x isa Int, idxs)
other_array_parameters[p] = scal
continue
end

sz = size(idxs)
if vec(idxs) == idxs[begin]:idxs[end]
idxs = idxs[begin]:idxs[end]
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
idxs = idxs[begin]:-1:idxs[end]
end
idxs = vec(idxs)
end
array_tunables[p] = (idxs, sz)
end
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
Expand All @@ -244,7 +299,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
)
Expand All @@ -256,7 +317,13 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[1].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(reshape(view($(expr.args[uind + 1].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
)
Expand All @@ -267,7 +334,14 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
expr.args,
[],
Let(
[k :(view($(expr.args[2].name), $v)) for (k, v) in array_vars],
vcat(
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k :(reshape(view($(expr.args[uind + 2].name), $idxs), $sz))
for (k, (idxs, sz)) in array_tunables],
[k Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
),
expr.body,
false
)
Expand Down Expand Up @@ -455,15 +529,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
return unwrap(sym) in 1:length(parameter_symbols(sys))
end
return any(isequal(sym), parameter_symbols(sys)) ||
hasname(sym) && is_parameter(sys, getname(sym))
hasname(sym) && !(iscall(sym) && operation(sym) == getindex) &&
is_parameter(sys, getname(sym))
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return is_parameter(ic, sym)
end

named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)]
named_parameters = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
return any(isequal(sym), named_parameters) ||
count(NAMESPACE_SEPARATOR, string(sym)) == 1 &&
count(isequal(sym),
Expand Down Expand Up @@ -499,7 +576,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
return sym
end
idx = findfirst(isequal(sym), parameter_symbols(sys))
if idx === nothing && hasname(sym)
if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex)
idx = parameter_index(sys, getname(sym))
end
return idx
Expand All @@ -515,13 +592,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
return idx
end
end
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
pnames = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
idx = findfirst(isequal(sym), pnames)
if idx !== nothing
return idx
elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1
return findfirst(isequal(sym),
Symbol.(
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys))))
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames))
end
return nothing
end
Expand Down
11 changes: 6 additions & 5 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
condit = substitute(condit, cmap)
end
expr = build_function(
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
condit, u, t, p...; expression = Val{true},
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
kwargs...)
if expression == Val{true}
return expr
Expand Down Expand Up @@ -411,10 +412,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
update_inds = map(sym -> unknownind[sym], update_vars)
elseif isparameter(first(lhss)) && alleq
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
update_inds = map(update_vars) do sym
pind = parameter_index(sys, sym)
discrete_linear_index(ic, pind)
return parameter_index(sys, sym)
end
else
psind = Dict(reverse(en) for en in enumerate(ps))
Expand All @@ -428,6 +427,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
update_inds = outputidxs
end

_ps = ps
ps = reorder_parameters(sys, ps)
if checkvars
u = map(x -> time_varying_as_func(value(x), sys), dvs)
Expand All @@ -440,7 +440,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
integ = gensym(:MTKIntegrator)
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar),
wrap_code = add_integrator_header(sys, integ, outvar) .∘
wrap_array_vars(sys, rhss; dvs, ps = _ps),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
Expand Down
12 changes: 8 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end

function generate_tgrad(
sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys);
simplify = false, kwargs...)
simplify = false, wrap_code = identity, kwargs...)
tgrad = calculate_tgrad(sys, simplify = simplify)
pre = get_preprocess_constants(tgrad)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
Expand All @@ -97,29 +97,33 @@ function generate_tgrad(
else
(ps,)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps)
return build_function(tgrad,
dvs,
p...,
get_iv(sys);
postprocess_fbody = pre,
wrap_code,
kwargs...)
end

function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys);
simplify = false, sparse = false, kwargs...)
simplify = false, sparse = false, wrap_code = identity, kwargs...)
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
pre = get_preprocess_constants(jac)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
reorder_parameters(get_index_cache(sys), ps)
else
(ps,)
end
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps)
return build_function(jac,
dvs,
p...,
get_iv(sys);
postprocess_fbody = pre,
wrap_code,
kwargs...)
end

Expand Down Expand Up @@ -188,12 +192,12 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
if implicit_dae
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
states = sol_states,
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
kwargs...)
else
build_function(rhss, u, p..., t; postprocess_fbody = pre,
states = sol_states,
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs),
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps),
kwargs...)
end
end
Expand Down
10 changes: 8 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ function build_explicit_observed_function(sys, ts;
if inputs !== nothing
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
end
_ps = ps
if ps isa Tuple
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
Expand All @@ -505,19 +506,24 @@ function build_explicit_observed_function(sys, ts;
end
pre = get_postprocess_fbody(sys)

array_wrapper = if param_only
wrap_array_vars(sys, ts; ps = _ps, dvs = nothing)
else
wrap_array_vars(sys, ts; ps = _ps)
end
# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
oop_fn = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
false))) |> array_wrapper[1] |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
iip_fn = build_function(ts,
args...;
postprocess_fbody = pre,
wrap_code = wrap_array_vars(sys, ts) .∘ wrap_assignments(isscalar, obsexprs),
wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs),
expression = Val{true})[2]
if !expression
iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,10 @@ function flatten(sys::DiscreteSystem, noeqs = false)
end

function generate_function(
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...)
generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...)
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); wrap_code = identity, kwargs...)
exprs = [eq.rhs for eq in equations(sys)]
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs)
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
end

function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;
Expand Down
Loading
Loading