Skip to content

Commit

Permalink
Merge pull request #3235 from AayushSabharwal/as/init-fully-determined
Browse files Browse the repository at this point in the history
feat: simplify initialization systems with `fully_determined=true` if possible
  • Loading branch information
ChrisRackauckas authored Nov 26, 2024
2 parents 42d4d63 + 2b7a6b6 commit f120310
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 21 deletions.
57 changes: 41 additions & 16 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,41 @@ end
###
### Structural check
###
function check_consistency(state::TransformationState, orig_inputs)

"""
$(TYPEDSIGNATURES)
Check if the `state` represents a singular system, and return the unmatched variables.
"""
function singular_check(state::TransformationState)
@unpack graph, var_to_diff = state.structure
fullvars = get_fullvars(state)
# This is defined to check if Pantelides algorithm terminates. For more
# details, check the equation (15) of the original paper.
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
map(collect, edges(var_to_diff))])
extended_var_eq_matching = maximal_matching(extended_graph)

nvars = ndsts(graph)
unassigned_var = []
for (vj, eq) in enumerate(extended_var_eq_matching)
vj > nvars && break
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
push!(unassigned_var, fullvars[vj])
end
end
return unassigned_var
end

"""
$(TYPEDSIGNATURES)
Check the consistency of `state`, given the inputs `orig_inputs`. If `nothrow == false`,
throws an error if the system is under-/over-determined or singular. In this case, if the
function returns it will return `true`. If `nothrow == true`, it will return `false`
instead of throwing an error. The singular case will print a warning.
"""
function check_consistency(state::TransformationState, orig_inputs; nothrow = false)
fullvars = get_fullvars(state)
neqs = n_concrete_eqs(state)
@unpack graph, var_to_diff = state.structure
Expand All @@ -72,6 +106,7 @@ function check_consistency(state::TransformationState, orig_inputs)
is_balanced = n_highest_vars == neqs

if neqs > 0 && !is_balanced
nothrow && return false
varwhitelist = var_to_diff .== nothing
var_eq_matching = maximal_matching(graph, eq -> true, v -> varwhitelist[v]) # not assigned
# Just use `error_reporting` to do conditional
Expand All @@ -85,22 +120,12 @@ function check_consistency(state::TransformationState, orig_inputs)
error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
end

# This is defined to check if Pantelides algorithm terminates. For more
# details, check the equation (15) of the original paper.
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
map(collect, edges(var_to_diff))])
extended_var_eq_matching = maximal_matching(extended_graph)

nvars = ndsts(graph)
unassigned_var = []
for (vj, eq) in enumerate(extended_var_eq_matching)
vj > nvars && break
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
push!(unassigned_var, fullvars[vj])
end
end
unassigned_var = singular_check(state)

if !isempty(unassigned_var) || !is_balanced
if nothrow
return false
end
io = IOBuffer()
Base.print_array(io, unassigned_var)
unassigned_var_str = String(take!(io))
Expand All @@ -110,7 +135,7 @@ function check_consistency(state::TransformationState, orig_inputs)
throw(InvalidSystemException(errmsg))
end

return nothing
return true
end

###
Expand Down
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3335,7 +3335,8 @@ function parse_variable(sys::AbstractSystem, str::AbstractString)
# I'd write a regex to validate `str`, but https://xkcd.com/1171/
str = strip(str)
derivative_level = 0
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) && endswith(str, ")")
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) &&
endswith(str, ")")
if cond1
derivative_level += 1
str = _string_view_inner(str, 2, 1)
Expand Down
15 changes: 14 additions & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
check_length = true,
warn_initialize_determined = true,
initialization_eqs = [],
fully_determined = false,
fully_determined = nothing,
check_units = true,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
Expand All @@ -1313,6 +1313,19 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
end

ts = get_tearing_state(isys)
if warn_initialize_determined &&
(unassigned_vars = StructuralTransformations.singular_check(ts); !isempty(unassigned_vars))
errmsg = """
The initialization system is structurally singular. Guess values may \
significantly affect the initial values of the ODE. The problematic variables \
are $unassigned_vars.
Note that the identification of problematic variables is a best-effort heuristic.
"""
@warn errmsg
end

uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])

# TODO: throw on uninitialized arrays
Expand Down
2 changes: 1 addition & 1 deletion src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ function process_SciMLProblem(
constructor, sys::AbstractSystem, u0map, pmap; build_initializeprob = true,
implicit_dae = false, t = nothing, guesses = AnyDict(),
warn_initialize_determined = true, initialization_eqs = [],
eval_expression = false, eval_module = @__MODULE__, fully_determined = false,
eval_expression = false, eval_module = @__MODULE__, fully_determined = nothing,
check_initialization_units = false, tofloat = true, use_union = false,
u0_constructor = identity, du0map = nothing, check_length = true,
symbolic_u0 = false, warn_cyclic_dependency = false,
Expand Down
9 changes: 7 additions & 2 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,11 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
dummy_derivative = true,
kwargs...)
check_consistency &= fully_determined
if fully_determined isa Bool
check_consistency &= fully_determined
else
check_consistency = true
end
has_io = io !== nothing
orig_inputs = Set()
if has_io
Expand All @@ -690,7 +694,8 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
end
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
if check_consistency
ModelingToolkit.check_consistency(state, orig_inputs)
fully_determined = ModelingToolkit.check_consistency(
state, orig_inputs; nothrow = fully_determined === nothing)
end
if fully_determined && dummy_derivative
sys = ModelingToolkit.dummy_derivative(
Expand Down
11 changes: 11 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,14 @@ end

@test_nowarn remake(prob, p = prob.p)
end

@testset "Singular initialization prints a warning" begin
@parameters g
@variables x(t) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x)) ~ λ * x
D(D(y)) ~ λ * y - g
x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)
@test_warn ["structurally singular", "initialization", "Guess", "heuristic"] ODEProblem(
pend, [x => 1, y => 0], (0.0, 1.5), [g => 1], guesses ==> 1])
end

0 comments on commit f120310

Please sign in to comment.