Skip to content

Commit

Permalink
feat: support symbolic arrays in @mtkmodels
Browse files Browse the repository at this point in the history
  • Loading branch information
ven-k committed Oct 10, 2023
1 parent d4239a7 commit 5074956
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ using PrecompileTools, Reexport
using Symbolics: _parse_vars, value, @derivatives, get_variables,
exprs_occur_in, solve_for, build_expr, unwrap, wrap,
VariableSource, getname, variable, Connection, connect,
NAMESPACE_SEPARATOR
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
tosymbol, lower_varname, diff2term, var_from_nested_derivative,
Expand Down
53 changes: 29 additions & 24 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function _model_macro(mod, name, expr, isconnector)
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
GUIMetadata(GlobalRef(mod, name))

sys = :($ODESystem($Equation[$(eqs...)], $iv, [$(vs...)], [$(ps...)];
sys = :($ODESystem($Equation[$(eqs...)], $iv, [$(vs...)], [$(ps...)];

Check warning on line 70 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L70

Added line #L70 was not covered by tests
name, systems = [$(comps...)], gui_metadata = $gui_metadata))

if ext[] === nothing
Expand All @@ -83,7 +83,7 @@ function _model_macro(mod, name, expr, isconnector)
:($name = $Model($f, $dict, $isconnector))
end

function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
function parse_variable_def!(dict, mod, arg, varclass, kwargs; def = nothing, indices = 1:1)

Check warning on line 86 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L86

Added line #L86 was not covered by tests
metatypes = [(:connection_type, VariableConnectType),
(:description, VariableDescription),
(:unit, VariableUnit),
Expand All @@ -104,20 +104,20 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
MLStyle.@match arg begin
a::Symbol => begin
push!(kwargs, Expr(:kw, a, nothing))
var = generate_var!(dict, a, varclass)
var = generate_var!(dict, a, varclass; indices)

Check warning on line 107 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L107

Added line #L107 was not covered by tests
dict[:kwargs][getname(var)] = def
(var, nothing)
(var, def)

Check warning on line 109 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L109

Added line #L109 was not covered by tests
end
Expr(:call, a, b) => begin
push!(kwargs, Expr(:kw, a, nothing))
var = generate_var!(dict, a, b, varclass)
var = generate_var!(dict, a, b, varclass; indices)

Check warning on line 113 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L113

Added line #L113 was not covered by tests
dict[:kwargs][getname(var)] = def
(var, nothing)
(var, def)

Check warning on line 115 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L115

Added line #L115 was not covered by tests
end
Expr(:(=), a, b) => begin
Base.remove_linenums!(b)
def, meta = parse_default(mod, b)
var, _ = parse_variable_def!(dict, mod, a, varclass, kwargs, def)
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs; def)

Check warning on line 120 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L120

Added line #L120 was not covered by tests
dict[varclass][getname(var)][:default] = def
if meta !== nothing
for (type, key) in metatypes
Expand All @@ -142,31 +142,33 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
end
var = set_var_metadata(var, meta)
end
(set_var_metadata(var, meta), def)
(var, def)

Check warning on line 145 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L145

Added line #L145 was not covered by tests
end
Expr(:ref, a, b) => begin
parse_variable_def!(dict, mod, a, varclass, kwargs; def, indices = eval(b))

Check warning on line 148 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L147-L148

Added lines #L147 - L148 were not covered by tests
end
_ => error("$arg cannot be parsed")
end
end

function generate_var(a, varclass)
var = Symbolics.variable(a)

function generate_var(a, varclass; indices = 1:1)
var = length(indices) > 1 ? (@variables $a[indices]) : Symbolics.variable(a)

Check warning on line 156 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L155-L156

Added lines #L155 - L156 were not covered by tests
if varclass == :parameters
var = toparam(var)
end
var
var isa Vector ? var[1] : var

Check warning on line 160 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L160

Added line #L160 was not covered by tests
end

function generate_var!(dict, a, varclass)
#var = generate_var(Symbol("#", a), varclass)
var = generate_var(a, varclass)
function generate_var!(dict, a, varclass; indices::UnitRange = 1:1)

Check warning on line 163 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L163

Added line #L163 was not covered by tests
vd = get!(dict, varclass) do
Dict{Symbol, Dict{Symbol, Any}}()
end
vd[a] = Dict{Symbol, Any}()
var
generate_var(a, varclass; indices)

Check warning on line 168 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L168

Added line #L168 was not covered by tests
end

function generate_var!(dict, a, b, varclass)
function generate_var!(dict, a, b, varclass; indices::UnitRange = 1:1)

Check warning on line 171 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L171

Added line #L171 was not covered by tests
iv = generate_var(b, :variables)
prev_iv = get!(dict, :independent_variable) do
iv
Expand All @@ -176,11 +178,15 @@ function generate_var!(dict, a, b, varclass)
Dict{Symbol, Dict{Symbol, Any}}()
end
vd[a] = Dict{Symbol, Any}()
var = Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)
var = if length(indices) > 1
@variables $a(iv)[indices]

Check warning on line 182 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
else
Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)

Check warning on line 184 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L184

Added line #L184 was not covered by tests
end
if varclass == :parameters
var = toparam(var)
end
var
var isa Vector ? var[1] : var

Check warning on line 189 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L189

Added line #L189 was not covered by tests
end

function parse_default(mod, a)
Expand Down Expand Up @@ -215,7 +221,7 @@ end

function set_var_metadata(a, ms)
for (m, v) in ms
a = setmetadata(a, m, v)
a = wrap(set_scalar_metadata(unwrap(a), m, v))

Check warning on line 224 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L224

Added line #L224 was not covered by tests
end
a
end
Expand Down Expand Up @@ -433,11 +439,10 @@ end

function parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs)
v = Num(vv)
name = getname(v)
push!(vs, name)
push!(expr.args,
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name)))
vv = wrap(setdefaultval(unwrap(vv), def))
name = getname(vv)
push!(expr.args, :($name = $name === nothing ? $vv : $wrap($setdefaultval($unwrap($vv), $name))))
vv isa Num ? push!(vs, name) : push!(vs, :($name...))

Check warning on line 445 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L442-L445

Added lines #L442 - L445 were not covered by tests
end

function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
Expand Down
15 changes: 14 additions & 1 deletion test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,20 @@ resistor = getproperty(rc, :resistor; namespace = false)
@mtkmodel MockModel begin
@parameters begin
a
a2[1:2]
b(t)
b2(t)[1:2]
cval
jval
kval
c(t) = cval + jval
d = 2
d2[1:2] = 2
e, [description = "e"]
e2[1:2], [description = "e2"]
f = 3, [description = "f"]
h(t), [description = "h(t)"]
h2(t)[1:2], [description = "h2(t)"]
i(t) = 4, [description = "i(t)"]
j(t) = jval, [description = "j(t)"]
k = kval, [description = "k"]
Expand All @@ -195,7 +200,13 @@ resistor = getproperty(rc, :resistor; namespace = false)
end

kval = 5
@named model = MockModel(; kval, cval = 1, func = identity)
@named model = MockModel(; b2 = 3, kval, cval = 1, func = identity)

@test lastindex(parameters(model)) == 23
@test all(lastindex.([model.a2, model.b2, model.d2, model.e2, model.h2]) .== 2)
@test all(getdefault.([model.b2...]) .== 3)
@test all(getdescription.([model.e2...]) .== "e2")
@test all(getdescription.([model.h2...]) .== "h2(t)")

@test hasmetadata(model.e, VariableDescription)
@test hasmetadata(model.f, VariableDescription)
Expand All @@ -211,6 +222,8 @@ resistor = getproperty(rc, :resistor; namespace = false)
@test_throws KeyError getdefault(model.e)
@test getdefault(model.f) == 3
@test getdefault(model.i) == 4
@test getdefault(model.b2[1]) == 3
@test getdefault(model.b2[2]) == 3
@test isequal(getdefault(model.j), model.jval)
@test isequal(getdefault(model.k), model.kval)
end
Expand Down

0 comments on commit 5074956

Please sign in to comment.