Skip to content

Commit

Permalink
Merge pull request #2929 from AayushSabharwal/as/mtkparams-methods
Browse files Browse the repository at this point in the history
fix: several MTKParameters fixes
  • Loading branch information
ChrisRackauckas authored Oct 23, 2024
2 parents 57dcc7e + 94f1c3d commit 1f2d943
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
20 changes: 9 additions & 11 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function MTKParameters(
end
p = Dict()
missing_params = Set()
pdeps = has_parameter_dependencies(sys) ? parameter_dependencies(sys) : nothing
pdeps = has_parameter_dependencies(sys) ? parameter_dependencies(sys) : []

for sym in all_ps
ttsym = default_toterm(sym)
Expand Down Expand Up @@ -92,7 +92,7 @@ function MTKParameters(
delete!(missing_params, ttsym)
end

if pdeps !== nothing
if !isempty(pdeps)
for eq in pdeps
sym = eq.lhs
expr = eq.rhs
Expand Down Expand Up @@ -279,10 +279,7 @@ function SciMLStructures.canonicalize(::SciMLStructures.Tunable, p::MTKParameter
arr = p.tunable
repack = let p = p
function (new_val)
if new_val !== p.tunable
copyto!(p.tunable, new_val)
end
return p
return SciMLStructures.replace(SciMLStructures.Tunable(), p, new_val)
end
end
return arr, repack, true
Expand All @@ -303,12 +300,9 @@ for (Portion, field, recurse) in [(SciMLStructures.Discrete, :discrete, 1)
(Nonnumeric, :nonnumeric, 1)]
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
as_vector = buffer_to_arraypartition(p.$field)
repack = let as_vector = as_vector, p = p
repack = let p = p
function (new_val)
if new_val !== as_vector
update_tuple_of_buffers(new_val, p.$field)
end
p
return SciMLStructures.replace(($Portion)(), p, new_val)
end
end
return as_vector, repack, true
Expand Down Expand Up @@ -678,6 +672,10 @@ end
return len
end

Base.size(ps::MTKParameters) = (length(ps),)

Base.IndexStyle(::Type{T}) where {T <: MTKParameters} = IndexLinear()

Base.getindex(p::MTKParameters, pind::ParameterIndex) = parameter_values(p, pind)

Base.setindex!(p::MTKParameters, val, pind::ParameterIndex) = set_parameter!(p, val, pind)
Expand Down
2 changes: 1 addition & 1 deletion test/mtkparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ for (portion, values) in [(Tunable(), [1.0, 5.0, 6.0, 7.0])
SciMLStructures.replace!(portion, ps, ones(length(buffer)))
# make sure it is in-place
@test all(isone, canonicalize(portion, ps)[1])
repack(zeros(length(buffer)))
global ps = repack(zeros(length(buffer)))
@test all(iszero, canonicalize(portion, ps)[1])
end

Expand Down

0 comments on commit 1f2d943

Please sign in to comment.