From c848028e70c6599c9708608278b9f0b8a9dc3b00 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Jul 2024 11:50:07 +0530 Subject: [PATCH] refactor: turn tunables portion into a Vector{T} --- src/inputoutput.jl | 2 +- src/systems/abstractsystem.jl | 41 +++- src/systems/callbacks.jl | 10 +- src/systems/diffeqs/abstractodesystem.jl | 12 +- src/systems/diffeqs/odesystem.jl | 10 +- .../discrete_system/discrete_system.jl | 6 +- src/systems/index_cache.jl | 66 +++-- src/systems/jumps/jumpsystem.jl | 1 + src/systems/nonlinear/nonlinearsystem.jl | 16 +- .../optimization/constraints_system.jl | 14 +- .../optimization/optimizationsystem.jl | 13 +- src/systems/parameter_buffer.jl | 227 +++++++++--------- test/jumpsystem.jl | 4 +- test/split_parameters.jl | 23 +- 14 files changed, 261 insertions(+), 184 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 6ef36cddcc..7a89d69820 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -243,7 +243,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu end process = get_postprocess_fbody(sys) f = build_function(rhss, args...; postprocess_fbody = process, - expression = Val{true}, kwargs...) + expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...) f = eval_or_rgf.(f; eval_expression, eval_module) (; f, dvs, ps, io_sys = sys) end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 11292752cc..f1b1c72dd0 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -223,15 +223,29 @@ function wrap_assignments(isscalar, assignments; let_block = false) end end -function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys)) +function wrap_array_vars( + sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys)) isscalar = !(exprs isa AbstractArray) array_vars = Dict{Any, AbstractArray{Int}}() - for (j, x) in enumerate(dvs) - if iscall(x) && operation(x) == getindex - arg = arguments(x)[1] - inds = get!(() -> Int[], array_vars, arg) - push!(inds, j) + if dvs !== nothing + for (j, x) in enumerate(dvs) + if iscall(x) && operation(x) == getindex + arg = arguments(x)[1] + inds = get!(() -> Int[], array_vars, arg) + push!(inds, j) + end end + uind = 1 + else + uind = 0 + end + array_pars = Dict{Any, AbstractArray{Int}}() + for p in ps + idx = parameter_index(sys, p) + idx isa ParameterIndex || continue + idx.portion isa SciMLStructures.Tunable || continue + idx.idx isa AbstractArray || continue + array_pars[p] = idx.idx end for (k, inds) in array_vars if inds == (inds′ = inds[1]:inds[end]) @@ -244,7 +258,10 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys)) expr.args, [], Let( - [k ← :(view($(expr.args[1].name), $v)) for (k, v) in array_vars], + vcat( + [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], + [k ← :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars] + ), expr.body, false ) @@ -256,7 +273,10 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys)) expr.args, [], Let( - [k ← :(view($(expr.args[1].name), $v)) for (k, v) in array_vars], + vcat( + [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], + [k ← :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars] + ), expr.body, false ) @@ -267,7 +287,10 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys)) expr.args, [], Let( - [k ← :(view($(expr.args[2].name), $v)) for (k, v) in array_vars], + vcat( + [k ← :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_vars], + [k ← :(view($(expr.args[uind+2].name), $v)) for (k, v) in array_pars] + ), expr.body, false ) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 3fe1f7f006..91d27dc741 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -354,7 +354,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; condit = substitute(condit, cmap) end expr = build_function( - condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys), + condit, u, t, p...; expression = Val{true}, + wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps), kwargs...) if expression == Val{true} return expr @@ -411,10 +412,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin update_inds = map(sym -> unknownind[sym], update_vars) elseif isparameter(first(lhss)) && alleq if has_index_cache(sys) && get_index_cache(sys) !== nothing - ic = get_index_cache(sys) update_inds = map(update_vars) do sym - pind = parameter_index(sys, sym) - discrete_linear_index(ic, pind) + return parameter_index(sys, sym) end else psind = Dict(reverse(en) for en in enumerate(ps)) @@ -440,7 +439,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin integ = gensym(:MTKIntegrator) pre = get_preprocess_constants(rhss) rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true}, - wrap_code = add_integrator_header(sys, integ, outvar), + wrap_code = add_integrator_header(sys, integ, outvar) .∘ + wrap_array_vars(sys, rhss; dvs, ps), outputidxs = update_inds, postprocess_fbody = pre, kwargs...) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 5f69266f7e..1e3aff7106 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -87,7 +87,7 @@ end function generate_tgrad( sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys); - simplify = false, kwargs...) + simplify = false, wrap_code = identity, kwargs...) tgrad = calculate_tgrad(sys, simplify = simplify) pre = get_preprocess_constants(tgrad) p = if has_index_cache(sys) && get_index_cache(sys) !== nothing @@ -97,17 +97,19 @@ function generate_tgrad( else (ps,) end + wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps) return build_function(tgrad, dvs, p..., get_iv(sys); postprocess_fbody = pre, + wrap_code, kwargs...) end function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys); - simplify = false, sparse = false, kwargs...) + simplify = false, sparse = false, wrap_code = identity, kwargs...) jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse) pre = get_preprocess_constants(jac) p = if has_index_cache(sys) && get_index_cache(sys) !== nothing @@ -115,11 +117,13 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), else (ps,) end + wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) return build_function(jac, dvs, p..., get_iv(sys); postprocess_fbody = pre, + wrap_code, kwargs...) end @@ -188,12 +192,12 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), if implicit_dae build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre, states = sol_states, - wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs), + wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps), kwargs...) else build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states, - wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs), + wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps), kwargs...) end end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e28f1ece3b..264c00590b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -485,6 +485,7 @@ function build_explicit_observed_function(sys, ts; if inputs !== nothing ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list end + _ps = ps if ps isa Tuple ps = DestructuredArgs.(ps, inbounds = !checkbounds) elseif has_index_cache(sys) && get_index_cache(sys) !== nothing @@ -505,19 +506,24 @@ function build_explicit_observed_function(sys, ts; end pre = get_postprocess_fbody(sys) + array_wrapper = if param_only + wrap_array_vars(sys, ts; ps = _ps, dvs = nothing) + else + wrap_array_vars(sys, ts; ps = _ps) + end # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` oop_fn = Func(args, [], pre(Let(obsexprs, isscalar ? ts[1] : MakeArray(ts, output_type), - false))) |> wrap_array_vars(sys, ts)[1] |> toexpr + false))) |> array_wrapper[1] |> toexpr oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module) if !isscalar iip_fn = build_function(ts, args...; postprocess_fbody = pre, - wrap_code = wrap_array_vars(sys, ts) .∘ wrap_assignments(isscalar, obsexprs), + wrap_code = array_wrapper .∘ wrap_assignments(isscalar, obsexprs), expression = Val{true})[2] if !expression iip_fn = eval_or_rgf(iip_fn; eval_expression, eval_module) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 86d23acea0..0245f28421 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -218,8 +218,10 @@ function flatten(sys::DiscreteSystem, noeqs = false) end function generate_function( - sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...) - generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...) + sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); wrap_code = identity, kwargs...) + exprs = [eq.rhs for eq in equations(sys)] + wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) + generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...) end function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap; diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 899bba4aa5..6979cb26d3 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -24,17 +24,19 @@ ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false) const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} const UnknownIndexMap = Dict{ BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} +const TunableIndexMap = Dict{BasicSymbolic, + Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}} struct IndexCache unknown_idx::UnknownIndexMap discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}} - tunable_idx::ParamIndexMap + tunable_idx::TunableIndexMap constant_idx::ParamIndexMap dependent_idx::ParamIndexMap nonnumeric_idx::ParamIndexMap observed_syms::Set{BasicSymbolic} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} - tunable_buffer_sizes::Vector{BufferTemplate} + tunable_buffer_size::BufferTemplate constant_buffer_sizes::Vector{BufferTemplate} dependent_buffer_sizes::Vector{BufferTemplate} nonnumeric_buffer_sizes::Vector{BufferTemplate} @@ -75,7 +77,7 @@ function IndexCache(sys::AbstractSystem) end end - observed_syms = Set{Union{Symbol, BasicSymbolic}}() + observed_syms = Set{BasicSymbolic}() for eq in observed(sys) if symbolic_type(eq.lhs) != NotSymbolic() sym = eq.lhs @@ -236,7 +238,10 @@ function IndexCache(sys::AbstractSystem) haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} - if istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown() + if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown && + (ctype == Real || ctype <: AbstractFloat || + ctype <: AbstractArray{Real} || + ctype <: AbstractArray{<:AbstractFloat}) tunable_buffers else constant_buffers @@ -292,11 +297,30 @@ function IndexCache(sys::AbstractSystem) return idxs, buffer_sizes end - tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers) const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers) dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers) nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers) + tunable_idxs = TunableIndexMap() + tunable_buffer_size = 0 + for (i, (_, buf)) in enumerate(tunable_buffers) + for (j, p) in enumerate(buf) + idx = if size(p) == () + tunable_buffer_size + 1 + else + reshape( + (tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p)) + end + tunable_buffer_size += length(p) + tunable_idxs[p] = idx + tunable_idxs[default_toterm(p)] = idx + if hasname(p) && (!iscall(p) || operation(p) !== getindex) + symbol_to_variable[getname(p)] = p + symbol_to_variable[getname(default_toterm(p))] = p + end + end + end + for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs), observed_syms, independent_variable_symbols(sys))) @@ -314,7 +338,7 @@ function IndexCache(sys::AbstractSystem) nonnumeric_idxs, observed_syms, disc_buffer_sizes, - tunable_buffer_sizes, + BufferTemplate(Real, tunable_buffer_size), const_buffer_sizes, dependent_buffer_sizes, nonnumeric_buffer_sizes, @@ -410,20 +434,6 @@ function check_index_map(idxmap, sym) end end -function discrete_linear_index(ic::IndexCache, idx::ParameterIndex) - idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected") - ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0) - for clockbuftemps in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1) - ind += sum(temp.length for temp in clockbuftemps; init = 0) - end - ind += sum( - temp.length - for temp in Iterators.take(ic.discrete_buffer_sizes[idx.idx[1]], idx.idx[2] - 1); - init = 0) - ind += idx.idx[3] - return ind -end - function reorder_parameters(sys::AbstractSystem, ps; kwargs...) if has_index_cache(sys) && get_index_cache(sys) !== nothing reorder_parameters(get_index_cache(sys), ps; kwargs...) @@ -436,8 +446,12 @@ end function reorder_parameters(ic::IndexCache, ps; drop_missing = false) isempty(ps) && return () - param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] - for temp in ic.tunable_buffer_sizes) + param_buf = if ic.tunable_buffer_size.length == 0 + () + else + (BasicSymbolic[unwrap(variable(:DEF)) + for _ in 1:(ic.tunable_buffer_size.length)],) + end disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in Iterators.flatten(ic.discrete_buffer_sizes)) const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] @@ -453,8 +467,12 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false) i, j, k = ic.discrete_idx[p] disc_buf[(i - 1) * disc_offset + j][k] = p elseif haskey(ic.tunable_idx, p) - i, j = ic.tunable_idx[p] - param_buf[i][j] = p + i = ic.tunable_idx[p] + if i isa Int + param_buf[1][i] = unwrap(p) + else + param_buf[1][i] = unwrap.(collect(p)) + end elseif haskey(ic.constant_idx, p) i, j = ic.constant_idx[p] const_buf[i][j] = p diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 536252fec4..da75b7dfd6 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -203,6 +203,7 @@ function generate_rate_function(js::JumpSystem, rate) p = reorder_parameters(js, full_parameters(js)) rf = build_function(rate, unknowns(js), p..., get_iv(js), + wrap_code = wrap_array_vars(js, rate; dvs = unknowns(js), ps = parameters(js)), expression = Val{true}) end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index f854b28737..46b9fbbf76 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -211,12 +211,13 @@ end function generate_jacobian( sys::NonlinearSystem, vs = unknowns(sys), ps = full_parameters(sys); - sparse = false, simplify = false, kwargs...) + sparse = false, simplify = false, wrap_code = identity, kwargs...) jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify) pre, sol_states = get_substitutions_and_solved_unknowns(sys) p = reorder_parameters(sys, ps) + wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs = vs, ps) return build_function( - jac, vs, p...; postprocess_fbody = pre, states = sol_states, kwargs...) + jac, vs, p...; postprocess_fbody = pre, states = sol_states, wrap_code, kwargs...) end function calculate_hessian(sys::NonlinearSystem; sparse = false, simplify = false) @@ -233,22 +234,23 @@ end function generate_hessian( sys::NonlinearSystem, vs = unknowns(sys), ps = full_parameters(sys); - sparse = false, simplify = false, kwargs...) + sparse = false, simplify = false, wrap_code = identity, kwargs...) hess = calculate_hessian(sys, sparse = sparse, simplify = simplify) pre = get_preprocess_constants(hess) p = reorder_parameters(sys, ps) - return build_function(hess, vs, p...; postprocess_fbody = pre, kwargs...) + wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) + return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code, kwargs...) end function generate_function( sys::NonlinearSystem, dvs = unknowns(sys), ps = full_parameters(sys); - kwargs...) + wrap_code = identity, kwargs...) rhss = [deq.rhs for deq in equations(sys)] pre, sol_states = get_substitutions_and_solved_unknowns(sys) - + wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps) p = reorder_parameters(sys, value.(ps)) return build_function(rhss, value.(dvs), p...; postprocess_fbody = pre, - states = sol_states, kwargs...) + states = sol_states, wrap_code, kwargs...) end function jacobian_sparsity(sys::NonlinearSystem) diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl index afb5416aa5..38a3869502 100644 --- a/src/systems/optimization/constraints_system.jl +++ b/src/systems/optimization/constraints_system.jl @@ -166,10 +166,11 @@ end function generate_jacobian( sys::ConstraintsSystem, vs = unknowns(sys), ps = full_parameters(sys); - sparse = false, simplify = false, kwargs...) + sparse = false, simplify = false, wrap_code = identity, kwargs...) jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify) p = reorder_parameters(sys, ps) - return build_function(jac, vs, p...; kwargs...) + wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs = vs, ps) + return build_function(jac, vs, p...; wrap_code, kwargs...) end function calculate_hessian(sys::ConstraintsSystem; sparse = false, simplify = false) @@ -185,20 +186,23 @@ end function generate_hessian( sys::ConstraintsSystem, vs = unknowns(sys), ps = full_parameters(sys); - sparse = false, simplify = false, kwargs...) + sparse = false, simplify = false, wrap_code = identity, kwargs...) hess = calculate_hessian(sys, sparse = sparse, simplify = simplify) p = reorder_parameters(sys, ps) - return build_function(hess, vs, p...; kwargs...) + wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) + return build_function(hess, vs, p...; wrap_code, kwargs...) end function generate_function(sys::ConstraintsSystem, dvs = unknowns(sys), ps = full_parameters(sys); + wrap_code = identity, kwargs...) lhss = generate_canonical_form_lhss(sys) pre, sol_states = get_substitutions_and_solved_unknowns(sys) p = reorder_parameters(sys, value.(ps)) + wrap_code = wrap_code .∘ wrap_array_vars(sys, lhss; dvs, ps) func = build_function(lhss, value.(dvs), p...; postprocess_fbody = pre, - states = sol_states, kwargs...) + states = sol_states, wrap_code, kwargs...) cstr = constraints(sys) lcons = fill(-Inf, length(cstr)) diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 21e82f15cc..6ef4646b6b 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -133,11 +133,13 @@ end function generate_gradient(sys::OptimizationSystem, vs = unknowns(sys), ps = full_parameters(sys); + wrap_code = identity, kwargs...) grad = calculate_gradient(sys) pre = get_preprocess_constants(grad) p = reorder_parameters(sys, ps) - return build_function(grad, vs, p...; postprocess_fbody = pre, + wrap_code = wrap_code .∘ wrap_array_vars(sys, grad; dvs = vs, ps) + return build_function(grad, vs, p...; postprocess_fbody = pre, wrap_code, kwargs...) end @@ -147,7 +149,7 @@ end function generate_hessian( sys::OptimizationSystem, vs = unknowns(sys), ps = full_parameters(sys); - sparse = false, kwargs...) + sparse = false, wrap_code = identity, kwargs...) if sparse hess = sparsehessian(objective(sys), unknowns(sys)) else @@ -155,12 +157,14 @@ function generate_hessian( end pre = get_preprocess_constants(hess) p = reorder_parameters(sys, ps) - return build_function(hess, vs, p...; postprocess_fbody = pre, + wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) + return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code, kwargs...) end function generate_function(sys::OptimizationSystem, vs = unknowns(sys), ps = full_parameters(sys); + wrap_code = identity, kwargs...) eqs = subs_constants(objective(sys)) p = if has_index_cache(sys) @@ -168,7 +172,8 @@ function generate_function(sys::OptimizationSystem, vs = unknowns(sys), else (ps,) end - return build_function(eqs, vs, p...; + wrap_code = wrap_code .∘ wrap_array_vars(sys, eqs; dvs = vs, ps) + return build_function(eqs, vs, p...; wrap_code, kwargs...) end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 43ccdb7e56..60c20c96d6 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -16,7 +16,7 @@ end function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, use_union = false, eval_expression = false, eval_module = @__MODULE__) - ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing + ic::IndexCache = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else error("Cannot create MTKParameters if system does not have index_cache") @@ -98,8 +98,8 @@ function MTKParameters( end end - tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length) - for temp in ic.tunable_buffer_sizes) + tunable_buffer = Vector{ic.tunable_buffer_size.type}( + undef, ic.tunable_buffer_size.length) disc_buffer = SizedArray{Tuple{length(ic.discrete_buffer_sizes)}}([Tuple(Vector{temp.type}( undef, temp.length) @@ -114,8 +114,8 @@ function MTKParameters( function set_value(sym, val) done = true if haskey(ic.tunable_idx, sym) - i, j = ic.tunable_idx[sym] - tunable_buffer[i][j] = val + idx = ic.tunable_idx[sym] + tunable_buffer[idx] = val elseif haskey(ic.discrete_idx, sym) i, j, k = ic.discrete_idx[sym] disc_buffer[i][j][k] = val @@ -157,7 +157,10 @@ function MTKParameters( end end end - tunable_buffer = narrow_buffer_type.(tunable_buffer) + tunable_buffer = narrow_buffer_type(tunable_buffer) + if isempty(tunable_buffer) + tunable_buffer = Float64[] + end disc_buffer = broadcast.(narrow_buffer_type, disc_buffer) const_buffer = narrow_buffer_type.(const_buffer) # Don't narrow nonnumeric types @@ -172,7 +175,7 @@ function MTKParameters( end dep_exprs = identity.(dep_exprs) psyms = reorder_parameters(ic, full_parameters(sys)) - update_fn_exprs = build_function(dep_exprs, psyms..., expression = Val{true}) + update_fn_exprs = build_function(dep_exprs, psyms..., expression = Val{true}, wrap_code = wrap_array_vars(sys, dep_exprs; dvs = nothing)) update_function_oop, update_function_iip = eval_or_rgf.( update_fn_exprs; eval_expression, eval_module) @@ -269,10 +272,42 @@ SciMLStructures.isscimlstructure(::MTKParameters) = true SciMLStructures.ismutablescimlstructure(::MTKParameters) = true -for (Portion, field, recurse) in [(SciMLStructures.Tunable, :tunable, 1) - (SciMLStructures.Discrete, :discrete, 2) - (SciMLStructures.Constants, :constant, 1) - (Nonnumeric, :nonnumeric, 1)] +function SciMLStructures.canonicalize(::SciMLStructures.Tunable, p::MTKParameters) + arr = p.tunable + repack = let p = p + function (new_val) + if new_val !== p.tunable + copyto!(p.tunable, new_val) + end + if p.dependent_update_iip !== nothing + p.dependent_update_iip(ArrayPartition(p.dependent), p...) + end + return p + end + end + return arr, repack, true +end + +function SciMLStructures.replace(::SciMLStructures.Tunable, p::MTKParameters, newvals) + @set! p.tunable = newvals + if p.dependent_update_oop !== nothing + raw = p.dependent_update_oop(p...) + @set! p.dependent = split_into_buffers(raw, p.dependent, Val(false)) + end + return p +end + +function SciMLStructures.replace!(::SciMLStructures.Tunable, p::MTKParameters, newvals) + copyto!(p.tunable, newvals) + if p.dependent_update_iip !== nothing + p.dependent_update_iip(ArrayPartition(p.dependent), p...) + end + return nothing +end + +for (Portion, field, recurse) in [(SciMLStructures.Discrete, :discrete, 2) + (SciMLStructures.Constants, :constant, 1) + (Nonnumeric, :nonnumeric, 1)] @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) as_vector = buffer_to_arraypartition(p.$field) repack = let as_vector = as_vector, p = p @@ -308,7 +343,7 @@ for (Portion, field, recurse) in [(SciMLStructures.Tunable, :tunable, 1) end function Base.copy(p::MTKParameters) - tunable = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.tunable) + tunable = copy(p.tunable) discrete = typeof(p.discrete)([Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in clockbuf) for clockbuf in p.discrete]) constant = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.constant) @@ -327,6 +362,9 @@ end function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex) @unpack portion, idx = pind + if portion isa SciMLStructures.Tunable + return idx isa Int ? p.tunable[idx] : view(p.tunable, idx) + end i, j, k... = idx if portion isa SciMLStructures.Tunable return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...] @@ -347,45 +385,43 @@ end function SymbolicIndexingInterface.set_parameter!( p::MTKParameters, val, idx::ParameterIndex) @unpack portion, idx, validate_size = idx - i, j, k... = idx if portion isa SciMLStructures.Tunable - if isempty(k) - if validate_size && size(val) !== size(p.tunable[i][j]) - throw(InvalidParameterSizeException(size(p.tunable[i][j]), size(val))) - end - p.tunable[i][j] = val - else - p.tunable[i][j][k...] = val + if validate_size && size(val) !== size(idx) + throw(InvalidParameterSizeException(size(idx), size(val))) end - elseif portion isa SciMLStructures.Discrete - k, l... = k - if isempty(l) - if validate_size && size(val) !== size(p.discrete[i][j][k]) - throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val))) + p.tunable[idx] = val + else + i, j, k... = idx + if portion isa SciMLStructures.Discrete + k, l... = k + if isempty(l) + if validate_size && size(val) !== size(p.discrete[i][j][k]) + throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val))) + end + p.discrete[i][j][k] = val + else + p.discrete[i][j][k][l...] = val end - p.discrete[i][j][k] = val - else - p.discrete[i][j][k][l...] = val - end - elseif portion isa SciMLStructures.Constants - if isempty(k) - if validate_size && size(val) !== size(p.constant[i][j]) - throw(InvalidParameterSizeException(size(p.constant[i][j]), size(val))) + elseif portion isa SciMLStructures.Constants + if isempty(k) + if validate_size && size(val) !== size(p.constant[i][j]) + throw(InvalidParameterSizeException(size(p.constant[i][j]), size(val))) + end + p.constant[i][j] = val + else + p.constant[i][j][k...] = val + end + elseif portion === DEPENDENT_PORTION + error("Cannot set value of dependent parameter") + elseif portion === NONNUMERIC_PORTION + if isempty(k) + p.nonnumeric[i][j] = val + else + p.nonnumeric[i][j][k...] = val end - p.constant[i][j] = val - else - p.constant[i][j][k...] = val - end - elseif portion === DEPENDENT_PORTION - error("Cannot set value of dependent parameter") - elseif portion === NONNUMERIC_PORTION - if isempty(k) - p.nonnumeric[i][j] = val else - p.nonnumeric[i][j][k...] = val + error("Unhandled portion $portion") end - else - error("Unhandled portion $portion") end if p.dependent_update_iip !== nothing p.dependent_update_iip(ArrayPartition(p.dependent), p...) @@ -395,41 +431,39 @@ end function _set_parameter_unchecked!( p::MTKParameters, val, idx::ParameterIndex; update_dependent = true) @unpack portion, idx = idx - i, j, k... = idx if portion isa SciMLStructures.Tunable - if isempty(k) - p.tunable[i][j] = val - else - p.tunable[i][j][k...] = val - end - elseif portion isa SciMLStructures.Discrete - k, l... = k - if isempty(l) - p.discrete[i][j][k] = val - else - p.discrete[i][j][k][l...] = val - end - elseif portion isa SciMLStructures.Constants - if isempty(k) - p.constant[i][j] = val - else - p.constant[i][j][k...] = val - end - elseif portion === DEPENDENT_PORTION - if isempty(k) - p.dependent[i][j] = val - else - p.dependent[i][j][k...] = val - end - update_dependent = false - elseif portion === NONNUMERIC_PORTION - if isempty(k) - p.nonnumeric[i][j] = val + p.tunable[idx] = val + else + i, j, k... = idx + if portion isa SciMLStructures.Discrete + k, l... = k + if isempty(l) + p.discrete[i][j][k] = val + else + p.discrete[i][j][k][l...] = val + end + elseif portion isa SciMLStructures.Constants + if isempty(k) + p.constant[i][j] = val + else + p.constant[i][j][k...] = val + end + elseif portion === DEPENDENT_PORTION + if isempty(k) + p.dependent[i][j] = val + else + p.dependent[i][j][k...] = val + end + update_dependent = false + elseif portion === NONNUMERIC_PORTION + if isempty(k) + p.nonnumeric[i][j] = val + else + p.nonnumeric[i][j][k...] = val + end else - p.nonnumeric[i][j][k...] = val + error("Unhandled portion $portion") end - else - error("Unhandled portion $portion") end update_dependent && p.dependent_update_iip !== nothing && p.dependent_update_iip(ArrayPartition(p.dependent), p...) @@ -508,8 +542,7 @@ function indp_to_system(indp) end function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, vals::Dict) - newbuf = @set oldbuf.tunable = Tuple(Vector{Any}(undef, length(buf)) - for buf in oldbuf.tunable) + newbuf = @set oldbuf.tunable = Vector{Any}(undef, length(oldbuf.tunable)) @set! newbuf.discrete = SizedVector{length(newbuf.discrete)}([Tuple(Vector{Any}(undef, length(buf)) for buf in clockbuf) @@ -553,7 +586,7 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va end end - @set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.( + @set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs( oldbuf.tunable, newbuf.tunable) @set! newbuf.discrete = SizedVector{length(newbuf.discrete)}([narrow_buffer_type_and_fallback_undefs.( oldclockbuf, @@ -644,8 +677,8 @@ _num_subarrays(v::Tuple) = length(v) function Base.getindex(buf::MTKParameters, i) i_orig = i if !isempty(buf.tunable) - i <= _num_subarrays(buf.tunable) && return _subarrays(buf.tunable)[i] - i -= _num_subarrays(buf.tunable) + i == 1 && return buf.tunable + i -= 1 end if !isempty(buf.discrete) for clockbuf in buf.discrete @@ -667,37 +700,13 @@ function Base.getindex(buf::MTKParameters, i) end throw(BoundsError(buf, i_orig)) end -function Base.setindex!(p::MTKParameters, val, i) - function _helper(buf) - done = false - for v in buf - if i <= length(v) - v[i] = val - done = true - else - i -= length(v) - end - end - done - end - _helper(p.tunable) || _helper(Iterators.flatten(p.discrete)) || _helper(p.constant) || - _helper(p.nonnumeric) || throw(BoundsError(p, i)) - if p.dependent_update_iip !== nothing - p.dependent_update_iip(ArrayPartition(p.dependent), p...) - end -end -function Base.getindex(p::MTKParameters, pind::ParameterIndex) - parameter_values(p, pind) -end +Base.getindex(p::MTKParameters, pind::ParameterIndex) = parameter_values(p, pind) -function Base.setindex!(p::MTKParameters, val, pind::ParameterIndex) - SymbolicIndexingInterface.set_parameter!(p, val, pind) -end +Base.setindex!(p::MTKParameters, val, pind::ParameterIndex) = set_parameter!(p, val, pind) function Base.iterate(buf::MTKParameters, state = 1) - total_len = 0 - total_len += _num_subarrays(buf.tunable) + total_len = Int(!isempty(buf.tunable)) # for tunables for clockbuf in buf.discrete total_len += _num_subarrays(clockbuf) end diff --git a/test/jumpsystem.jl b/test/jumpsystem.jl index 827dc6a01b..11c9fc1cd9 100644 --- a/test/jumpsystem.jl +++ b/test/jumpsystem.jl @@ -194,8 +194,8 @@ jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false), rng = pcondit(u, t, integrator) = t == 1000.0 function paffect!(integrator) - integrator.p[1] = 0.0 - integrator.p[2] = 1.0 + integrator.ps[k1] = 0.0 + integrator.ps[k2] = 1.0 reset_aggregated_jumps!(integrator) end sol = solve(jprob, SSAStepper(), tstops = [1000.0], diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 01011828ab..31a41376e8 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -194,23 +194,26 @@ connections = [[state_feedback.input.u[i] ~ model_outputs[i] for i in 1:4] S = get_sensitivity(closed_loop, :u) @testset "Indexing MTKParameters with ParameterIndex" begin - ps = MTKParameters(([1.0, 2.0], [3, 4]), + ps = MTKParameters(collect(1.0:10.0), SizedVector{2}([([true, false], [[1 2; 3 4]]), ([false, true], [[2 4; 6 8]])]), ([5, 6],), ([7.0, 8.0],), (["hi", "bye"], [:lie, :die]), nothing, nothing) - @test ps[ParameterIndex(Tunable(), (1, 2))] === 2.0 - @test ps[ParameterIndex(Tunable(), (2, 2))] === 4 - @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] === 4 + @test ps[ParameterIndex(Tunable(), 1)] == 1.0 + @test ps[ParameterIndex(Tunable(), 2:4)] == collect(2.0:4.0) + @test ps[ParameterIndex(Tunable(), reshape(4:7, 2, 2))] == reshape(4.0:7.0, 2, 2) + @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] == 4 @test ps[ParameterIndex(Discrete(), (2, 2, 1))] == [2 4; 6 8] - @test ps[ParameterIndex(Constants(), (1, 1))] === 5 - @test ps[ParameterIndex(DEPENDENT_PORTION, (1, 1))] === 7.0 - @test ps[ParameterIndex(NONNUMERIC_PORTION, (2, 2))] === :die + @test ps[ParameterIndex(Constants(), (1, 1))] == 5 + @test ps[ParameterIndex(DEPENDENT_PORTION, (1, 1))] == 7.0 + @test ps[ParameterIndex(NONNUMERIC_PORTION, (2, 2))] == :die - ps[ParameterIndex(Tunable(), (1, 2))] = 3.0 + ps[ParameterIndex(Tunable(), 1)] = 1.5 + ps[ParameterIndex(Tunable(), 2:4)] = [2.5, 3.5, 4.5] + ps[ParameterIndex(Tunable(), reshape(5:8, 2, 2))] = [5.5 7.5; 6.5 8.5] ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] = 5 - @test ps[ParameterIndex(Tunable(), (1, 2))] === 3.0 - @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] === 5 + @test ps[ParameterIndex(Tunable(), 1:8)] == collect(1.0:8.0) .+ 0.5 + @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] == 5 end