Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: simplify initialization systems with fully_determined=true if possible #3235

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,41 @@ end
###
### Structural check
###
function check_consistency(state::TransformationState, orig_inputs)

"""
$(TYPEDSIGNATURES)

Check if the `state` represents a singular system, and return the unmatched variables.
"""
function singular_check(state::TransformationState)
@unpack graph, var_to_diff = state.structure
fullvars = get_fullvars(state)
# This is defined to check if Pantelides algorithm terminates. For more
# details, check the equation (15) of the original paper.
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
map(collect, edges(var_to_diff))])
extended_var_eq_matching = maximal_matching(extended_graph)

nvars = ndsts(graph)
unassigned_var = []
for (vj, eq) in enumerate(extended_var_eq_matching)
vj > nvars && break
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
push!(unassigned_var, fullvars[vj])
end
end
return unassigned_var
end

"""
$(TYPEDSIGNATURES)

Check the consistency of `state`, given the inputs `orig_inputs`. If `nothrow == false`,
throws an error if the system is under-/over-determined or singular. In this case, if the
function returns it will return `true`. If `nothrow == true`, it will return `false`
instead of throwing an error. The singular case will print a warning.
"""
function check_consistency(state::TransformationState, orig_inputs; nothrow = false)
fullvars = get_fullvars(state)
neqs = n_concrete_eqs(state)
@unpack graph, var_to_diff = state.structure
Expand All @@ -72,6 +106,7 @@ function check_consistency(state::TransformationState, orig_inputs)
is_balanced = n_highest_vars == neqs

if neqs > 0 && !is_balanced
nothrow && return false
varwhitelist = var_to_diff .== nothing
var_eq_matching = maximal_matching(graph, eq -> true, v -> varwhitelist[v]) # not assigned
# Just use `error_reporting` to do conditional
Expand All @@ -85,22 +120,12 @@ function check_consistency(state::TransformationState, orig_inputs)
error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
end

# This is defined to check if Pantelides algorithm terminates. For more
# details, check the equation (15) of the original paper.
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
map(collect, edges(var_to_diff))])
extended_var_eq_matching = maximal_matching(extended_graph)

nvars = ndsts(graph)
unassigned_var = []
for (vj, eq) in enumerate(extended_var_eq_matching)
vj > nvars && break
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
push!(unassigned_var, fullvars[vj])
end
end
unassigned_var = singular_check(state)

if !isempty(unassigned_var) || !is_balanced
if nothrow
return false
end
io = IOBuffer()
Base.print_array(io, unassigned_var)
unassigned_var_str = String(take!(io))
Expand All @@ -110,7 +135,7 @@ function check_consistency(state::TransformationState, orig_inputs)
throw(InvalidSystemException(errmsg))
end

return nothing
return true
end

###
Expand Down
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3335,7 +3335,8 @@ function parse_variable(sys::AbstractSystem, str::AbstractString)
# I'd write a regex to validate `str`, but https://xkcd.com/1171/
str = strip(str)
derivative_level = 0
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) && endswith(str, ")")
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) &&
endswith(str, ")")
if cond1
derivative_level += 1
str = _string_view_inner(str, 2, 1)
Expand Down
15 changes: 14 additions & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
check_length = true,
warn_initialize_determined = true,
initialization_eqs = [],
fully_determined = false,
fully_determined = nothing,
check_units = true,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
Expand All @@ -1313,6 +1313,19 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
end

ts = get_tearing_state(isys)
if warn_initialize_determined &&
(unassigned_vars = StructuralTransformations.singular_check(ts); !isempty(unassigned_vars))
errmsg = """
The initialization system is structurally singular. Guess values may \
significantly affect the initial values of the ODE. The problematic variables \
are $unassigned_vars.

Note that the identification of problematic variables is a best-effort heuristic.
"""
@warn errmsg
end

uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])

# TODO: throw on uninitialized arrays
Expand Down
2 changes: 1 addition & 1 deletion src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ function process_SciMLProblem(
constructor, sys::AbstractSystem, u0map, pmap; build_initializeprob = true,
implicit_dae = false, t = nothing, guesses = AnyDict(),
warn_initialize_determined = true, initialization_eqs = [],
eval_expression = false, eval_module = @__MODULE__, fully_determined = false,
eval_expression = false, eval_module = @__MODULE__, fully_determined = nothing,
check_initialization_units = false, tofloat = true, use_union = false,
u0_constructor = identity, du0map = nothing, check_length = true,
symbolic_u0 = false, warn_cyclic_dependency = false,
Expand Down
9 changes: 7 additions & 2 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,11 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
dummy_derivative = true,
kwargs...)
check_consistency &= fully_determined
if fully_determined isa Bool
check_consistency &= fully_determined
else
check_consistency = true
end
has_io = io !== nothing
orig_inputs = Set()
if has_io
Expand All @@ -690,7 +694,8 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
end
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
if check_consistency
ModelingToolkit.check_consistency(state, orig_inputs)
fully_determined = ModelingToolkit.check_consistency(
state, orig_inputs; nothrow = fully_determined === nothing)
end
if fully_determined && dummy_derivative
sys = ModelingToolkit.dummy_derivative(
Expand Down
11 changes: 11 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,14 @@ end

@test_nowarn remake(prob, p = prob.p)
end

@testset "Singular initialization prints a warning" begin
@parameters g
@variables x(t) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x)) ~ λ * x
D(D(y)) ~ λ * y - g
x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)
@test_warn ["structurally singular", "initialization", "Guess", "heuristic"] ODEProblem(
pend, [x => 1, y => 0], (0.0, 1.5), [g => 1], guesses = [λ => 1])
end
Loading