Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for component array in @mtkmodel #2368

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions docs/src/basics/MTKModel_Connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ end
end
@structural_parameters begin
f = sin
N = 2
end
begin
v_var = 1.0
Expand All @@ -69,6 +70,11 @@ end
@extend ModelB(; p1)
@components begin
model_a = ModelA(; k_array)
model_array_a = [ModelA(; k = i) for i in 1:N]
model_array_b = for i in 1:N
k = i^2
ModelA(; k)
end
end
@equations begin
model_a.k ~ f(v)
Expand Down Expand Up @@ -146,6 +152,7 @@ julia> @mtkbuild model_c2 = ModelC(; p1 = 2.0)
#### `@components` begin block

- Declare the subcomponents within `@components` begin block.
- Array of components can be declared with a for loop or a list comprehension.
- The arguments in these subcomponents are promoted as keyword arguments as `subcomponent_name__argname` with `nothing` as default value.
- Whenever components are created with `@named` macro, these can be accessed with `.` operator as `subcomponent_name.argname`
- In the above example, as `k` of `model_a` isn't listed while defining the sub-component in `ModelC`, its default value can't be modified by users. While `k_array` can be set as:
Expand Down Expand Up @@ -247,14 +254,13 @@ For example, the structure of `ModelC` is:
```julia
julia> ModelC.structure
Dict{Symbol, Any} with 9 entries:
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA]]
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)))
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :v=>Dict{Symbol, Union{Nothing, Symbol}}(:value=>:v_var, :type=>Real), :v_array=>Dict(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin))
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :p1=>Dict(:value=>nothing))
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2))
:independent_variable => t
:constants => Dict{Symbol, Dict}(:c=>Dict(:value=>1))
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
:equations => Any["model_a.k ~ f(v)"]
```
Expand Down
7 changes: 5 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,7 @@
end
end

function _named_idxs(name::Symbol, idxs, call)
function _named_idxs(name::Symbol, idxs, call; extra_args = "")
if call.head !== :->
throw(ArgumentError("Not an anonymous function"))
end
Expand All @@ -1413,7 +1413,10 @@
ex = Base.Cartesian.poplinenum(ex)
ex = _named(:(Symbol($(Meta.quot(name)), :_, $sym)), ex, true)
ex = Base.Cartesian.poplinenum(ex)
:($name = $map($sym -> $ex, $idxs))
:($name = map($sym -> begin
$extra_args
$ex

Check warning on line 1418 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1418

Added line #L1418 was not covered by tests
end, $idxs))
end

function single_named_expr(expr)
Expand Down
136 changes: 89 additions & 47 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
end
(m::Model)(args...; kw...) = m.f(args...; kw...)

Base.parentmodule(m::Model) = parentmodule(m.f)

Check warning on line 27 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L27

Added line #L27 was not covered by tests

for f in (:connector, :mtkmodel)
isconnector = f == :connector ? true : false
@eval begin
Expand All @@ -40,7 +42,7 @@
:kwargs => Dict{Symbol, Dict}(),
:structural_parameters => Dict{Symbol, Dict}()
)
comps = Symbol[]
comps = Union{Symbol, Expr}[]
ext = Ref{Any}(nothing)
eqs = Expr[]
icon = Ref{Union{String, URI}}()
Expand Down Expand Up @@ -745,7 +747,7 @@

### Parsing Components:

function component_args!(a, b, expr, varexpr, kwargs)
function component_args!(a, b, varexpr, kwargs; index_name = nothing)
# Whenever `b` is a function call, skip the first arg aka the function name.
# Whenever it is a kwargs list, include it.
start = b.head == :call ? 2 : 1
Expand All @@ -754,73 +756,115 @@
arg isa LineNumberNode && continue
MLStyle.@match arg begin
x::Symbol || Expr(:kw, x) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(varexpr.args, :((@isdefined $x) && ($_v = $x)))
push!(kwargs, Expr(:kw, _v, nothing))
# dict[:kwargs][_v] = nothing
varname, _varname = _rename(a, x)
b.args[i] = Expr(:kw, x, _varname)
push!(varexpr.args, :((if $varname !== nothing
$_varname = $varname
elseif @isdefined $x

Check warning on line 763 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L762-L763

Added lines #L762 - L763 were not covered by tests
# Allow users to define a var in `structural_parameters` and set
# that as positional arg of subcomponents; it is useful for cases
# where it needs to be passed to multiple subcomponents.
$_varname = $x
end)))
push!(kwargs, Expr(:kw, varname, nothing))
# dict[:kwargs][varname] = nothing
end
Expr(:parameters, x...) => begin
component_args!(a, arg, expr, varexpr, kwargs)
component_args!(a, arg, varexpr, kwargs)
end
Expr(:kw, x, y) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(varexpr.args, :($_v = $_v === nothing ? $y : $_v))
push!(kwargs, Expr(:kw, _v, nothing))
# dict[:kwargs][_v] = nothing
varname, _varname = _rename(a, x)
b.args[i] = Expr(:kw, x, _varname)
if isnothing(index_name)
push!(varexpr.args, :($_varname = $varname === nothing ? $y : $varname))
else
push!(varexpr.args,

Check warning on line 781 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L781

Added line #L781 was not covered by tests
:($_varname = $varname === nothing ? $y : $varname[$index_name]))
end
push!(kwargs, Expr(:kw, varname, nothing))
# dict[:kwargs][varname] = nothing
end
_ => error("Could not parse $arg of component $a")
end
end
end

function _parse_components!(exprs, body, kwargs)
expr = Expr(:block)
model_name(name, range) = Symbol.(name, :_, collect(range))

Check warning on line 792 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L792

Added line #L792 was not covered by tests

function _parse_components!(body, kwargs)
local expr
varexpr = Expr(:block)
# push!(exprs, varexpr)
comps = Vector{Union{Symbol, Expr}}[]
comps = Vector{Union{Union{Expr, Symbol}, Expr}}[]
comp_names = []

for arg in body.args
arg isa LineNumberNode && continue
MLStyle.@match arg begin
Expr(:block) => begin
# TODO: Do we need this?
error("Multiple `@components` block detected within a single block")
end
Expr(:(=), a, b) => begin
arg = deepcopy(arg)
b = deepcopy(arg.args[2])
Base.remove_linenums!(body)
arg = body.args[end]

component_args!(a, b, expr, varexpr, kwargs)
MLStyle.@match arg begin
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d)))) => begin
array_varexpr = Expr(:block)

arg.args[2] = b
push!(expr.args, arg)
push!(comp_names, a)
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
push!(comps, [a, b.args[1]])
end
push!(comp_names, :($a...))
push!(comps, [a, b.args[1], d])
b = deepcopy(b)

component_args!(a, b, array_varexpr, kwargs; index_name = c)

expr = _named_idxs(a, d, :($c -> $b); extra_args = array_varexpr)
end
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:filter, e, Expr(:(=), c, d))))) => begin
error("List comprehensions with conditional statements aren't supported.")

Check warning on line 816 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L816

Added line #L816 was not covered by tests
end
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d), e...))) => begin
# Note that `e` is of the form `Tuple{Expr(:(=), c, d)}`
error("More than one index isn't supported while building component array")

Check warning on line 820 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L820

Added line #L820 was not covered by tests
end
Expr(:block) => begin
# TODO: Do we need this?
error("Multiple `@components` block detected within a single block")

Check warning on line 824 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L824

Added line #L824 was not covered by tests
end
Expr(:(=), a, Expr(:for, Expr(:(=), c, d), b)) => begin
Base.remove_linenums!(b)
array_varexpr = Expr(:block)
push!(array_varexpr.args, b.args[1:(end - 1)]...)
push!(comp_names, :($a...))
push!(comps, [a, b.args[end].args[1], d])
b = deepcopy(b)

component_args!(a, b.args[end], array_varexpr, kwargs; index_name = c)

expr = _named_idxs(a, d, :($c -> $(b.args[end])); extra_args = array_varexpr)
end
Expr(:(=), a, b) => begin
arg = deepcopy(arg)
b = deepcopy(arg.args[2])

component_args!(a, b, varexpr, kwargs)

arg.args[2] = b
expr = :(@named $arg)
push!(comp_names, a)
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
push!(comps, [a, b.args[1]])
end
_ => error("Couldn't parse the component body: $arg")
end
_ => error("Couldn't parse the component body: $arg")
end

return comp_names, comps, expr, varexpr
end

function push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
blk = Expr(:block)
push!(blk.args, varexpr)
push!(blk.args, :(@named begin
$(expr_vec.args...)
end))
push!(blk.args, expr_vec)
push!(blk.args, :($push!(systems, $(comp_names...))))
push!(ifexpr.args, blk)
end

function handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition = nothing)
push!(ifexpr.args, condition)
comp_names, comps, expr_vec, varexpr = _parse_components!(ifexpr, x, kwargs)
comp_names, comps, expr_vec, varexpr = _parse_components!(x, kwargs)
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
comps
end
Expand All @@ -836,7 +880,7 @@
push!(ifexpr.args, elseifexpr)
(comps...,)
else
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, y, kwargs)
comp_names, comps, expr_vec, varexpr = _parse_components!(y, kwargs)
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
comps
end
Expand All @@ -861,25 +905,23 @@
Expr(:if, condition, x, y) => begin
handle_conditional_components(condition, dict, exprs, kwargs, x, y)
end
Expr(:(=), a, b) => begin
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs,
:(begin
# Either the arg is top level component declaration or an invalid cause - both are handled by `_parse_components`
_ => begin

Check warning on line 909 in src/systems/model_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/model_parsing.jl#L909

Added line #L909 was not covered by tests
comp_names, comps, expr_vec, varexpr = _parse_components!(:(begin
$arg
end),
kwargs)
push!(cs, comp_names...)
push!(dict[:components], comps...)
push!(exprs, varexpr, :(@named begin
$(expr_vec.args...)
end))
push!(exprs, varexpr, expr_vec)
end
_ => error("Couldn't parse the component body $compbody")
end
end
end

function _rename(compname, varname)
compname = Symbol(compname, :__, varname)
(compname, Symbol(:_, compname))
end

# Handle top level branching
Expand Down
69 changes: 69 additions & 0 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,72 @@ end
@named m = MyModel()
@variables x___(t)
@test isequal(x___, _b[])

@testset "Component array" begin
@mtkmodel SubComponent begin
@parameters begin
sc
end
end

@mtkmodel Component begin
@structural_parameters begin
N = 2
end
@components begin
comprehension = [SubComponent(sc = i) for i in 1:N]
written_out_for = for i in 1:N
sc = i + 1
SubComponent(; sc)
end
single_sub_component = SubComponent()
end
end

@named component = Component()
component = complete(component)

@test nameof.(ModelingToolkit.get_systems(component)) == [
:comprehension_1,
:comprehension_2,
:written_out_for_1,
:written_out_for_2,
:single_sub_component
]

@test getdefault(component.comprehension_1.sc) == 1
@test getdefault(component.comprehension_2.sc) == 2
@test getdefault(component.written_out_for_1.sc) == 2
@test getdefault(component.written_out_for_2.sc) == 3

@mtkmodel ConditionalComponent begin
@structural_parameters begin
N = 2
end
@components begin
if N == 2
if_comprehension = [SubComponent(sc = i) for i in 1:N]
elseif N == 3
elseif_comprehension = [SubComponent(sc = i) for i in 1:N]
else
else_comprehension = [SubComponent(sc = i) for i in 1:N]
end
end
end

@named if_component = ConditionalComponent()
@test nameof.(get_systems(if_component)) == [:if_comprehension_1, :if_comprehension_2]

@named elseif_component = ConditionalComponent(; N = 3)
@test nameof.(get_systems(elseif_component)) ==
[:elseif_comprehension_1, :elseif_comprehension_2, :elseif_comprehension_3]

@named else_component = ConditionalComponent(; N = 4)
@test nameof.(get_systems(else_component)) ==
[:else_comprehension_1, :else_comprehension_2,
:else_comprehension_3, :else_comprehension_4]
end

@testset "Parent module of Models" begin
@test parentmodule(MyMockModule.Ground) == MyMockModule
end
Loading