Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote.gradient failed with CubaCuhre and batch !=0 #49

Open
KirillZubov opened this issue Nov 21, 2020 · 1 comment
Open

Zygote.gradient failed with CubaCuhre and batch !=0 #49

KirillZubov opened this issue Nov 21, 2020 · 1 comment
Assignees

Comments

@KirillZubov
Copy link
Member

using Quadrature, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p))
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]

prob = QuadratureProblem(f,lb,ub,p; batch=0)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]

function testf(p)
    prob = QuadratureProblem(f,lb,ub,p, batch=0)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

prob = QuadratureProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]


function testf(p)
    prob = QuadratureProblem(f,lb,ub,p, batch=10)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

dp1 = Zygote.gradient(testf,p)
ERROR: MethodError: Cannot `convert` an object of type Array{Float64,1} to an object of type Float64
Closest candidates are:
  convert(::Type{T}, ::ArrayInterface.StaticInt{N}) where {T<:Number, N} at /Users/kirill/.julia/packages/ArrayInterface/rw2kK/src/static.jl:18
  convert(::Type{R}, ::T) where {R<:Real, T<:ReverseDiff.TrackedReal} at /Users/kirill/.julia/packages/ReverseDiff/jFRo1/src/tracked.jl:255
  convert(::Type{T}, ::Unitful.Quantity) where T<:Real at /Users/kirill/.julia/packages/Unitful/1t88N/src/conversion.jl:145
  ...
Stacktrace:
 [1] setindex! at ./array.jl:849 [inlined]
 [2] macro expansion at ./multidimensional.jl:802 [inlined]
 [3] macro expansion at ./cartesian.jl:64 [inlined]
 [4] macro expansion at ./multidimensional.jl:797 [inlined]
 [5] _unsafe_setindex!(::IndexLinear, ::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Base.Slice{Base.OneTo{Int64}}, ::Int64) at ./multidimensional.jl:789
 [6] _setindex! at ./multidimensional.jl:785 [inlined]
 [7] setindex!(::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Function, ::Int64) at ./abstractarray.jl:1153
 [8] (::Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}})(::Array{Float64,2}, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:524
 [9] __solvebp_call(::QuadratureProblem{false,Array{Float64,1},Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}}}, ::CubaCuhre, ::Quadrature.ReCallVJP{Quadrature.ZygoteVJP}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}; reltol::Float64, abstol::Float64, maxiters::Int64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:437
 [10] (::Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:546
 [11] #65#back at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [12] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
 [13] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{Quadrature.var"#65#back#64"{Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}}},Tuple{NTuple{8,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [14] #solve#10 at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:149 [inlined]
 [15] (::typeof((#solve#10)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [16] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
 [17] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof((#solve#10)),Tuple{NTuple{5,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [18] (::typeof((solve##kw)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [19] testf at ./none:3 [inlined]
 [20] (::typeof((testf)))(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#41#42"{typeof((testf))})(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45
 [22] gradient(::Function, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [23] top-level scope at none:1
@lxvm
Copy link
Collaborator

lxvm commented Mar 9, 2024

I think there is a bug in the MWE since a scalar-valued f is incompatible with batching. Namely, the batch integrand should return a vector whose length matches the last axis of the input points (see the FAQ for more details).

I've adapted the MWE to the current version of Integrals, modified the integrand to do what I think was intended, and can confirm it works on the master branch

using Integrals, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p); dims=1)
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]

prob = IntegralProblem(f,lb,ub,p)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]

function testf(p)
    prob = IntegralProblem(f,lb,ub,p)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

prob = IntegralProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]


function testf(p)
    prob = IntegralProblem(f,lb,ub,p, batch=10)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

dp1 = Zygote.gradient(testf,p)

Since there are some bugs in the current release that affect CubaCuhre and they are fixed on the master branch, I'll wait to close the issue until the next release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants