Skip to content

Commit

Permalink
fix: properly handle rational functions in HomotopyContinuation
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 12, 2024
1 parent 552b039 commit e75d06f
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 23 deletions.
5 changes: 3 additions & 2 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ function MTK.HomotopyContinuationProblem(
return prob
end

function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...)
function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing;
fraction_cancel_fn = SymbolicUtils.simplify_fractions, kwargs...)
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
end
transformation = MTK.PolynomialTransformation(sys)
if transformation isa MTK.NotPolynomialError
return transformation
end
result = MTK.transform_system(sys, transformation)
result = MTK.transform_system(sys, transformation; fraction_cancel_fn)
if result isa MTK.NotPolynomialError
return result
end
Expand Down
63 changes: 44 additions & 19 deletions src/systems/nonlinear/homotopy_continuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ Transform the system `sys` with `transformation` and return a
`PolynomialTransformationResult`, or a `NotPolynomialError` if the system cannot
be transformed.
"""
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation)
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation;
fraction_cancel_fn = simplify_fractions)
subrules = transformation.substitution_rules
dvs = unknowns(sys)
eqs = full_equations(sys)
Expand All @@ -463,7 +464,7 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf
return NotPolynomialError(
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata)
end
num, den = handle_rational_polynomials(t, new_dvs)
num, den = handle_rational_polynomials(t, new_dvs; fraction_cancel_fn)
# make factors different elements, otherwise the nonzero factors artificially
# inflate the error of the zero factor.
if iscall(den) && operation(den) == *
Expand Down Expand Up @@ -492,43 +493,67 @@ $(TYPEDSIGNATURES)
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
express `x` as a single rational function with polynomial `num` and denominator `den`.
Return `(num, den)`.
Keyword arguments:
- `fraction_cancel_fn`: A function which takes a fraction (`operation(expr) == /`) and returns
a simplified symbolic quantity with common factors in the numerator and denominator are
cancelled. Defaults to `SymbolicUtils.simplify_fractions`, but can be changed to
`nothing` to improve performance on large polynomials at the cost of avoiding non-trivial
cancellation.
"""
function handle_rational_polynomials(x, wrt)
function handle_rational_polynomials(x, wrt; fraction_cancel_fn = simplify_fractions)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return x, 1
iscall(x) || return x, 1
contains_variable(x, wrt) || return x, 1
any(isequal(x), wrt) && return x, 1

# simplify_fractions cancels out some common factors
# and expands (a / b)^c to a^c / b^c, so we only need
# to handle these cases
x = simplify_fractions(x)
op = operation(x)
args = arguments(x)

if op == /
# numerator and denominator are trivial
num, den = args
# but also search for rational functions in numerator
n, d = handle_rational_polynomials(num, wrt)
num, den = n, den * d
elseif op == +
n1, d1 = handle_rational_polynomials(num, wrt; fraction_cancel_fn)
n2, d2 = handle_rational_polynomials(den, wrt; fraction_cancel_fn)
num, den = n1 * d2, d1 * n2
elseif (op == +) || (op == -)
num = 0
den = 1

# we don't need to do common denominator
# because we don't care about cases where denominator
# is zero. The expression is zero when all the numerators
# are zero.
if op == -
args[2] = -args[2]
end
for arg in args
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
num = num * d + n * den
den *= d
end
elseif op == ^
base, pow = args
num, den = handle_rational_polynomials(base, wrt; fraction_cancel_fn)
num ^= pow
den ^= pow
elseif op == *
num = 1
den = 1
for arg in args
n, d = handle_rational_polynomials(arg, wrt)
num += n
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
num *= n
den *= d
end
else
return x, 1
error("Unhandled operation in `handle_rational_polynomials`. This should never happen. Please open an issue in ModelingToolkit.jl with an MWE.")
end

if fraction_cancel_fn !== nothing
expr = fraction_cancel_fn(num / den)
if iscall(expr) && operation(expr) == /
num, den = arguments(expr)
else
num, den = expr, 1
end
end

# if the denominator isn't a polynomial in `wrt`, better to not include it
# to reduce the size of the gcd polynomial
if !contains_variable(den, wrt)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`")
end
if use_homotopy_continuation
prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...)
prob = safe_HomotopyContinuationProblem(
sys, u0map, parammap; check_length, kwargs...)
if prob isa HomotopyContinuationProblem
return prob
end
Expand Down
31 changes: 30 additions & 1 deletion test/extensions/homotopy_continuation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface
using SymbolicUtils
import ModelingToolkit as MTK
using LinearAlgebra
using Test
Expand Down Expand Up @@ -34,6 +35,8 @@ import HomotopyContinuation
sol = solve(prob2; threading = false)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid)0.0 atol=1e-10

@test NonlinearProblem(sys, u0; use_homotopy_continuation = false) isa NonlinearProblem
end

struct Wrapper
Expand Down Expand Up @@ -217,7 +220,17 @@ end
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
@test_nowarn solve(prob; threading = false)
@test SciMLBase.successful_retcode(solve(prob; threading = false))
end

@testset "Rational function forced to common denominators" begin
@variables x = 1
@mtkbuild sys = NonlinearSystem([0 ~ 1 / (1 + x) - x])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([-1.0], parameter_values(prob)) .≈ 0.0)
sol = solve(prob; threading = false)
@test SciMLBase.successful_retcode(sol)
@test 1 / (1 + sol.u[1]) - sol.u[1]0.0 atol=1e-10
end
end

Expand All @@ -229,3 +242,19 @@ end
@test sol[x] 2.0
@test sol[y] sin(2.0)
end

@testset "`fraction_cancel_fn`" begin
@variables x = 1
@named sys = NonlinearSystem([0 ~ ((x^2 - 5x + 6) / (x - 2) - 1) * (x^2 - 7x + 12) /
(x - 4)^3])
sys = complete(sys)

@testset "`simplify_fractions`" begin
prob = HomotopyContinuationProblem(sys, [])
@test prob.denominator([0.0], parameter_values(prob)) [4.0]
end
@testset "`nothing`" begin
prob = HomotopyContinuationProblem(sys, []; fraction_cancel_fn = nothing)
@test sort(prob.denominator([0.0], parameter_values(prob))) [2.0, 4.0^3]
end
end

0 comments on commit e75d06f

Please sign in to comment.