Skip to content

Commit

Permalink
Merge pull request #2368 from ven-k/vkb/component-arrays
Browse files Browse the repository at this point in the history
Add support for component array in `@mtkmodel`
  • Loading branch information
ChrisRackauckas authored Mar 4, 2024
2 parents 2a28c4d + e324633 commit a64aad8
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 54 deletions.
16 changes: 11 additions & 5 deletions docs/src/basics/MTKModel_Connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ end
end
@structural_parameters begin
f = sin
N = 2
end
begin
v_var = 1.0
Expand All @@ -69,6 +70,11 @@ end
@extend ModelB(; p1)
@components begin
model_a = ModelA(; k_array)
model_array_a = [ModelA(; k = i) for i in 1:N]
model_array_b = for i in 1:N
k = i^2
ModelA(; k)
end
end
@equations begin
model_a.k ~ f(v)
Expand Down Expand Up @@ -146,6 +152,7 @@ julia> @mtkbuild model_c2 = ModelC(; p1 = 2.0)
#### `@components` begin block

- Declare the subcomponents within `@components` begin block.
- Array of components can be declared with a for loop or a list comprehension.
- The arguments in these subcomponents are promoted as keyword arguments as `subcomponent_name__argname` with `nothing` as default value.
- Whenever components are created with `@named` macro, these can be accessed with `.` operator as `subcomponent_name.argname`
- In the above example, as `k` of `model_a` isn't listed while defining the sub-component in `ModelC`, its default value can't be modified by users. While `k_array` can be set as:
Expand Down Expand Up @@ -247,14 +254,13 @@ For example, the structure of `ModelC` is:
```julia
julia> ModelC.structure
Dict{Symbol, Any} with 9 entries:
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA]]
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)))
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :v=>Dict{Symbol, Union{Nothing, Symbol}}(:value=>:v_var, :type=>Real), :v_array=>Dict(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin))
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :p1=>Dict(:value=>nothing))
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2))
:independent_variable => t
:constants => Dict{Symbol, Dict}(:c=>Dict(:value=>1))
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
:equations => Any["model_a.k ~ f(v)"]
```
Expand Down
7 changes: 5 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ function _named(name, call, runtime = false)
end
end

function _named_idxs(name::Symbol, idxs, call)
function _named_idxs(name::Symbol, idxs, call; extra_args = "")
if call.head !== :->
throw(ArgumentError("Not an anonymous function"))
end
Expand All @@ -1402,7 +1402,10 @@ function _named_idxs(name::Symbol, idxs, call)
ex = Base.Cartesian.poplinenum(ex)
ex = _named(:(Symbol($(Meta.quot(name)), :_, $sym)), ex, true)
ex = Base.Cartesian.poplinenum(ex)
:($name = $map($sym -> $ex, $idxs))
:($name = map($sym -> begin
$extra_args
$ex
end, $idxs))
end

function single_named_expr(expr)
Expand Down
136 changes: 89 additions & 47 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct Model{F, S}
end
(m::Model)(args...; kw...) = m.f(args...; kw...)

Base.parentmodule(m::Model) = parentmodule(m.f)

for f in (:connector, :mtkmodel)
isconnector = f == :connector ? true : false
@eval begin
Expand All @@ -40,7 +42,7 @@ function _model_macro(mod, name, expr, isconnector)
:kwargs => Dict{Symbol, Dict}(),
:structural_parameters => Dict{Symbol, Dict}()
)
comps = Symbol[]
comps = Union{Symbol, Expr}[]
ext = Ref{Any}(nothing)
eqs = Expr[]
icon = Ref{Union{String, URI}}()
Expand Down Expand Up @@ -745,7 +747,7 @@ end

### Parsing Components:

function component_args!(a, b, expr, varexpr, kwargs)
function component_args!(a, b, varexpr, kwargs; index_name = nothing)
# Whenever `b` is a function call, skip the first arg aka the function name.
# Whenever it is a kwargs list, include it.
start = b.head == :call ? 2 : 1
Expand All @@ -754,73 +756,115 @@ function component_args!(a, b, expr, varexpr, kwargs)
arg isa LineNumberNode && continue
MLStyle.@match arg begin
x::Symbol || Expr(:kw, x) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(varexpr.args, :((@isdefined $x) && ($_v = $x)))
push!(kwargs, Expr(:kw, _v, nothing))
# dict[:kwargs][_v] = nothing
varname, _varname = _rename(a, x)
b.args[i] = Expr(:kw, x, _varname)
push!(varexpr.args, :((if $varname !== nothing
$_varname = $varname
elseif @isdefined $x
# Allow users to define a var in `structural_parameters` and set
# that as positional arg of subcomponents; it is useful for cases
# where it needs to be passed to multiple subcomponents.
$_varname = $x
end)))
push!(kwargs, Expr(:kw, varname, nothing))
# dict[:kwargs][varname] = nothing
end
Expr(:parameters, x...) => begin
component_args!(a, arg, expr, varexpr, kwargs)
component_args!(a, arg, varexpr, kwargs)
end
Expr(:kw, x, y) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(varexpr.args, :($_v = $_v === nothing ? $y : $_v))
push!(kwargs, Expr(:kw, _v, nothing))
# dict[:kwargs][_v] = nothing
varname, _varname = _rename(a, x)
b.args[i] = Expr(:kw, x, _varname)
if isnothing(index_name)
push!(varexpr.args, :($_varname = $varname === nothing ? $y : $varname))
else
push!(varexpr.args,
:($_varname = $varname === nothing ? $y : $varname[$index_name]))
end
push!(kwargs, Expr(:kw, varname, nothing))
# dict[:kwargs][varname] = nothing
end
_ => error("Could not parse $arg of component $a")
end
end
end

function _parse_components!(exprs, body, kwargs)
expr = Expr(:block)
model_name(name, range) = Symbol.(name, :_, collect(range))

function _parse_components!(body, kwargs)
local expr
varexpr = Expr(:block)
# push!(exprs, varexpr)
comps = Vector{Union{Symbol, Expr}}[]
comps = Vector{Union{Union{Expr, Symbol}, Expr}}[]
comp_names = []

for arg in body.args
arg isa LineNumberNode && continue
MLStyle.@match arg begin
Expr(:block) => begin
# TODO: Do we need this?
error("Multiple `@components` block detected within a single block")
end
Expr(:(=), a, b) => begin
arg = deepcopy(arg)
b = deepcopy(arg.args[2])
Base.remove_linenums!(body)
arg = body.args[end]

component_args!(a, b, expr, varexpr, kwargs)
MLStyle.@match arg begin
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d)))) => begin
array_varexpr = Expr(:block)

arg.args[2] = b
push!(expr.args, arg)
push!(comp_names, a)
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
push!(comps, [a, b.args[1]])
end
push!(comp_names, :($a...))
push!(comps, [a, b.args[1], d])
b = deepcopy(b)

component_args!(a, b, array_varexpr, kwargs; index_name = c)

expr = _named_idxs(a, d, :($c -> $b); extra_args = array_varexpr)
end
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:filter, e, Expr(:(=), c, d))))) => begin
error("List comprehensions with conditional statements aren't supported.")
end
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d), e...))) => begin
# Note that `e` is of the form `Tuple{Expr(:(=), c, d)}`
error("More than one index isn't supported while building component array")
end
Expr(:block) => begin
# TODO: Do we need this?
error("Multiple `@components` block detected within a single block")
end
Expr(:(=), a, Expr(:for, Expr(:(=), c, d), b)) => begin
Base.remove_linenums!(b)
array_varexpr = Expr(:block)
push!(array_varexpr.args, b.args[1:(end - 1)]...)
push!(comp_names, :($a...))
push!(comps, [a, b.args[end].args[1], d])
b = deepcopy(b)

component_args!(a, b.args[end], array_varexpr, kwargs; index_name = c)

expr = _named_idxs(a, d, :($c -> $(b.args[end])); extra_args = array_varexpr)
end
Expr(:(=), a, b) => begin
arg = deepcopy(arg)
b = deepcopy(arg.args[2])

component_args!(a, b, varexpr, kwargs)

arg.args[2] = b
expr = :(@named $arg)
push!(comp_names, a)
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
push!(comps, [a, b.args[1]])
end
_ => error("Couldn't parse the component body: $arg")
end
_ => error("Couldn't parse the component body: $arg")
end

return comp_names, comps, expr, varexpr
end

function push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
blk = Expr(:block)
push!(blk.args, varexpr)
push!(blk.args, :(@named begin
$(expr_vec.args...)
end))
push!(blk.args, expr_vec)
push!(blk.args, :($push!(systems, $(comp_names...))))
push!(ifexpr.args, blk)
end

function handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition = nothing)
push!(ifexpr.args, condition)
comp_names, comps, expr_vec, varexpr = _parse_components!(ifexpr, x, kwargs)
comp_names, comps, expr_vec, varexpr = _parse_components!(x, kwargs)
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
comps
end
Expand All @@ -836,7 +880,7 @@ function handle_if_y!(exprs, ifexpr, y, kwargs)
push!(ifexpr.args, elseifexpr)
(comps...,)
else
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, y, kwargs)
comp_names, comps, expr_vec, varexpr = _parse_components!(y, kwargs)
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
comps
end
Expand All @@ -861,25 +905,23 @@ function parse_components!(exprs, cs, dict, compbody, kwargs)
Expr(:if, condition, x, y) => begin
handle_conditional_components(condition, dict, exprs, kwargs, x, y)
end
Expr(:(=), a, b) => begin
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs,
:(begin
# Either the arg is top level component declaration or an invalid cause - both are handled by `_parse_components`
_ => begin
comp_names, comps, expr_vec, varexpr = _parse_components!(:(begin
$arg
end),
kwargs)
push!(cs, comp_names...)
push!(dict[:components], comps...)
push!(exprs, varexpr, :(@named begin
$(expr_vec.args...)
end))
push!(exprs, varexpr, expr_vec)
end
_ => error("Couldn't parse the component body $compbody")
end
end
end

function _rename(compname, varname)
compname = Symbol(compname, :__, varname)
(compname, Symbol(:_, compname))
end

# Handle top level branching
Expand Down
69 changes: 69 additions & 0 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,72 @@ end
@named m = MyModel()
@variables x___(t)
@test isequal(x___, _b[])

@testset "Component array" begin
@mtkmodel SubComponent begin
@parameters begin
sc
end
end

@mtkmodel Component begin
@structural_parameters begin
N = 2
end
@components begin
comprehension = [SubComponent(sc = i) for i in 1:N]
written_out_for = for i in 1:N
sc = i + 1
SubComponent(; sc)
end
single_sub_component = SubComponent()
end
end

@named component = Component()
component = complete(component)

@test nameof.(ModelingToolkit.get_systems(component)) == [
:comprehension_1,
:comprehension_2,
:written_out_for_1,
:written_out_for_2,
:single_sub_component
]

@test getdefault(component.comprehension_1.sc) == 1
@test getdefault(component.comprehension_2.sc) == 2
@test getdefault(component.written_out_for_1.sc) == 2
@test getdefault(component.written_out_for_2.sc) == 3

@mtkmodel ConditionalComponent begin
@structural_parameters begin
N = 2
end
@components begin
if N == 2
if_comprehension = [SubComponent(sc = i) for i in 1:N]
elseif N == 3
elseif_comprehension = [SubComponent(sc = i) for i in 1:N]
else
else_comprehension = [SubComponent(sc = i) for i in 1:N]
end
end
end

@named if_component = ConditionalComponent()
@test nameof.(get_systems(if_component)) == [:if_comprehension_1, :if_comprehension_2]

@named elseif_component = ConditionalComponent(; N = 3)
@test nameof.(get_systems(elseif_component)) ==
[:elseif_comprehension_1, :elseif_comprehension_2, :elseif_comprehension_3]

@named else_component = ConditionalComponent(; N = 4)
@test nameof.(get_systems(else_component)) ==
[:else_comprehension_1, :else_comprehension_2,
:else_comprehension_3, :else_comprehension_4]
end

@testset "Parent module of Models" begin
@test parentmodule(MyMockModule.Ground) == MyMockModule
end

0 comments on commit a64aad8

Please sign in to comment.