Skip to content

Commit

Permalink
Support typed arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
contradict committed Aug 1, 2024
1 parent 8ab6bfe commit 914a209
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
Expr(:(::), a,
Expr(:curly, :Union, :Nothing, Expr(:curly, :AbstractArray, vartype))),
nothing))
push!(where_types, :($vartype <: $type))
if !isnothing(meta) && haskey(meta, VariableUnit)
push!(where_types, vartype)
else
push!(where_types, :($vartype <: $type))
end
dict[:kwargs][getname(var)] = Dict(:value => def, :type => AbstractArray{type})
end
if dict[varclass] isa Vector
Expand Down Expand Up @@ -624,10 +628,20 @@ function convert_units(varunits::DynamicQuantities.Quantity, value)
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end

function convert_units(varunits::DynamicQuantities.Quantity, value::AbstractArray{T}) where T
DynamicQuantities.ustrip.(DynamicQuantities.uconvert.(
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end

function convert_units(varunits::Unitful.FreeUnits, value)
Unitful.ustrip(varunits, value)
end

function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where T
Unitful.ustrip.(varunits, value)
end


function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
vv, def, metadata_with_exprs = parse_variable_def!(
dict, mod, arg, varclass, kwargs, where_types)
Expand Down
11 changes: 11 additions & 0 deletions test/dq_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,14 @@ end

@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")

@mtkmodel ArrayParamTest begin
@parameters begin
a[1:2], [unit = u"m"]
end
end

@named sys = ArrayParamTest()

@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]
11 changes: 11 additions & 0 deletions test/units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,14 @@ end

@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")

@mtkmodel ArrayParamTest begin
@parameters begin
a[1:2], [unit = u"m"]
end
end

@named sys = ArrayParamTest()

@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]

0 comments on commit 914a209

Please sign in to comment.