From e75d06fb2f680ccda550f6ad37cc875497119736 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 11 Dec 2024 14:35:16 +0530 Subject: [PATCH] fix: properly handle rational functions in HomotopyContinuation --- ext/MTKHomotopyContinuationExt.jl | 5 +- .../nonlinear/homotopy_continuation.jl | 63 +++++++++++++------ src/systems/nonlinear/nonlinearsystem.jl | 3 +- test/extensions/homotopy_continuation.jl | 31 ++++++++- 4 files changed, 79 insertions(+), 23 deletions(-) diff --git a/ext/MTKHomotopyContinuationExt.jl b/ext/MTKHomotopyContinuationExt.jl index c4a090d9a8..8f17c05b18 100644 --- a/ext/MTKHomotopyContinuationExt.jl +++ b/ext/MTKHomotopyContinuationExt.jl @@ -101,7 +101,8 @@ function MTK.HomotopyContinuationProblem( return prob end -function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...) +function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; + fraction_cancel_fn = SymbolicUtils.simplify_fractions, kwargs...) if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`") end @@ -109,7 +110,7 @@ function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; k if transformation isa MTK.NotPolynomialError return transformation end - result = MTK.transform_system(sys, transformation) + result = MTK.transform_system(sys, transformation; fraction_cancel_fn) if result isa MTK.NotPolynomialError return result end diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index 03aeed1edf..00c6c7f1b0 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -442,7 +442,8 @@ Transform the system `sys` with `transformation` and return a `PolynomialTransformationResult`, or a `NotPolynomialError` if the system cannot be transformed. """ -function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation) +function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation; + fraction_cancel_fn = simplify_fractions) subrules = transformation.substitution_rules dvs = unknowns(sys) eqs = full_equations(sys) @@ -463,7 +464,7 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf return NotPolynomialError( VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata) end - num, den = handle_rational_polynomials(t, new_dvs) + num, den = handle_rational_polynomials(t, new_dvs; fraction_cancel_fn) # make factors different elements, otherwise the nonzero factors artificially # inflate the error of the zero factor. if iscall(den) && operation(den) == * @@ -492,43 +493,67 @@ $(TYPEDSIGNATURES) Given a `x`, a polynomial in variables in `wrt` which may contain rational functions, express `x` as a single rational function with polynomial `num` and denominator `den`. Return `(num, den)`. + +Keyword arguments: +- `fraction_cancel_fn`: A function which takes a fraction (`operation(expr) == /`) and returns + a simplified symbolic quantity with common factors in the numerator and denominator are + cancelled. Defaults to `SymbolicUtils.simplify_fractions`, but can be changed to + `nothing` to improve performance on large polynomials at the cost of avoiding non-trivial + cancellation. """ -function handle_rational_polynomials(x, wrt) +function handle_rational_polynomials(x, wrt; fraction_cancel_fn = simplify_fractions) x = unwrap(x) symbolic_type(x) == NotSymbolic() && return x, 1 iscall(x) || return x, 1 contains_variable(x, wrt) || return x, 1 any(isequal(x), wrt) && return x, 1 - # simplify_fractions cancels out some common factors - # and expands (a / b)^c to a^c / b^c, so we only need - # to handle these cases - x = simplify_fractions(x) op = operation(x) args = arguments(x) if op == / # numerator and denominator are trivial num, den = args - # but also search for rational functions in numerator - n, d = handle_rational_polynomials(num, wrt) - num, den = n, den * d - elseif op == + + n1, d1 = handle_rational_polynomials(num, wrt; fraction_cancel_fn) + n2, d2 = handle_rational_polynomials(den, wrt; fraction_cancel_fn) + num, den = n1 * d2, d1 * n2 + elseif (op == +) || (op == -) num = 0 den = 1 - - # we don't need to do common denominator - # because we don't care about cases where denominator - # is zero. The expression is zero when all the numerators - # are zero. + if op == - + args[2] = -args[2] + end + for arg in args + n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn) + num = num * d + n * den + den *= d + end + elseif op == ^ + base, pow = args + num, den = handle_rational_polynomials(base, wrt; fraction_cancel_fn) + num ^= pow + den ^= pow + elseif op == * + num = 1 + den = 1 for arg in args - n, d = handle_rational_polynomials(arg, wrt) - num += n + n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn) + num *= n den *= d end else - return x, 1 + error("Unhandled operation in `handle_rational_polynomials`. This should never happen. Please open an issue in ModelingToolkit.jl with an MWE.") + end + + if fraction_cancel_fn !== nothing + expr = fraction_cancel_fn(num / den) + if iscall(expr) && operation(expr) == / + num, den = arguments(expr) + else + num, den = expr, 1 + end end + # if the denominator isn't a polynomial in `wrt`, better to not include it # to reduce the size of the gcd polynomial if !contains_variable(den, wrt) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 2c788688bd..2751f4f5cd 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -501,7 +501,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`") end if use_homotopy_continuation - prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...) + prob = safe_HomotopyContinuationProblem( + sys, u0map, parammap; check_length, kwargs...) if prob isa HomotopyContinuationProblem return prob end diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index 3f4a3fbc71..9e15ea857e 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -1,4 +1,5 @@ using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface +using SymbolicUtils import ModelingToolkit as MTK using LinearAlgebra using Test @@ -34,6 +35,8 @@ import HomotopyContinuation sol = solve(prob2; threading = false) @test SciMLBase.successful_retcode(sol) @test norm(sol.resid)≈0.0 atol=1e-10 + + @test NonlinearProblem(sys, u0; use_homotopy_continuation = false) isa NonlinearProblem end struct Wrapper @@ -217,7 +220,17 @@ end @mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)]) prob = HomotopyContinuationProblem(sys, []) @test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0) - @test_nowarn solve(prob; threading = false) + @test SciMLBase.successful_retcode(solve(prob; threading = false)) + end + + @testset "Rational function forced to common denominators" begin + @variables x = 1 + @mtkbuild sys = NonlinearSystem([0 ~ 1 / (1 + x) - x]) + prob = HomotopyContinuationProblem(sys, []) + @test any(prob.denominator([-1.0], parameter_values(prob)) .≈ 0.0) + sol = solve(prob; threading = false) + @test SciMLBase.successful_retcode(sol) + @test 1 / (1 + sol.u[1]) - sol.u[1]≈0.0 atol=1e-10 end end @@ -229,3 +242,19 @@ end @test sol[x] ≈ √2.0 @test sol[y] ≈ sin(√2.0) end + +@testset "`fraction_cancel_fn`" begin + @variables x = 1 + @named sys = NonlinearSystem([0 ~ ((x^2 - 5x + 6) / (x - 2) - 1) * (x^2 - 7x + 12) / + (x - 4)^3]) + sys = complete(sys) + + @testset "`simplify_fractions`" begin + prob = HomotopyContinuationProblem(sys, []) + @test prob.denominator([0.0], parameter_values(prob)) ≈ [4.0] + end + @testset "`nothing`" begin + prob = HomotopyContinuationProblem(sys, []; fraction_cancel_fn = nothing) + @test sort(prob.denominator([0.0], parameter_values(prob))) ≈ [2.0, 4.0^3] + end +end