Skip to content

Commit

Permalink
Merge pull request #3249 from AayushSabharwal/as/hc-everywhere
Browse files Browse the repository at this point in the history
feat: use `HomotopyContinuationProblem` in `NonlinearProblem` if possible
  • Loading branch information
ChrisRackauckas authored Dec 9, 2024
2 parents ab5747f + f428df4 commit 5306a7a
Show file tree
Hide file tree
Showing 5 changed files with 609 additions and 405 deletions.
361 changes: 27 additions & 334 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,217 +11,6 @@ using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache,

const MTK = ModelingToolkit

function contains_variable(x, wrt)
any(y -> occursin(y, x), wrt)
end

"""
Possible reasons why a term is not polynomial
"""
MTK.EnumX.@enumx NonPolynomialReason begin
NonIntegerExponent
ExponentContainsUnknowns
BaseNotPolynomial
UnrecognizedOperation
end

function display_reason(reason::NonPolynomialReason.T, sym)
if reason == NonPolynomialReason.NonIntegerExponent
pow = arguments(sym)[2]
"In $sym: Exponent $pow is not an integer"
elseif reason == NonPolynomialReason.ExponentContainsUnknowns
pow = arguments(sym)[2]
"In $sym: Exponent $pow contains unknowns of the system"
elseif reason == NonPolynomialReason.BaseNotPolynomial
base = arguments(sym)[1]
"In $sym: Base $base is not a polynomial in the unknowns"
elseif reason == NonPolynomialReason.UnrecognizedOperation
op = operation(sym)
"""
In $sym: Operation $op is not recognized. Allowed polynomial operations are \
`*, /, +, -, ^`.
"""
else
error("This should never happen. Please open an issue in ModelingToolkit.jl.")
end
end

mutable struct PolynomialData
non_polynomial_terms::Vector{BasicSymbolic}
reasons::Vector{NonPolynomialReason.T}
has_parametric_exponent::Bool
end

PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false)

abstract type PolynomialTransformationError <: Exception end

struct MultivarTerm <: PolynomialTransformationError
term::Any
vars::Any
end

function Base.showerror(io::IO, err::MultivarTerm)
println(io,
"Cannot convert system to polynomial: Found term $(err.term) which is a function of multiple unknowns $(err.vars).")
end

struct MultipleTermsOfSameVar <: PolynomialTransformationError
terms::Any
var::Any
end

function Base.showerror(io::IO, err::MultipleTermsOfSameVar)
println(io,
"Cannot convert system to polynomial: Found multiple non-polynomial terms $(err.terms) involving the same unknown $(err.var).")
end

struct SymbolicSolveFailure <: PolynomialTransformationError
term::Any
var::Any
end

function Base.showerror(io::IO, err::SymbolicSolveFailure)
println(io,
"Cannot convert system to polynomial: Unable to symbolically solve $(err.term) for $(err.var).")
end

struct NemoNotLoaded <: PolynomialTransformationError end

function Base.showerror(io::IO, err::NemoNotLoaded)
println(io,
"ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again.")
end

struct VariablesAsPolyAndNonPoly <: PolynomialTransformationError
vars::Any
end

function Base.showerror(io::IO, err::VariablesAsPolyAndNonPoly)
println(io,
"Cannot convert convert system to polynomial: Variables $(err.vars) occur in both polynomial and non-polynomial terms in the system.")
end

struct NotPolynomialError <: Exception
transformation_err::Union{PolynomialTransformationError, Nothing}
eq::Vector{Equation}
data::Vector{PolynomialData}
end

function Base.showerror(io::IO, err::NotPolynomialError)
if err.transformation_err !== nothing
Base.showerror(io, err.transformation_err)
end
for (eq, data) in zip(err.eq, err.data)
if isempty(data.non_polynomial_terms)
continue
end
println(io,
"Equation $(eq) is not a polynomial in the unknowns for the following reasons:")
for (term, reason) in zip(data.non_polynomial_terms, data.reasons)
println(io, display_reason(reason, term))
end
end
end

function is_polynomial!(data, y, wrt)
process_polynomial!(data, y, wrt)
isempty(data.reasons)
end

"""
$(TYPEDSIGNATURES)
Return information about the polynmial `x` with respect to variables in `wrt`,
writing said information to `data`.
"""
function process_polynomial!(data::PolynomialData, x, wrt)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return true
iscall(x) || return true
contains_variable(x, wrt) || return true
any(isequal(x), wrt) && return true

if operation(x) in (*, +, -, /)
# `map` because `all` will early exit, but we want to search
# through everything to get all the non-polynomial terms
return all(map(y -> is_polynomial!(data, y, wrt), arguments(x)))
end
if operation(x) == (^)
b, p = arguments(x)
is_pow_integer = symtype(p) <: Integer
if !is_pow_integer
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.NonIntegerExponent)
end
if symbolic_type(p) != NotSymbolic()
data.has_parametric_exponent = true
end

exponent_has_unknowns = contains_variable(p, wrt)
if exponent_has_unknowns
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns)
end
base_polynomial = is_polynomial!(data, b, wrt)
return base_polynomial && !exponent_has_unknowns && is_pow_integer
end
push!(data.non_polynomial_terms, x)
push!(data.reasons, NonPolynomialReason.UnrecognizedOperation)
return false
end

"""
$(TYPEDSIGNATURES)
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
express `x` as a single rational function with polynomial `num` and denominator `den`.
Return `(num, den)`.
"""
function handle_rational_polynomials(x, wrt)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return x, 1
iscall(x) || return x, 1
contains_variable(x, wrt) || return x, 1
any(isequal(x), wrt) && return x, 1

# simplify_fractions cancels out some common factors
# and expands (a / b)^c to a^c / b^c, so we only need
# to handle these cases
x = simplify_fractions(x)
op = operation(x)
args = arguments(x)

if op == /
# numerator and denominator are trivial
num, den = args
# but also search for rational functions in numerator
n, d = handle_rational_polynomials(num, wrt)
num, den = n, den * d
elseif op == +
num = 0
den = 1

# we don't need to do common denominator
# because we don't care about cases where denominator
# is zero. The expression is zero when all the numerators
# are zero.
for arg in args
n, d = handle_rational_polynomials(arg, wrt)
num += n
den *= d
end
else
return x, 1
end
# if the denominator isn't a polynomial in `wrt`, better to not include it
# to reduce the size of the gcd polynomial
if !contains_variable(den, wrt)
return num / den, 1
end
return num, den
end

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -289,12 +78,6 @@ end

SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p

struct PolynomialTransformationData
new_var::BasicSymbolic
term::BasicSymbolic
inv_term::Vector
end

"""
$(TYPEDSIGNATURES)
Expand All @@ -312,128 +95,37 @@ Keyword arguments:
All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`.
"""
function MTK.HomotopyContinuationProblem(
sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false,
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
sys::NonlinearSystem, u0map, parammap = nothing; kwargs...)
prob = MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap; kwargs...)
prob isa MTK.HomotopyContinuationProblem || throw(prob)
return prob
end

function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...)
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
end

dvs = unknowns(sys)
# we need to consider `full_equations` because observed also should be
# polynomials (if used in equations) and we don't know if observed is used
# in denominator.
# This is not the most efficient, and would be improved significantly with
# CSE/hashconsing.
eqs = full_equations(sys)

polydata = map(eqs) do eq
data = PolynomialData()
process_polynomial!(data, eq.lhs, dvs)
process_polynomial!(data, eq.rhs, dvs)
data
transformation = MTK.PolynomialTransformation(sys)
if transformation isa MTK.NotPolynomialError
return transformation
end

has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata)

all_non_poly_terms = mapreduce(d -> d.non_polynomial_terms, vcat, polydata)
unique!(all_non_poly_terms)

var_to_nonpoly = Dict{BasicSymbolic, PolynomialTransformationData}()

is_poly = true
transformation_err = nothing
for t in all_non_poly_terms
# if the term involves multiple unknowns, we can't invert it
dvs_in_term = map(x -> occursin(x, t), dvs)
if count(dvs_in_term) > 1
transformation_err = MultivarTerm(t, dvs[dvs_in_term])
is_poly = false
break
end
# we already have a substitution solving for `var`
var = dvs[findfirst(dvs_in_term)]
if haskey(var_to_nonpoly, var) && !isequal(var_to_nonpoly[var].term, t)
transformation_err = MultipleTermsOfSameVar([t, var_to_nonpoly[var].term], var)
is_poly = false
break
end
# we want to solve `term - new_var` for `var`
new_var = gensym(Symbol(var))
new_var = unwrap(only(@variables $new_var))
invterm = Symbolics.ia_solve(
t - new_var, var; complex_roots = false, periodic_roots = false, warns = false)
# if we can't invert it, quit
if invterm === nothing || isempty(invterm)
transformation_err = SymbolicSolveFailure(t, var)
is_poly = false
break
end
# `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2`
# this just evaluates the constant expressions
invterm = Symbolics.substitute.(invterm, (Dict(),))
# RootsOf implies Symbolics couldn't solve the inner polynomial because
# `Nemo` wasn't loaded.
if any(x -> MTK.iscall(x) && MTK.operation(x) == Symbolics.RootsOf, invterm)
transformation_err = NemoNotLoaded()
is_poly = false
break
end
var_to_nonpoly[var] = PolynomialTransformationData(new_var, t, invterm)
end

if !is_poly
throw(NotPolynomialError(transformation_err, eqs, polydata))
end

subrules = Dict()
combinations = Vector[]
new_dvs = []
for x in dvs
if haskey(var_to_nonpoly, x)
_data = var_to_nonpoly[x]
subrules[_data.term] = _data.new_var
push!(combinations, _data.inv_term)
push!(new_dvs, _data.new_var)
else
push!(combinations, [x])
push!(new_dvs, x)
end
end
all_solutions = collect.(collect(Iterators.product(combinations...)))

denoms = []
eqs2 = map(eqs) do eq
t = eq.rhs - eq.lhs
t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs))
# the substituted variable occurs outside the substituted term
poly_and_nonpoly = map(dvs) do x
haskey(var_to_nonpoly, x) && occursin(x, t)
end
if any(poly_and_nonpoly)
throw(NotPolynomialError(
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata))
end

num, den = handle_rational_polynomials(t, new_dvs)
# make factors different elements, otherwise the nonzero factors artificially
# inflate the error of the zero factor.
if iscall(den) && operation(den) == *
for arg in arguments(den)
# ignore constant factors
symbolic_type(arg) == NotSymbolic() && continue
push!(denoms, abs(arg))
end
elseif symbolic_type(den) != NotSymbolic()
push!(denoms, abs(den))
end
return 0 ~ num
result = MTK.transform_system(sys, transformation)
if result isa MTK.NotPolynomialError
return result
end
MTK.HomotopyContinuationProblem(sys, transformation, result, u0map, parammap; kwargs...)
end

sys2 = MTK.@set sys.eqs = eqs2
MTK.@set! sys2.unknowns = new_dvs
# remove observed equations to avoid adding them in codegen
MTK.@set! sys2.observed = Equation[]
MTK.@set! sys2.substitutions = nothing
function MTK.HomotopyContinuationProblem(
sys::MTK.NonlinearSystem, transformation::MTK.PolynomialTransformation,
result::MTK.PolynomialTransformationResult, u0map,
parammap = nothing; eval_expression = false,
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
sys2 = result.sys
denoms = result.denominators
polydata = transformation.polydata
new_dvs = transformation.new_dvs
all_solutions = transformation.all_solutions

_, u0, p = MTK.process_SciMLProblem(
MTK.EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module)
Expand All @@ -443,10 +135,11 @@ function MTK.HomotopyContinuationProblem(
unpack_solution = MTK.build_explicit_observed_function(sys2, all_solutions)

hvars = symbolics_to_hc.(new_dvs)
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(new_dvs))

obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)

has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata)
if has_parametric_exponents
if warn_parametric_exponent
@warn """
Expand Down
Loading

0 comments on commit 5306a7a

Please sign in to comment.