Skip to content

Commit

Permalink
feat: support symbolic arrays in @mtkmodels
Browse files Browse the repository at this point in the history
  • Loading branch information
ven-k committed Sep 25, 2023
1 parent 1aa4806 commit a00a00c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 24 deletions.
58 changes: 35 additions & 23 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function _model_macro(mod, name, expr, isconnector)
:($name = $Model($f, $dict, $isconnector))
end

function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
function parse_variable_def!(dict, mod, arg, varclass, kwargs; def = nothing, indices = 1:1)

Check warning on line 86 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L86

Added line #L86 was not covered by tests
metatypes = [(:connection_type, VariableConnectType),
(:description, VariableDescription),
(:unit, VariableUnit),
Expand All @@ -104,20 +104,21 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
MLStyle.@match arg begin
a::Symbol => begin
push!(kwargs, Expr(:kw, a, nothing))
var = generate_var!(dict, a, varclass)
var = generate_var!(dict, a, varclass; indices)

Check warning on line 107 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L107

Added line #L107 was not covered by tests
dict[:kwargs][getname(var)] = def
(var, nothing)
(var, def)

Check warning on line 109 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L109

Added line #L109 was not covered by tests
end
Expr(:call, a, b) => begin
push!(kwargs, Expr(:kw, a, nothing))
var = generate_var!(dict, a, b, varclass)
var = generate_var!(dict, a, b, varclass; indices)

Check warning on line 113 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L113

Added line #L113 was not covered by tests
dict[:kwargs][getname(var)] = def
(var, nothing)
(var, def)

Check warning on line 115 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L115

Added line #L115 was not covered by tests
end
Expr(:(=), a, b) => begin
Base.remove_linenums!(b)
def, meta = parse_default(mod, b)
var, _ = parse_variable_def!(dict, mod, a, varclass, kwargs, def)
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs; def)
@info 128 var def meta

Check warning on line 121 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L120-L121

Added lines #L120 - L121 were not covered by tests
dict[varclass][getname(var)][:default] = def
if meta !== nothing
for (type, key) in metatypes
Expand All @@ -142,31 +143,33 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
end
var = set_var_metadata(var, meta)
end
(set_var_metadata(var, meta), def)
(var, def)

Check warning on line 146 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L146

Added line #L146 was not covered by tests
end
Expr(:ref, a, b) => begin
parse_variable_def!(dict, mod, a, varclass, kwargs; def, indices = eval(b))

Check warning on line 149 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L148-L149

Added lines #L148 - L149 were not covered by tests
end
_ => error("$arg cannot be parsed")
end
end

function generate_var(a, varclass)
var = Symbolics.variable(a)

function generate_var(a, varclass; indices = 1:1)
var = length(indices) > 1 ? (@variables $a[indices]) : Symbolics.variable(a)

Check warning on line 157 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L156-L157

Added lines #L156 - L157 were not covered by tests
if varclass == :parameters
var = toparam(var)
end
var
var isa Vector ? var[1] : var

Check warning on line 161 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L161

Added line #L161 was not covered by tests
end

function generate_var!(dict, a, varclass)
#var = generate_var(Symbol("#", a), varclass)
var = generate_var(a, varclass)
function generate_var!(dict, a, varclass; indices::UnitRange = 1:1)

Check warning on line 164 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L164

Added line #L164 was not covered by tests
vd = get!(dict, varclass) do
Dict{Symbol, Dict{Symbol, Any}}()
end
vd[a] = Dict{Symbol, Any}()
var
generate_var(a, varclass; indices)

Check warning on line 169 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L169

Added line #L169 was not covered by tests
end

function generate_var!(dict, a, b, varclass)
function generate_var!(dict, a, b, varclass; indices::UnitRange = 1:1)

Check warning on line 172 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L172

Added line #L172 was not covered by tests
iv = generate_var(b, :variables)
prev_iv = get!(dict, :independent_variable) do
iv
Expand All @@ -176,11 +179,15 @@ function generate_var!(dict, a, b, varclass)
Dict{Symbol, Dict{Symbol, Any}}()
end
vd[a] = Dict{Symbol, Any}()
var = Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)
var = if length(indices) > 1
@variables $a(iv)[indices]

Check warning on line 183 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
else
Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)

Check warning on line 185 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L185

Added line #L185 was not covered by tests
end
if varclass == :parameters
var = toparam(var)
end
var
var isa Vector ? var[1] : var

Check warning on line 190 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L190

Added line #L190 was not covered by tests
end

function parse_default(mod, a)
Expand Down Expand Up @@ -215,7 +222,7 @@ end

function set_var_metadata(a, ms)
for (m, v) in ms
a = setmetadata(a, m, v)
a = a isa Symbolics.Arr ? collect(setmetadata.(a, m, v)) : setmetadata(a, m, v)

Check warning on line 225 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L225

Added line #L225 was not covered by tests
end
a
end
Expand Down Expand Up @@ -406,11 +413,16 @@ end

function parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs)
v = Num(vv)
name = getname(v)
push!(vs, name)
push!(expr.args,
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name)))
if !(vv isa Union{Symbolics.Arr, Vector})
vv = setdefault(vv, def)
name = getname(Num(vv))
push!(expr.args, :($name = $name === nothing ? $vv : $setdefault($vv, $name)))
push!(vs, name)

Check warning on line 420 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L416-L420

Added lines #L416 - L420 were not covered by tests
else
name = vv isa Vector ? getname(vv[1]) : getname(vv)
push!(expr.args, :($name = $name === nothing ? $setdefault.($vv, $def) : $setdefault.($vv, $name)))
push!(vs, :($name...))

Check warning on line 424 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L422-L424

Added lines #L422 - L424 were not covered by tests
end
end

function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
Expand Down
13 changes: 12 additions & 1 deletion test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,19 @@ end
@mtkmodel MockModel begin
@parameters begin
a
a2[1:2]
b(t)
b2(t)[1:2]
cval
jval
kval
c(t) = cval + jval
d = 2
e, [description = "e"]
e2[1:2], [description = "e2"]
f = 3, [description = "f"]
h(t), [description = "h(t)"]
h2(t)[1:2], [description = "h(t)"]
i(t) = 4, [description = "i(t)"]
j(t) = jval, [description = "j(t)"]
k = kval, [description = "k"]
Expand All @@ -193,7 +197,12 @@ end
end

kval = 5
@named model = MockModel(; kval, cval = 1, func = identity)
@named model = MockModel(; b2 = 3, kval, cval = 1, func = identity)

@test lastindex(parameters(model)) == 21
@test lastindex(model.b2) == 2
@test getdefault.(model.b2) .== 3
@test getmetadata(model.e2[1], VariableDescription) == getmetadata(model.e2[2], VariableDescription) == "e2"

@test hasmetadata(model.e, VariableDescription)
@test hasmetadata(model.f, VariableDescription)
Expand All @@ -209,6 +218,8 @@ end
@test_throws KeyError getdefault(model.e)
@test getdefault(model.f) == 3
@test getdefault(model.i) == 4
@test getdefault(model.b2[1]) == 3
@test getdefault(model.b2[2]) == 3
@test isequal(getdefault(model.j), model.jval)
@test isequal(getdefault(model.k), model.kval)
end
Expand Down

0 comments on commit a00a00c

Please sign in to comment.