Skip to content

Commit

Permalink
feat: allow simplifying DAEs to index zero
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 27, 2024
1 parent eda23d4 commit b77ed41
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 13 deletions.
28 changes: 19 additions & 9 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
end

"""
computed_highest_diff_variables(structure)
computed_highest_diff_variables(structure; whitelisted_vars = ())
Computes which variables are the "highest-differentiated" for purposes of
pantelides. Ordinarily this is relatively straightforward. However, in our
Expand All @@ -83,12 +83,18 @@ case, there is one complicating condition:
This function takes care of these complications are returns a boolean array
for every variable, indicating whether it is considered "highest-differentiated".
For each index `i` in `whitelisted_vars`, the `i`th variable is included if it
is the highest differentiated variable even if it doesn't appear in the system.
"""
function computed_highest_diff_variables(structure)
function computed_highest_diff_variables(structure; whitelisted_vars = ())
@unpack graph, var_to_diff = structure

nvars = length(var_to_diff)
varwhitelist = falses(nvars)
for i in whitelisted_vars
varwhitelist[i] = true
end
for var in 1:nvars
if var_to_diff[var] === nothing && !varwhitelist[var]
# This variable is structurally highest-differentiated, but may not actually appear in the
Expand Down Expand Up @@ -125,7 +131,7 @@ end
Perform Pantelides algorithm.
"""
function pantelides!(
state::TransformationState; finalize = true, maxiters = 8000, kwargs...)
state::TransformationState; finalize = true, maxiters = 8000, whitelisted_vars = (), kwargs...)
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
neqs = nsrcs(graph)
nvars = nv(var_to_diff)
Expand All @@ -137,8 +143,7 @@ function pantelides!(
eq -> !isempty(𝑠neighbors(graph, eq)) && eq_to_diff[eq] === nothing,
1:neqs′)

varwhitelist = computed_highest_diff_variables(state.structure)

varwhitelist = computed_highest_diff_variables(state.structure; whitelisted_vars)
if nnonemptyeqs > count(varwhitelist)
throw(InvalidSystemException("System is structurally singular"))
end
Expand Down Expand Up @@ -206,14 +211,19 @@ function pantelides!(
end

"""
dae_index_lowering(sys::ODESystem; kwargs...) -> ODESystem
dae_index_lowering(sys::ODESystem; to_index_zero = false, kwargs...) -> ODESystem
Perform the Pantelides algorithm to transform a higher index DAE to an index 1
DAE. `kwargs` are forwarded to [`pantelides!`](@ref). End users are encouraged to call [`structural_simplify`](@ref)
instead, which calls this function internally.
instead, which calls this function internally. If `to_index_zero` is true, the DAE will be reduced to an index 1 DAE.
"""
function dae_index_lowering(sys::ODESystem; kwargs...)
function dae_index_lowering(sys::ODESystem; to_index_zero = false, kwargs...)
state = TearingState(sys)
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
if to_index_zero
newvars = ModelingToolkit.add_missing_differentials!(state)
else
newvars = ()
end
var_eq_matching = pantelides!(state; finalize = false, whitelisted_vars = newvars, kwargs...)
return invalidate_cache!(pantelides_reassemble(state, var_eq_matching))
end
10 changes: 8 additions & 2 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,13 @@ Perform index reduction and use the dummy derivative technique to ensure that
the system is balanced.
"""
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
mm = nothing, cse_hack = true, array_hack = true, to_index_zero = false, kwargs...)
if to_index_zero
newvars = ModelingToolkit.add_missing_differentials!(state)
else
newvars = ()
end

jac = let state = state
(eqs, vars) -> begin
symeqs = EquationsView(state)[eqs]
Expand All @@ -834,7 +840,7 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
p
end
end
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, whitelisted_vars = newvars,
kwargs...)
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
end
41 changes: 39 additions & 2 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,37 @@ function shift_discrete_system(ts::TearingState)
return ts
end

"""
$(TYPEDSIGNATURES)
For each variable in `ts.fullvars` which does not have a derivative in `ts.fullvars`
and is not the derivative of a variable in `ts.fullvars`, add its derivative to `ts`.
Returns the indexes of added differential variables.
"""
function add_missing_differentials!(ts::TearingState)
sys = ts.sys
D = Differential(get_iv(sys))
newvars = Int[]
for (i, v) in enumerate(ts.fullvars)
# ignore variables that have a derivative...
ts.structure.var_to_diff[i] === nothing || continue
# or are the derivative
invview(ts.structure.var_to_diff)[i] === nothing || continue
# add to fullvars
push!(ts.fullvars, D(v))
push!(newvars, length(ts.fullvars))
# update diffgraph
add_vertex!(ts.structure.var_to_diff)
add_edge!(ts.structure.var_to_diff, i, length(ts.fullvars))
# update bipartite graphs
add_vertex!(ts.structure.graph, DST)
if ts.structure.solvable_graph !== nothing
add_vertex!(ts.structure.solvable_graph, DST)
end
end
return newvars
end

using .BipartiteGraphs: Label, BipartiteAdjacencyList
struct SystemStructurePrintMatrix <:
AbstractMatrix{Union{Label, BipartiteAdjacencyList}}
Expand Down Expand Up @@ -676,6 +707,7 @@ end
function _structural_simplify!(state::TearingState, io; simplify = false,
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
dummy_derivative = true,
to_index_zero = false,
kwargs...)
if fully_determined isa Bool
check_consistency &= fully_determined
Expand All @@ -699,9 +731,14 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
end
if fully_determined && dummy_derivative
sys = ModelingToolkit.dummy_derivative(
sys, state; simplify, mm, check_consistency, kwargs...)
sys, state; simplify, mm, check_consistency, to_index_zero, kwargs...)
elseif fully_determined
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
if to_index_zero
newvars = add_missing_differentials!(state)
else
newvars = ()
end
var_eq_matching = pantelides!(state; finalize = false, whitelisted_vars = newvars, kwargs...)
sys = pantelides_reassemble(state, var_eq_matching)
state = TearingState(sys)
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
Expand Down

0 comments on commit b77ed41

Please sign in to comment.