From 306ff2dfd81e3ec7dbb005493da8a9be67f3c7d1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 14:22:18 +0530 Subject: [PATCH 01/13] feat: add `get_abstol` and `get_reltol` interface methods --- lib/NonlinearSolveBase/src/abstract_types.jl | 9 +++++++++ lib/NonlinearSolveBase/src/forward_diff.jl | 7 +++++++ lib/NonlinearSolveBase/src/polyalg.jl | 7 +++++++ lib/NonlinearSolveBase/src/solve.jl | 7 +++++++ lib/NonlinearSolveBase/src/termination_conditions.jl | 3 +++ lib/NonlinearSolveQuasiNewton/src/solve.jl | 7 +++++++ 6 files changed, 40 insertions(+) diff --git a/lib/NonlinearSolveBase/src/abstract_types.jl b/lib/NonlinearSolveBase/src/abstract_types.jl index 6829e19c3..00b4c127b 100644 --- a/lib/NonlinearSolveBase/src/abstract_types.jl +++ b/lib/NonlinearSolveBase/src/abstract_types.jl @@ -259,6 +259,8 @@ Abstract Type for all NonlinearSolveBase Caches. `u0` and any additional keyword arguments. - `SciMLBase.isinplace(cache)`: whether or not the solver is inplace. - `CommonSolve.step!(cache; kwargs...)`: See [`CommonSolve.step!`](@ref) for more details. + - `get_abstol(cache)`: get the `abstol` provided to the cache. + - `get_reltol(cache)`: get the `reltol` provided to the cache. Additionally implements `SymbolicIndexingInterface` interface Functions. @@ -304,6 +306,13 @@ end SciMLBase.isinplace(cache::AbstractNonlinearSolveCache) = SciMLBase.isinplace(cache.prob) +function get_abstol(cache::AbstractNonlinearSolveCache) + get_abstol(cache.termination_cache) +end +function get_reltol(cache::AbstractNonlinearSolveCache) + get_reltol(cache.termination_cache) +end + ## SII Interface SII.symbolic_container(cache::AbstractNonlinearSolveCache) = cache.prob SII.parameter_values(cache::AbstractNonlinearSolveCache) = SII.parameter_values(cache.prob) diff --git a/lib/NonlinearSolveBase/src/forward_diff.jl b/lib/NonlinearSolveBase/src/forward_diff.jl index a588aa52d..e780bf554 100644 --- a/lib/NonlinearSolveBase/src/forward_diff.jl +++ b/lib/NonlinearSolveBase/src/forward_diff.jl @@ -6,3 +6,10 @@ values_p partials_p end + +function NonlinearSolveBase.get_abstol(cache::NonlinearSolveForwardDiffCache) + NonlinearSolveBase.get_abstol(cache.cache) +end +function NonlinearSolveBase.get_reltol(cache::NonlinearSolveForwardDiffCache) + NonlinearSolveBase.get_reltol(cache.cache) +end diff --git a/lib/NonlinearSolveBase/src/polyalg.jl b/lib/NonlinearSolveBase/src/polyalg.jl index c2101af0e..ea873064c 100644 --- a/lib/NonlinearSolveBase/src/polyalg.jl +++ b/lib/NonlinearSolveBase/src/polyalg.jl @@ -61,6 +61,13 @@ end alias_u0::Bool end +function NonlinearSolveBase.get_abstol(cache::NonlinearSolvePolyAlgorithmCache) + NonlinearSolveBase.get_abstol(cache.caches[cache.current]) +end +function NonlinearSolveBase.get_reltol(cache::NonlinearSolvePolyAlgorithmCache) + NonlinearSolveBase.get_reltol(cache.caches[cache.current]) +end + function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache) return cache.caches[cache.current] end diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 66ad6a0e4..b5d94af8b 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -248,6 +248,13 @@ end kwargs::Any end +function get_abstol(cache::NonlinearSolveNoInitCache) + get(cache.kwargs, :abstol, get_tolerance(nothing, eltype(cache.prob.u0))) +end +function get_reltol(cache::NonlinearSolveNoInitCache) + get(cache.kwargs, :reltol, get_tolerance(nothing, eltype(cache.prob.u0))) +end + get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) function SciMLBase.reinit!( diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index cca9134d1..e6ab4a579 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -23,6 +23,9 @@ const AbsNormModes = Union{ u_diff_cache::uType end +get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol +get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol + function update_u!!(cache::NonlinearTerminationModeCache, u) cache.u === nothing && return if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u) diff --git a/lib/NonlinearSolveQuasiNewton/src/solve.jl b/lib/NonlinearSolveQuasiNewton/src/solve.jl index c52a425ae..3e413b2f6 100644 --- a/lib/NonlinearSolveQuasiNewton/src/solve.jl +++ b/lib/NonlinearSolveQuasiNewton/src/solve.jl @@ -95,6 +95,13 @@ end kwargs end +function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache) + NonlinearSolveBase.get_abstol(cache.termination_cache) +end +function NonlinearSolveBase.get_reltol(cache::QuasiNewtonCache) + NonlinearSolveBase.get_reltol(cache.termination_cache) +end + function InternalAPI.reinit_self!( cache::QuasiNewtonCache, args...; p = cache.p, u0 = cache.u, alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs... From 1cbb4c48d2341dbec0c09c8196f0c94878162c9e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 14:22:44 +0530 Subject: [PATCH 02/13] feat: add `initialize_cache!` --- .../src/NonlinearSolveBase.jl | 1 + lib/NonlinearSolveBase/src/initialization.jl | 60 +++++++++++++++++++ src/default.jl | 7 +++ 3 files changed, 68 insertions(+) create mode 100644 lib/NonlinearSolveBase/src/initialization.jl diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index f45ba9242..9087a5c98 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -55,6 +55,7 @@ include("descent/damped_newton.jl") include("descent/dogleg.jl") include("descent/geodesic_acceleration.jl") +include("initialization.jl") include("solve.jl") include("forward_diff.jl") diff --git a/lib/NonlinearSolveBase/src/initialization.jl b/lib/NonlinearSolveBase/src/initialization.jl new file mode 100644 index 000000000..e3612e2cb --- /dev/null +++ b/lib/NonlinearSolveBase/src/initialization.jl @@ -0,0 +1,60 @@ +struct NonlinearSolveDefaultInit <: SciMLBase.DAEInitializationAlgorithm end + +function run_initialization!(cache, initializealg = cache.initializealg, prob = cache.prob) + _run_initialization!(cache, initializealg, prob, Val(SciMLBase.isinplace(cache))) +end + +function _run_initialization!( + cache, ::NonlinearSolveDefaultInit, prob, isinplace::Union{Val{true}, Val{false}}) + if SciMLBase.has_initialization_data(prob.f) && + prob.f.initialization_data isa SciMLBase.OverrideInitData + return _run_initialization!(cache, SciMLBase.OverrideInit(), prob, isinplace) + end + return cache, true +end + +function _run_initialization!(cache, initalg::SciMLBase.OverrideInit, prob, + isinplace::Union{Val{true}, Val{false}}) + if cache isa AbstractNonlinearSolveCache && isdefined(cache.alg, :autodiff) + autodiff = cache.alg.autodiff + else + autodiff = ADTypes.AutoForwardDiff() + end + alg = initialization_alg(prob.f.initialization_data.initializeprob, autodiff) + if alg === nothing && cache isa AbstractNonlinearSolveCache + alg = cache.alg + end + u0, p, success = SciMLBase.get_initial_values( + prob, cache, prob.f, initalg, isinplace; nlsolve_alg = alg, + abstol = get_abstol(cache), reltol = get_reltol(cache)) + cache = update_initial_values!(cache, u0, p) + if cache isa AbstractNonlinearSolveCache && isdefined(cache, :retcode) && !success + cache.retcode = ReturnCode.InitialFailure + end + + return cache, success +end + +function get_abstol(prob::AbstractNonlinearProblem) + get_tolerance(get(prob.kwargs, :abstol, nothing), eltype(SII.state_values(prob))) +end +function get_reltol(prob::AbstractNonlinearProblem) + get_tolerance(get(prob.kwargs, :reltol, nothing), eltype(SII.state_values(prob))) +end + +initialization_alg(initprob, autodiff) = nothing + +function update_initial_values!(cache::AbstractNonlinearSolveCache, u0, p) + InternalAPI.reinit!(cache; u0, p) + cache.prob = SciMLBase.remake(cache.prob; u0, p) + return cache +end + +function update_initial_values!(prob::AbstractNonlinearProblem, u0, p) + return SciMLBase.remake(prob; u0, p) +end + +function _run_initialization!( + cache::AbstractNonlinearSolveCache, ::SciMLBase.NoInit, prob, isinplace) + return cache, true +end diff --git a/src/default.jl b/src/default.jl index 6021a98b1..734305ef6 100644 --- a/src/default.jl +++ b/src/default.jl @@ -50,3 +50,10 @@ function SciMLBase.__solve( prob, FastShortcutNLLSPolyalg(eltype(prob.u0)), args...; kwargs... ) end + +function NonlinearSolveBase.initialization_alg(::AbstractNonlinearProblem, autodiff) + FastShortcutNonlinearPolyalg(; autodiff) +end +function NonlinearSolveBase.initialization_alg(::NonlinearLeastSquaresProblem, autodiff) + FastShortcutNLLSPolyalg(; autodiff) +end From 63384cdb90116cf044a6130a1ec37dff4934f30f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 14:23:14 +0530 Subject: [PATCH 03/13] feat: implement initialization for polyalg cache --- lib/NonlinearSolveBase/src/polyalg.jl | 22 ++++++++++++++++---- lib/NonlinearSolveBase/src/solve.jl | 30 +++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/lib/NonlinearSolveBase/src/polyalg.jl b/lib/NonlinearSolveBase/src/polyalg.jl index ea873064c..9e6510093 100644 --- a/lib/NonlinearSolveBase/src/polyalg.jl +++ b/lib/NonlinearSolveBase/src/polyalg.jl @@ -59,6 +59,16 @@ end u0 u0_aliased alias_u0::Bool + + initializealg +end + +function update_initial_values!(cache::NonlinearSolvePolyAlgorithmCache, u0, p) + foreach(cache.caches) do subcache + update_initial_values!(subcache, u0, p) + end + cache.prob = SciMLBase.remake(cache.prob; u0, p) + return cache end function NonlinearSolveBase.get_abstol(cache::NonlinearSolvePolyAlgorithmCache) @@ -104,7 +114,8 @@ end function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...; stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000, - internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs... + internalnorm = L2_NORM, alias_u0 = false, verbose = true, + initializealg = NonlinearSolveDefaultInit(), kwargs... ) if alias_u0 && !ArrayInterface.ismutable(prob.u0) verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \ @@ -116,18 +127,21 @@ function SciMLBase.__init( u0_aliased = alias_u0 ? copy(u0) : u0 alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased)) - return NonlinearSolvePolyAlgorithmCache( + cache = NonlinearSolvePolyAlgorithmCache( alg.static_length, prob, map(alg.algs) do solver SciMLBase.__init( prob, solver, args...; - stats, maxtime, internalnorm, alias_u0, verbose, kwargs... + stats, maxtime, internalnorm, alias_u0, verbose, + initializealg = SciMLBase.NoInit(), kwargs... ) end, alg, -1, alg.start_index, 0, stats, 0.0, maxtime, ReturnCode.Default, false, maxiters, internalnorm, - u0, u0_aliased, alias_u0 + u0, u0_aliased, alias_u0, initializealg ) + run_initialization!(cache) + return cache end @generated function InternalAPI.step!( diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index b5d94af8b..0784b9c92 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -40,6 +40,17 @@ end sol_syms = [gensym("sol") for i in 1:N] u_result_syms = [gensym("u_result") for i in 1:N] + push!(calls, + quote + if cache.retcode == ReturnCode.InitialFailure + u = $(SII.state_values)(cache) + return build_solution_less_specialize( + cache.prob, cache.alg, u, $(Utils.evaluate_f)(cache.prob, u); + retcode = cache.retcode + ) + end + end) + for i in 1:N push!(calls, quote @@ -111,7 +122,8 @@ end @generated function __generated_polysolve( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...; - stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs... + stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, + initializealg = NonlinearSolveDefaultInit(), kwargs... ) where {N} sol_syms = [gensym("sol") for _ in 1:N] prob_syms = [gensym("prob") for _ in 1:N] @@ -123,9 +135,23 @@ end immutable (checked using `ArrayInterface.ismutable`)." alias_u0 = false # If immutable don't care about aliasing end + end] + + push!(calls, + quote + prob, success = $(run_initialization!)(prob, initializealg, prob) + if !success + u = $(SII.state_values)(prob) + return build_solution_less_specialize( + prob, alg, u, $(Utils.evaluate_f)(prob, u); + retcode = $(ReturnCode.InitialFailure)) + end + end) + + push!(calls, quote u0 = prob.u0 u0_aliased = alias_u0 ? zero(u0) : u0 - end] + end) for i in 1:N cur_sol = sol_syms[i] push!(calls, From 8d861aaeec4b8663721cb0e2249e394f9cb27b08 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 14 Dec 2024 13:30:36 +0530 Subject: [PATCH 04/13] feat: implement initialization for no-init cache --- lib/NonlinearSolveBase/src/solve.jl | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 0784b9c92..4035422cc 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -272,6 +272,9 @@ end alg args kwargs::Any + initializealg + + retcode::ReturnCode.T end function get_abstol(cache::NonlinearSolveNoInitCache) @@ -297,11 +300,20 @@ end function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; + initializealg = NonlinearSolveDefaultInit(), kwargs... ) - return NonlinearSolveNoInitCache(prob, alg, args, kwargs) + cache = NonlinearSolveNoInitCache( + prob, alg, args, kwargs, initializealg, ReturnCode.Default) + run_initialization!(cache) + return cache end function CommonSolve.solve!(cache::NonlinearSolveNoInitCache) + if cache.retcode == ReturnCode.InitialFailure + u = SII.state_values(cache) + return SciMLBase.build_solution( + cache.prob, cache.alg, u, Utils.evaluate_f(cache.prob, u); cache.retcode) + end return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...) end From 5521d40b60e0e48d6b2d73958c7151701fc2aa01 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 14:23:38 +0530 Subject: [PATCH 05/13] feat: implement initialization for first order cache --- lib/NonlinearSolveFirstOrder/src/solve.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/NonlinearSolveFirstOrder/src/solve.jl b/lib/NonlinearSolveFirstOrder/src/solve.jl index c9c8c77a8..ec0d54da6 100644 --- a/lib/NonlinearSolveFirstOrder/src/solve.jl +++ b/lib/NonlinearSolveFirstOrder/src/solve.jl @@ -87,6 +87,8 @@ end retcode::ReturnCode.T force_stop::Bool kwargs + + initializealg end function InternalAPI.reinit_self!( @@ -121,7 +123,7 @@ function SciMLBase.__init( stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, maxtime = nothing, termination_condition = nothing, internalnorm = L2_NORM, - linsolve_kwargs = (;), kwargs... + linsolve_kwargs = (;), initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs... ) @set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) provided_jvp_autodiff = alg.jvp_autodiff !== nothing @@ -206,13 +208,17 @@ function SciMLBase.__init( prob, alg, u, fu, J, du; kwargs... ) - return GeneralizedFirstOrderAlgorithmCache( + cache = GeneralizedFirstOrderAlgorithmCache( fu, u, u_cache, prob.p, du, J, alg, prob, globalization, jac_cache, descent_cache, linesearch_cache, trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer, - 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs + 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs, + initializealg ) + NonlinearSolveBase.run_initialization!(cache) end + + return cache end function InternalAPI.step!( From a2d75ced59f9ac8668e3059fc757693afd844f55 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 14:23:51 +0530 Subject: [PATCH 06/13] feat: implement initialization for `QuasiNewtonCache` --- lib/NonlinearSolveQuasiNewton/src/solve.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/lib/NonlinearSolveQuasiNewton/src/solve.jl b/lib/NonlinearSolveQuasiNewton/src/solve.jl index 3e413b2f6..53289c117 100644 --- a/lib/NonlinearSolveQuasiNewton/src/solve.jl +++ b/lib/NonlinearSolveQuasiNewton/src/solve.jl @@ -93,6 +93,9 @@ end force_stop::Bool force_reinit::Bool kwargs + + # Initialization + initializealg end function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache) @@ -137,7 +140,8 @@ function SciMLBase.__init( stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxtime = nothing, maxiters = 1000, abstol = nothing, reltol = nothing, linsolve_kwargs = (;), termination_condition = nothing, - internalnorm::F = L2_NORM, kwargs... + internalnorm::F = L2_NORM, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), + kwargs... ) where {F} timer = get_timer_output() @static_timeit timer "cache construction" begin @@ -211,15 +215,18 @@ function SciMLBase.__init( uses_jacobian_inverse = inverted_jac, kwargs... ) - return QuasiNewtonCache( + cache = QuasiNewtonCache( fu, u, u_cache, prob.p, du, J, alg, prob, globalization, initialization_cache, descent_cache, linesearch_cache, trustregion_cache, update_rule_cache, reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0, termination_cache, trace, - ReturnCode.Default, false, false, kwargs + ReturnCode.Default, false, false, kwargs, initializealg ) + NonlinearSolveBase.run_initialization!(cache) end + + return cache end function InternalAPI.step!( From a632a51ff6d2f7b48b7035240f7ff7b7c78a8e52 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 14:24:12 +0530 Subject: [PATCH 07/13] feat: implement initialization for `GeneralizedDFSaneCache` --- lib/NonlinearSolveSpectralMethods/src/solve.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lib/NonlinearSolveSpectralMethods/src/solve.jl b/lib/NonlinearSolveSpectralMethods/src/solve.jl index b3a7d216e..899d08a5d 100644 --- a/lib/NonlinearSolveSpectralMethods/src/solve.jl +++ b/lib/NonlinearSolveSpectralMethods/src/solve.jl @@ -68,6 +68,8 @@ end retcode::ReturnCode.T force_stop::Bool kwargs + + initializealg end function InternalAPI.reinit_self!( @@ -112,7 +114,7 @@ function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...; stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, termination_condition = nothing, - maxtime = nothing, kwargs... + maxtime = nothing, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs... ) timer = get_timer_output() @@ -145,13 +147,16 @@ function SciMLBase.__init( σ_n = T(alg.σ_1) end - return GeneralizedDFSaneCache( + cache = GeneralizedDFSaneCache( fu, fu_cache, u, u_cache, prob.p, du, alg, prob, σ_n, T(alg.σ_min), T(alg.σ_max), linesearch_cache, stats, 0, maxiters, maxtime, timer, 0.0, - tc_cache, trace, ReturnCode.Default, false, kwargs + tc_cache, trace, ReturnCode.Default, false, kwargs, initializealg ) + NonlinearSolveBase.run_initialization!(cache) end + + return cache end function InternalAPI.step!( From 8434699454536d39cccccda90dcef1088a2a8183 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 14 Dec 2024 13:26:24 +0530 Subject: [PATCH 08/13] fix: fix `SII.parameter_values` --- lib/NonlinearSolveBase/src/abstract_types.jl | 2 +- lib/NonlinearSolveBase/src/polyalg.jl | 3 +++ lib/NonlinearSolveBase/src/solve.jl | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/src/abstract_types.jl b/lib/NonlinearSolveBase/src/abstract_types.jl index 00b4c127b..f43d59012 100644 --- a/lib/NonlinearSolveBase/src/abstract_types.jl +++ b/lib/NonlinearSolveBase/src/abstract_types.jl @@ -315,7 +315,7 @@ end ## SII Interface SII.symbolic_container(cache::AbstractNonlinearSolveCache) = cache.prob -SII.parameter_values(cache::AbstractNonlinearSolveCache) = SII.parameter_values(cache.prob) +SII.parameter_values(cache::AbstractNonlinearSolveCache) = cache.p SII.state_values(cache::AbstractNonlinearSolveCache) = get_u(cache) function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol) diff --git a/lib/NonlinearSolveBase/src/polyalg.jl b/lib/NonlinearSolveBase/src/polyalg.jl index 9e6510093..935019e97 100644 --- a/lib/NonlinearSolveBase/src/polyalg.jl +++ b/lib/NonlinearSolveBase/src/polyalg.jl @@ -84,6 +84,9 @@ end function SII.state_values(cache::NonlinearSolvePolyAlgorithmCache) SII.state_values(SII.symbolic_container(cache)) end +function SII.parameter_values(cache::NonlinearSolvePolyAlgorithmCache) + SII.parameter_values(SII.symbolic_container(cache)) +end function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache) println(io, "NonlinearSolvePolyAlgorithmCache with \ diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 4035422cc..5386df90f 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -284,6 +284,8 @@ function get_reltol(cache::NonlinearSolveNoInitCache) get(cache.kwargs, :reltol, get_tolerance(nothing, eltype(cache.prob.u0))) end +SII.parameter_values(cache::NonlinearSolveNoInitCache) = SII.parameter_values(cache.prob) + get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) function SciMLBase.reinit!( From a6e2f32467615e907c8969b1847d7b43eec8e025 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:54:10 +0530 Subject: [PATCH 09/13] feat: implement initialization for `SimpleNonlinearSolve` --- lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 3d8258f7e..2a87eb54e 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -59,6 +59,12 @@ function CommonSolve.solve( prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs... ) + cache = SciMLBase.__init(prob, alg, args...; kwargs...) + prob = cache.prob + if cache.retcode == ReturnCode.InitialFailure + return SciMLBase.build_solution(prob, alg, prob.u0, + NonlinearSolveBase.Utils.evaluate_f(prob, prob.u0); cache.retcode) + end prob = convert(ImmutableNonlinearProblem, prob) return solve(prob, alg, args...; kwargs...) end @@ -97,6 +103,12 @@ function CommonSolve.solve( alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs... ) + cache = SciMLBase.__init(prob, alg, args...; kwargs...) + prob = cache.prob + if cache.retcode == ReturnCode.InitialFailure + return SciMLBase.build_solution(prob, alg, prob.u0, + NonlinearSolveBase.Utils.evaluate_f(prob, prob.u0); cache.retcode) + end if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end From a64124732a1821adb3f891b919d987f1bcac3520 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:54:49 +0530 Subject: [PATCH 10/13] fix: fix `InternalAPI.reinit_self!` for `GeneralizedDFSaneCache` --- lib/NonlinearSolveSpectralMethods/src/solve.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/NonlinearSolveSpectralMethods/src/solve.jl b/lib/NonlinearSolveSpectralMethods/src/solve.jl index 899d08a5d..fd71527c0 100644 --- a/lib/NonlinearSolveSpectralMethods/src/solve.jl +++ b/lib/NonlinearSolveSpectralMethods/src/solve.jl @@ -77,6 +77,7 @@ function InternalAPI.reinit_self!( alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs... ) Utils.reinit_common!(cache, u0, p, alias_u0) + T = eltype(u0) if cache.alg.σ_1 === nothing σ_n = Utils.safe_dot(cache.u, cache.u) / Utils.safe_dot(cache.u, cache.fu) From 279727aa80e4f5ba0e8c007b6a7bf19abb4e37c8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:55:49 +0530 Subject: [PATCH 11/13] fix: fix `SII.state_values` for `NoInitCache` --- lib/NonlinearSolveBase/src/solve.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 5386df90f..c763cf192 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -285,6 +285,7 @@ function get_reltol(cache::NonlinearSolveNoInitCache) end SII.parameter_values(cache::NonlinearSolveNoInitCache) = SII.parameter_values(cache.prob) +SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) From d0dc96c9938fa3748b571ace44e9a17eff792307 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:55:59 +0530 Subject: [PATCH 12/13] feat: run initialiation on `solve!` --- lib/NonlinearSolveBase/src/solve.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index c763cf192..41dc1d5fb 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -7,6 +7,13 @@ function SciMLBase.__solve( end function CommonSolve.solve!(cache::AbstractNonlinearSolveCache) + if cache.retcode == ReturnCode.InitialFailure + return SciMLBase.build_solution( + cache.prob, cache.alg, get_u(cache), get_fu(cache); + cache.retcode, cache.stats, cache.trace + ) + end + while not_terminated(cache) CommonSolve.step!(cache) end From a56f4e3be6ee8a976d2908e3173c52662f329f9c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 14 Dec 2024 11:25:44 +0530 Subject: [PATCH 13/13] build: bump SciMLBase compat in NonlinearSolveBase --- lib/NonlinearSolveBase/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 2f5404dc7..cc60c308e 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -68,7 +68,7 @@ MaybeInplace = "0.1.4" Preferences = "1.4" Printf = "1.10" RecursiveArrayTools = "3" -SciMLBase = "2.58" +SciMLBase = "2.68.1" SciMLJacobianOperators = "0.1.1" SciMLOperators = "0.3.10" SparseArrays = "1.10"