diff --git a/src/bipartite_graph.jl b/src/bipartite_graph.jl index 1168d1ad24..a72ea7ef31 100644 --- a/src/bipartite_graph.jl +++ b/src/bipartite_graph.jl @@ -733,6 +733,8 @@ end Graphs.has_edge(g::DiCMOBiGraph{true}, a, b) = a in inneighbors(g, b) Graphs.has_edge(g::DiCMOBiGraph{false}, a, b) = b in outneighbors(g, a) +# This definition is required for `induced_subgraph` to work +(::Type{<:DiCMOBiGraph})(n::Integer) = SimpleDiGraph(n) # Condensation Graphs abstract type AbstractCondensationGraph <: AbstractGraph{Int} end diff --git a/src/structural_transformation/bipartite_tearing/modia_tearing.jl b/src/structural_transformation/bipartite_tearing/modia_tearing.jl index e8304752e8..f8be5ccd7b 100644 --- a/src/structural_transformation/bipartite_tearing/modia_tearing.jl +++ b/src/structural_transformation/bipartite_tearing/modia_tearing.jl @@ -83,7 +83,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, max(length(var_eq_matching), maximum(x -> x isa Int ? x : 0, var_eq_matching))) full_var_eq_matching = copy(var_eq_matching) - var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching) + var_sccs = find_var_sccs(graph, var_eq_matching) vargraph = DiCMOBiGraph{true}(graph) ict = IncrementalCycleTracker(vargraph; dir = :in) @@ -111,5 +111,5 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, empty!(ieqs) empty!(filtered_vars) end - return var_eq_matching, full_var_eq_matching + return var_eq_matching, full_var_eq_matching, var_sccs end diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index bae8816b50..aef149dd93 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -299,59 +299,60 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja (n_dummys = length(dummy_derivatives)) @warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)." end - dummy_derivatives_set = BitSet(dummy_derivatives) - is_not_present_non_rec = let graph = graph - v -> isempty(đť‘‘neighbors(graph, v)) + ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives)) + if log + ret + else + ret[1] end +end - is_not_present = let var_to_diff = var_to_diff - v -> while true - # if a higher derivative is present, then it's present - is_not_present_non_rec(v) || return false - v = var_to_diff[v] - v === nothing && return true - end +function is_present(structure, v)::Bool + @unpack var_to_diff, graph = structure + while true + # if a higher derivative is present, then it's present + isempty(đť‘‘neighbors(graph, v)) || return true + v = var_to_diff[v] + v === nothing && return false end +end - # Derivatives that are either in the dummy derivatives set or ended up not - # participating in the system at all are not considered differential - is_some_diff = let dummy_derivatives_set = dummy_derivatives_set - v -> !(v in dummy_derivatives_set) && !is_not_present(v) - end +# Derivatives that are either in the dummy derivatives set or ended up not +# participating in the system at all are not considered differential +function is_some_diff(structure, dummy_derivatives, v)::Bool + !(v in dummy_derivatives) && is_present(structure, v) +end - # We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with - # actually differentiated variables. - isdiffed = let diff_to_var = diff_to_var - v -> diff_to_var[v] !== nothing && is_some_diff(v) - end +# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with +# actually differentiated variables. +function isdiffed((structure, dummy_derivatives), v)::Bool + @unpack var_to_diff, graph = structure + diff_to_var = invview(var_to_diff) + diff_to_var[v] !== nothing && is_some_diff(structure, dummy_derivatives, v) +end +function tearing_with_dummy_derivatives(structure, dummy_derivatives) + @unpack var_to_diff = structure # We can eliminate variables that are not a selected state (differential # variables). Selected states are differentiated variables that are not # dummy derivatives. - can_eliminate = let var_to_diff = var_to_diff - v -> begin - dv = var_to_diff[v] - dv === nothing && return true - is_some_diff(dv) || return true - return false + can_eliminate = falses(length(var_to_diff)) + for (v, dv) in enumerate(var_to_diff) + dv = var_to_diff[v] + if dv === nothing || !is_some_diff(structure, dummy_derivatives, dv) + can_eliminate[v] = true end end - - var_eq_matching, full_var_eq_matching = tear_graph_modia(structure, isdiffed, + var_eq_matching, full_var_eq_matching, var_sccs = tear_graph_modia(structure, + Base.Fix1(isdiffed, (structure, dummy_derivatives)), Union{Unassigned, SelectedState}; - varfilter = can_eliminate) + varfilter = Base.Fix1(getindex, can_eliminate)) for v in eachindex(var_eq_matching) - is_not_present(v) && continue + is_present(structure, v) || continue dv = var_to_diff[v] - (dv === nothing || !is_some_diff(dv)) && continue + (dv === nothing || !is_some_diff(structure, dummy_derivatives, dv)) && continue var_eq_matching[v] = SelectedState() end - - if log - candidates = can_eliminate.(1:ndsts(graph)) - return var_eq_matching, full_var_eq_matching, candidates - else - return var_eq_matching - end + return var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate end