Skip to content

Commit

Permalink
Merge pull request #3090 from AayushSabharwal/as/promote-resid-prototype
Browse files Browse the repository at this point in the history
fix: promote `resid_prototype` using tunables
  • Loading branch information
ChrisRackauckas authored Oct 4, 2024
2 parents 23b7b2e + 794a421 commit 40b1f7c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
35 changes: 28 additions & 7 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ function SciMLBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...)
end

function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
ps = parameters(sys), u0 = nothing;
ps = parameters(sys), u0 = nothing, p = nothing;
version = nothing,
jac = false,
eval_expression = false,
Expand Down Expand Up @@ -327,11 +327,22 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

if length(dvs) == length(equations(sys))
resid_prototype = nothing
else
u0ElType = u0 === nothing ? Float64 : eltype(u0)
if SciMLStructures.isscimlstructure(p)
u0ElType = promote_type(
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
u0ElType)
end
resid_prototype = zeros(u0ElType, length(equations(sys)))
end

NonlinearFunction{iip}(f,
sys = sys,
jac = _jac === nothing ? nothing : _jac,
resid_prototype = length(dvs) == length(equations(sys)) ? nothing :
zeros(length(equations(sys))),
resid_prototype = resid_prototype,
jac_prototype = sparse ?
similar(calculate_jacobian(sys, sparse = sparse),
Float64) : nothing,
Expand All @@ -355,7 +366,7 @@ variable and parameter vectors, respectively.
struct NonlinearFunctionExpr{iip} end

function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
ps = parameters(sys), u0 = nothing;
ps = parameters(sys), u0 = nothing, p = nothing;
version = nothing, tgrad = false,
jac = false,
linenumbers = false,
Expand All @@ -376,8 +387,18 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
end

jp_expr = sparse ? :(similar($(get_jac(sys)[]), Float64)) : :nothing
resid_expr = length(dvs) == length(equations(sys)) ? :nothing :
:(zeros($(length(equations(sys)))))
if length(dvs) == length(equations(sys))
resid_expr = :nothing
else
u0ElType = u0 === nothing ? Float64 : eltype(u0)
if SciMLStructures.isscimlstructure(p)
u0ElType = promote_type(
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
u0ElType)
end

resid_expr = :(zeros($u0ElType, $(length(equations(sys)))))
end
ex = quote
f = $f
jac = $_jac
Expand Down Expand Up @@ -412,7 +433,7 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
check_eqs_u0(eqs, dvs, u0; kwargs...)
end

f = constructor(sys, dvs, ps, u0; jac = jac, checkbounds = checkbounds,
f = constructor(sys, dvs, ps, u0, p; jac = jac, checkbounds = checkbounds,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression, eval_module = eval_module,
kwargs...)
Expand Down
19 changes: 19 additions & 0 deletions test/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ModelingToolkit: get_metadata
using DiffEqBase, SparseArrays
using Test
using NonlinearSolve
using ForwardDiff
using ModelingToolkit: value
using ModelingToolkit: get_default_or_guess, MTKParameters

Expand Down Expand Up @@ -325,3 +326,21 @@ end
prob = @test_nowarn NonlinearProblem(sys, nothing)
@test_nowarn solve(prob)
end

@testset "resid_prototype when system has no unknowns and an equation" begin
@variables x
@parameters p
@named sys = NonlinearSystem([x ~ 1, x^2 - p ~ 0])
for sys in [
structural_simplify(sys, fully_determined = false),
structural_simplify(sys, fully_determined = false, split = false)
]
@test length(equations(sys)) == 1
@test length(unknowns(sys)) == 0
T = typeof(ForwardDiff.Dual(1.0))
prob = NonlinearProblem(sys, [], [p => ForwardDiff.Dual(1.0)]; check_length = false)
@test prob.f(Float64[], prob.p) isa Vector{T}
@test prob.f.resid_prototype isa Vector{T}
@test_nowarn solve(prob)
end
end

0 comments on commit 40b1f7c

Please sign in to comment.