From b5be3e44e0cbca20761bc0eb387b19c394f9ca70 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Jul 2024 17:29:38 +0530 Subject: [PATCH] fixup! refactor: turn tunables portion into a Vector{T} --- src/systems/abstractsystem.jl | 12 ++++++++---- src/systems/index_cache.jl | 8 ++++---- src/systems/parameter_buffer.jl | 10 ++++++---- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index f1b1c72dd0..2022f5d5d0 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -260,7 +260,8 @@ function wrap_array_vars( Let( vcat( [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], - [k ← :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars] + [k ← :(view($(expr.args[uind + 1].name), $v)) + for (k, v) in array_pars] ), expr.body, false @@ -275,7 +276,8 @@ function wrap_array_vars( Let( vcat( [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], - [k ← :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars] + [k ← :(view($(expr.args[uind + 1].name), $v)) + for (k, v) in array_pars] ), expr.body, false @@ -288,8 +290,10 @@ function wrap_array_vars( [], Let( vcat( - [k ← :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_vars], - [k ← :(view($(expr.args[uind+2].name), $v)) for (k, v) in array_pars] + [k ← :(view($(expr.args[uind + 1].name), $v)) + for (k, v) in array_vars], + [k ← :(view($(expr.args[uind + 2].name), $v)) + for (k, v) in array_pars] ), expr.body, false diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 6979cb26d3..ac82d8257f 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -239,9 +239,9 @@ function IndexCache(sys::AbstractSystem) insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown && - (ctype == Real || ctype <: AbstractFloat || - ctype <: AbstractArray{Real} || - ctype <: AbstractArray{<:AbstractFloat}) + (ctype == Real || ctype <: AbstractFloat || + ctype <: AbstractArray{Real} || + ctype <: AbstractArray{<:AbstractFloat}) tunable_buffers else constant_buffers @@ -450,7 +450,7 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false) () else (BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(ic.tunable_buffer_size.length)],) + for _ in 1:(ic.tunable_buffer_size.length)],) end disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in Iterators.flatten(ic.discrete_buffer_sizes)) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 60c20c96d6..b7187bcd5b 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -175,7 +175,8 @@ function MTKParameters( end dep_exprs = identity.(dep_exprs) psyms = reorder_parameters(ic, full_parameters(sys)) - update_fn_exprs = build_function(dep_exprs, psyms..., expression = Val{true}, wrap_code = wrap_array_vars(sys, dep_exprs; dvs = nothing)) + update_fn_exprs = build_function(dep_exprs, psyms..., expression = Val{true}, + wrap_code = wrap_array_vars(sys, dep_exprs; dvs = nothing)) update_function_oop, update_function_iip = eval_or_rgf.( update_fn_exprs; eval_expression, eval_module) @@ -306,8 +307,8 @@ function SciMLStructures.replace!(::SciMLStructures.Tunable, p::MTKParameters, n end for (Portion, field, recurse) in [(SciMLStructures.Discrete, :discrete, 2) - (SciMLStructures.Constants, :constant, 1) - (Nonnumeric, :nonnumeric, 1)] + (SciMLStructures.Constants, :constant, 1) + (Nonnumeric, :nonnumeric, 1)] @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) as_vector = buffer_to_arraypartition(p.$field) repack = let as_vector = as_vector, p = p @@ -396,7 +397,8 @@ function SymbolicIndexingInterface.set_parameter!( k, l... = k if isempty(l) if validate_size && size(val) !== size(p.discrete[i][j][k]) - throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val))) + throw(InvalidParameterSizeException( + size(p.discrete[i][j][k]), size(val))) end p.discrete[i][j][k] = val else