From ccc6935378e52a5221ba7abf4bead6a62a641a80 Mon Sep 17 00:00:00 2001 From: Venkateshprasad <32921645+ven-k@users.noreply.github.com> Date: Mon, 25 Sep 2023 16:10:57 +0530 Subject: [PATCH] feat: support symbolic arrays in `@mtkmodel`s --- src/ModelingToolkit.jl | 2 +- src/systems/model_parsing.jl | 46 ++++++++++++++++++++---------------- src/utils.jl | 2 +- test/model_parsing.jl | 15 +++++++++++- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index de63a67837..99bf0b015b 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -56,7 +56,7 @@ using PrecompileTools, Reexport using Symbolics: _parse_vars, value, @derivatives, get_variables, exprs_occur_in, solve_for, build_expr, unwrap, wrap, VariableSource, getname, variable, Connection, connect, - NAMESPACE_SEPARATOR + NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, tosymbol, lower_varname, diff2term, var_from_nested_derivative, diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 299793c236..7de82dd7fa 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -83,7 +83,7 @@ function _model_macro(mod, name, expr, isconnector) :($name = $Model($f, $dict, $isconnector)) end -function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing) +function parse_variable_def!(dict, mod, arg, varclass, kwargs; def = nothing, indices = 1:1) metatypes = [(:connection_type, VariableConnectType), (:description, VariableDescription), (:unit, VariableUnit), @@ -104,20 +104,20 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing) MLStyle.@match arg begin a::Symbol => begin push!(kwargs, Expr(:kw, a, nothing)) - var = generate_var!(dict, a, varclass) + var = generate_var!(dict, a, varclass; indices) dict[:kwargs][getname(var)] = def - (var, nothing) + (var, def) end Expr(:call, a, b) => begin push!(kwargs, Expr(:kw, a, nothing)) - var = generate_var!(dict, a, b, varclass) + var = generate_var!(dict, a, b, varclass; indices) dict[:kwargs][getname(var)] = def - (var, nothing) + (var, def) end Expr(:(=), a, b) => begin Base.remove_linenums!(b) def, meta = parse_default(mod, b) - var, _ = parse_variable_def!(dict, mod, a, varclass, kwargs, def) + var, def = parse_variable_def!(dict, mod, a, varclass, kwargs; def) dict[varclass][getname(var)][:default] = def if meta !== nothing for (type, key) in metatypes @@ -142,31 +142,32 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing) end var = set_var_metadata(var, meta) end - (set_var_metadata(var, meta), def) + (var, def) + end + Expr(:ref, a, b) => begin + parse_variable_def!(dict, mod, a, varclass, kwargs; def, indices = eval(b)) end _ => error("$arg cannot be parsed") end end -function generate_var(a, varclass) - var = Symbolics.variable(a) +function generate_var(a, varclass; indices = 1:1) + var = lastindex(indices) > 1 ? first(@variables $a[indices]) : Symbolics.variable(a) if varclass == :parameters var = toparam(var) end var end -function generate_var!(dict, a, varclass) - #var = generate_var(Symbol("#", a), varclass) - var = generate_var(a, varclass) +function generate_var!(dict, a, varclass; indices::UnitRange = 1:1) vd = get!(dict, varclass) do Dict{Symbol, Dict{Symbol, Any}}() end vd[a] = Dict{Symbol, Any}() - var + generate_var(a, varclass; indices) end -function generate_var!(dict, a, b, varclass) +function generate_var!(dict, a, b, varclass; indices::UnitRange = 1:1) iv = generate_var(b, :variables) prev_iv = get!(dict, :independent_variable) do iv @@ -176,7 +177,11 @@ function generate_var!(dict, a, b, varclass) Dict{Symbol, Dict{Symbol, Any}}() end vd[a] = Dict{Symbol, Any}() - var = Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv) + var = if lastindex(indices) > 1 + first(@variables $a(iv)[indices]) + else + Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv) + end if varclass == :parameters var = toparam(var) end @@ -215,7 +220,7 @@ end function set_var_metadata(a, ms) for (m, v) in ms - a = setmetadata(a, m, v) + a = wrap(set_scalar_metadata(unwrap(a), m, v)) end a end @@ -433,11 +438,12 @@ end function parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs) vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs) - v = Num(vv) - name = getname(v) - push!(vs, name) + name = getname(vv) push!(expr.args, - :($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name))) + :($name = $name === nothing ? + $setdefault($vv, $def) : + $setdefault($vv, $name))) + vv isa Num ? push!(vs, name) : push!(vs, :($name...)) end function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs) diff --git a/src/utils.jl b/src/utils.jl index ab17bd3d8d..f66ed48772 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -224,7 +224,7 @@ function getdefaulttype(v) def === nothing ? Float64 : typeof(def) end function setdefault(v, val) - val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val)) + val === nothing ? v : wrap(setdefaultval(unwrap(v), value(val))) end function process_variables!(var_to_name, defs, vars) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 701e785928..5b432ae99a 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -175,15 +175,20 @@ resistor = getproperty(rc, :resistor; namespace = false) @mtkmodel MockModel begin @parameters begin a + a2[1:2] b(t) + b2(t)[1:2] cval jval kval c(t) = cval + jval d = 2 + d2[1:2] = 2 e, [description = "e"] + e2[1:2], [description = "e2"] f = 3, [description = "f"] h(t), [description = "h(t)"] + h2(t)[1:2], [description = "h2(t)"] i(t) = 4, [description = "i(t)"] j(t) = jval, [description = "j(t)"] k = kval, [description = "k"] @@ -195,7 +200,13 @@ resistor = getproperty(rc, :resistor; namespace = false) end kval = 5 - @named model = MockModel(; kval, cval = 1, func = identity) + @named model = MockModel(; b2 = 3, kval, cval = 1, func = identity) + + @test lastindex(parameters(model)) == 23 + @test all(lastindex.([model.a2, model.b2, model.d2, model.e2, model.h2]) .== 2) + @test all(getdefault.([model.b2...]) .== 3) + @test all(getdescription.([model.e2...]) .== "e2") + @test all(getdescription.([model.h2...]) .== "h2(t)") @test hasmetadata(model.e, VariableDescription) @test hasmetadata(model.f, VariableDescription) @@ -211,6 +222,8 @@ resistor = getproperty(rc, :resistor; namespace = false) @test_throws KeyError getdefault(model.e) @test getdefault(model.f) == 3 @test getdefault(model.i) == 4 + @test getdefault(model.b2[1]) == 3 + @test getdefault(model.b2[2]) == 3 @test isequal(getdefault(model.j), model.jval) @test isequal(getdefault(model.k), model.kval) end