Skip to content

Commit

Permalink
Merge pull request #2233 from SciML/myb/new_domain
Browse files Browse the repository at this point in the history
Add `domain_connect`
  • Loading branch information
YingboMa authored Sep 16, 2023
2 parents a1369f3 + e4b9235 commit 1344c1d
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 131 deletions.
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ export SteadyStateProblem, SteadyStateProblemExpr
export JumpProblem, DiscreteProblem
export NonlinearSystem, OptimizationSystem, ConstraintsSystem
export alias_elimination, flatten
export connect, @connector, Connection, Flow, Stream, instream
export connect, domain_connect, @connector, Connection, Flow, Stream, instream
export @component, @mtkmodel
export isinput, isoutput, getbounds, hasbounds, isdisturbance, istunable, getdist, hasdist,
tunable_parameters, isirreducible, getdescription, hasdescription, isbinaryvar,
Expand Down
196 changes: 68 additions & 128 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
function domain_connect(sys1, sys2, syss...)
syss = (sys1, sys2, syss...)
length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!")
Equation(Connection(:domain), Connection(syss)) # the RHS are connected systems
end

function get_connection_type(s)
s = unwrap(s)
if istree(s) && operation(s) === getindex
Expand Down Expand Up @@ -260,23 +266,24 @@ end

function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing)
connectionsets = ConnectionSet[]
sys = generate_connection_set!(connectionsets, sys, find, replace)
domain_free_connectionsets = filter(connectionsets) do cset
!any(s -> is_domain_connector(s.sys.sys), cset.set)
end
domain_csets = ConnectionSet[]
sys = generate_connection_set!(connectionsets, domain_csets, sys, find, replace)
csets = merge(connectionsets)
domain_csets = merge([csets; domain_csets], true)

sys, (merge(domain_free_connectionsets), connectionsets)
sys, (csets, domain_csets)
end

function generate_connection_set!(connectionsets, sys::AbstractSystem, find, replace,
namespace = nothing)
function generate_connection_set!(connectionsets, domain_csets,
sys::AbstractSystem, find, replace, namespace = nothing)
subsys = get_systems(sys)

isouter = generate_isouter(sys)
eqs′ = get_eqs(sys)
eqs = Equation[]

cts = [] # connections
domain_cts = [] # connections
extra_states = []
for eq in eqs′
lhs = eq.lhs
Expand All @@ -292,8 +299,14 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
else
if lhs isa Number || lhs isa Symbolic
push!(eqs, eq) # split connections and equations
elseif lhs isa Connection
if get_systems(lhs) === :domain
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
else
push!(cts, get_systems(rhs))
end
else
push!(cts, get_systems(rhs))
error("$eq is not a legal equation!")
end
end
end
Expand All @@ -302,6 +315,7 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
T = ConnectionElement
for s in subsys
isconnector(s) || continue
is_domain_connector(s) && continue
for v in states(s)
Flow === get_connection_type(v) || continue
push!(connectionsets, ConnectionSet([T(LazyNamespace(namespace, s), v, false)]))
Expand All @@ -316,24 +330,42 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
if !isempty(extra_states)
@set! sys.states = [get_states(sys); extra_states]
end
@set! sys.systems = map(s -> generate_connection_set!(connectionsets, s, find, replace,
@set! sys.systems = map(s -> generate_connection_set!(connectionsets, domain_csets, s,
find, replace,
renamespace(namespace, s)),
subsys)
@set! sys.eqs = eqs
end

function Base.merge(csets::AbstractVector{<:ConnectionSet})
function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
csets, merged = partial_merge(csets, allouter)
while merged
csets, merged = partial_merge(csets)
end
csets
end

function partial_merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
mcsets = ConnectionSet[]
ele2idx = Dict{ConnectionElement, Int}()
cacheset = Set{ConnectionElement}()
for cset in csets
merged = false
for (j, cset) in enumerate(csets)
if allouter
cset = ConnectionSet(map(cset.set) do e
@set! e.isouter = true
end)
end
idx = nothing
for e in cset.set
idx = get(ele2idx, e, nothing)
idx !== nothing && break
if idx !== nothing
merged = true
break
end
end
if idx === nothing
push!(mcsets, cset)
push!(mcsets, copy(cset))
for e in cset.set
ele2idx[e] = length(mcsets)
end
Expand All @@ -352,79 +384,7 @@ function Base.merge(csets::AbstractVector{<:ConnectionSet})
empty!(cacheset)
end
end
mcsets
end

struct SystemDomainGraph{T, C <: AbstractVector{<:ConnectionSet}} <:
Graphs.AbstractGraph{Int}
ts::T
lineqs::BitSet
var2idx::Dict{Any, Int}
id2cset::Vector{NTuple{2, Int}}
cset2id::Vector{Vector{Int}}
csets::C
sys2id::Dict{Symbol, Int}
outne::Vector{Union{Nothing, Vector{Int}}}
end

Graphs.nv(g::SystemDomainGraph) = length(g.id2cset)
function Graphs.outneighbors(g::SystemDomainGraph, n::Int)
i, j = g.id2cset[n]
ids = copy(g.cset2id[i])
@unpack ts, lineqs, var2idx = g
@unpack fullvars, structure = ts
@unpack graph = structure
visited = BitSet(n)
for s in g.csets[i].set
s.sys.namespace === nothing && continue
sys = s.sys.sys
is_domain_connector(sys) && continue
vidx = get(var2idx, states(s.sys.namespace, states(sys, s.v)), 0)
iszero(vidx) && continue
ies = 𝑑neighbors(graph, vidx)
for ie in ies
ie in lineqs || continue
for iv in 𝑠neighbors(graph, ie)
iv == vidx && continue
fv = ts.fullvars[iv]
vtype = get_connection_type(fv)
vtype === Flow || continue
n′ = get(g.sys2id, getname(fv), nothing)
n′ === nothing && continue
n′ in visited && continue
push!(visited, n′)
append!(ids, g.cset2id[g.id2cset[n′][1]])
end
end
end
ids
end
function rooted_system_domain_graph!(ts, csets::AbstractVector{<:ConnectionSet})
id2cset = NTuple{2, Int}[]
cset2id = Vector{Int}[]
sys2id = Dict{Symbol, Int}()
roots = BitSet()
for (i, c) in enumerate(csets)
cset2id′ = Int[]
for (j, s) in enumerate(c.set)
ij = (i, j)
push!(id2cset, ij)
if !haskey(sys2id, nameof(s))
n = length(id2cset)
sys2id[nameof(s)] = n
else
n = sys2id[nameof(s)]
end
push!(cset2id′, n)
is_domain_connector(s.sys.sys) && push!(roots, n)
end
push!(cset2id, cset2id′)
end
outne = Vector{Union{Nothing, Vector{Int}}}(undef, length(id2cset))
mm = linear_subsys_adjmat!(ts)
lineqs = BitSet(mm.nzrows)
var2idx = Dict{Any, Int}(reverse(en) for en in enumerate(ts.fullvars))
SystemDomainGraph(ts, lineqs, var2idx, id2cset, cset2id, csets, sys2id, outne), roots
mcsets, merged
end

function generate_connection_equations_and_stream_connections(csets::AbstractVector{
Expand Down Expand Up @@ -458,48 +418,28 @@ function generate_connection_equations_and_stream_connections(csets::AbstractVec
end

function domain_defaults(sys, domain_csets)
csets = merge(domain_csets)
g, roots = rooted_system_domain_graph!(TearingState(sys), csets)
# a simple way to make `_g` bidirectional
simple_g = SimpleGraph(nv(g))
for v in 1:nv(g), n in neighbors(g, v)
add_edge!(simple_g, v => n)
end
domain_csets = []
root_ijs = Set(g.id2cset[r] for r in roots)
for r in roots
nh = neighborhood(simple_g, r, Inf)
sources_idxs = intersect(nh, roots)
# TODO: error reporting when length(sources_idxs) > 1
length(sources_idxs) > 1 && error()
i′, j′ = g.id2cset[r]
source = csets[i′].set[j′]
domain = source => []
push!(domain_csets, domain)
# get unique cset indices that `r` is (implicitly) connected to.
idxs = BitSet(g.id2cset[i][1] for i in nh)
for i in idxs
for (j, ele) in enumerate(csets[i].set)
(i, j) == (i′, j′) && continue
if (i, j) in root_ijs
error("Domain source $(nameof(source)) and $(nameof(ele)) are connected!")
end
push!(domain[2], ele)
end
end
end

def = Dict()
for (s, mods) in domain_csets
s_def = defaults(s.sys.sys)
for m in mods
ns_s_def = Dict(states(m.sys.sys, n) => n for (n, v) in s_def)
for p in parameters(m.sys.namespace)
d_p = get(ns_s_def, p, nothing)
if d_p !== nothing
def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace,
parameters(s.sys.sys,
d_p))
for c in domain_csets
cset = c.set
idx = findfirst(s -> is_domain_connector(s.sys.sys), cset)
idx === nothing && continue
s = cset[idx]
root = s.sys
s_def = defaults(root.sys)
for (j, m) in enumerate(cset)
if j == idx
continue
elseif is_domain_connector(m.sys.sys)
error("Domain sources $(nameof(root)) and $(nameof(m)) are connected!")
else
ns_s_def = Dict(states(m.sys.sys, n) => n for (n, v) in s_def)
for p in parameters(m.sys.namespace)
d_p = get(ns_s_def, p, nothing)
if d_p !== nothing
def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace,
parameters(s.sys.sys,
d_p))
end
end
end
end
Expand Down Expand Up @@ -656,7 +596,7 @@ function expand_instream(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSy
s_inners = (s for s in cset if !s.isouter)
s_outers = (s for s in cset if s.isouter)
for (q, oscq) in enumerate(s_outers)
sq += sum(s -> max(-states(s, fv), 0), s_inners)
sq += sum(s -> max(-states(s, fv), 0), s_inners, init = 0)
for (k, s) in enumerate(s_outers)
k == q && continue
f = states(s.sys.sys, fv)
Expand Down
Loading

0 comments on commit 1344c1d

Please sign in to comment.