-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding complex broadcasting for gradients on the GPU #1324
Changes from 9 commits
807d689
2972faf
51dc882
0635ba4
739e896
a0e21e6
5a83493
6742644
2aa06c6
851ab33
f42d940
95a6b5b
b29f090
40fdb29
15c33ad
5e53ada
2c4857b
9fc2180
c685798
efc4f67
51e3ba3
c888db8
83ed917
7b0044b
2bb3b65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ end | |
g_gpu = gradient(x -> v(x, 7), a_gpu)[1] | ||
@test g_gpu isa CuArray | ||
@test g_gpu |> collect ≈ g | ||
|
||
w(x) = sum(broadcast(log, x)) | ||
g = gradient(x -> w(x), a)[1] | ||
g_gpu = gradient(x -> w(x), a_gpu)[1] | ||
|
@@ -38,7 +38,7 @@ end | |
@test gradient(x -> sum(x .> 3), a_gpu) == (nothing,) | ||
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression | ||
@test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 | ||
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] | ||
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] | ||
|
||
# Projection: eltype preservation: | ||
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} | ||
|
@@ -90,40 +90,40 @@ end | |
@testset "gradient algebra" begin | ||
w, b = rand(2) |> cu, rand(2) |> cu | ||
x1, x2 = rand(2) |> cu, rand(2) |> cu | ||
gs1 = gradient(() -> sum(w .* x1), Params([w])) | ||
gs2 = gradient(() -> sum(w .* x2), Params([w])) | ||
|
||
gs1 = gradient(() -> sum(w .* x1), Params([w])) | ||
gs2 = gradient(() -> sum(w .* x2), Params([w])) | ||
|
||
@test .- gs1 isa Grads | ||
@test gs1 .- gs2 isa Grads | ||
@test gs1 .- gs2 isa Grads | ||
@test .+ gs1 isa Grads | ||
@test gs1 .+ gs2 isa Grads | ||
@test 2 .* gs1 isa Grads | ||
@test gs1 .+ gs2 isa Grads | ||
@test 2 .* gs1 isa Grads | ||
@test (2 .* gs1)[w] ≈ 2 * gs1[w] | ||
@test gs1 .* 2 isa Grads | ||
@test gs1 ./ 2 isa Grads | ||
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] | ||
@test gs1 .* 2 isa Grads | ||
@test gs1 ./ 2 isa Grads | ||
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] | ||
|
||
gs12 = gs1 .+ gs2 | ||
gs1 .+= gs2 | ||
@test gs12[w] ≈ gs1[w] | ||
@test gs12[w] ≈ gs1[w] | ||
|
||
gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b | ||
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) | ||
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) | ||
|
||
@test .- gs3 isa Grads | ||
@test gs3 .- gs4 isa Grads | ||
@test gs3 .- gs4 isa Grads | ||
@test .+ gs3 isa Grads | ||
@test gs3 .+ gs4 isa Grads | ||
@test 2 .* gs3 isa Grads | ||
@test gs3 .* 2 isa Grads | ||
@test gs3 ./ 2 isa Grads | ||
@test gs3 .+ gs4 isa Grads | ||
@test 2 .* gs3 isa Grads | ||
@test gs3 .* 2 isa Grads | ||
@test gs3 ./ 2 isa Grads | ||
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] | ||
@test (gs3 .+ gs4)[b] ≈ gs4[b] | ||
@test (gs3 .+ gs4)[b] ≈ gs4[b] | ||
|
||
@test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads | ||
gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3)) | ||
@test gs3 isa Grads | ||
@test gs3 isa Grads | ||
|
||
@test_throws ArgumentError gs1 .+ gs4 | ||
end | ||
|
@@ -140,3 +140,21 @@ end | |
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} | ||
end | ||
|
||
|
||
@testset "CUDA complex broadcasting" begin | ||
# Issue 961 and 1121 and 1215 | ||
x = rand(Float32, 50) | ||
y = complex(rand(Float32, 50)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why define Also, this
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops! That There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trying to invent some functions, did not try them on GPU:
But locally, with this branch, I expected them to use the new code... but adding printing doesn't seem to work?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I looked into this and this occurred because I hadn't added a Complex method for However, that is not the big issue. The big issue is that certain functions seem to be causing some type instabilities during the evaluation of the dual numbers. For instance, x = rand(Complex{Float32}, 100)
f(x) = sum(abs2, log.(y))
@code_warntype Zygote.dual_function(f).(x)
MethodInstance for (::var"##dotfunction#314#7")(::Vector{ComplexF32})
from (::var"##dotfunction#314#7")(x1) in Main
Arguments
#self#::Core.Const(var"##dotfunction#314#7"())
x1::Vector{ComplexF32}
Body::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
1 ─ %1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│ %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│ %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF32}}}
│ %4 = Base.materialize(%3)::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
└── return %4``` Has a problem where the broadcast can't seem to figure out that eltype of the partial field in x64 = Complex{Float64}.(x)
@code_warntype Zygote.dual_function(f)(x64)
MethodInstance for (::var"##dotfunction#313#6")(::Vector{ComplexF64})
from (::var"##dotfunction#313#6")(x1) in Main
Arguments
#self#::Core.Const(var"##dotfunction#313#6"())
x1::Vector{ComplexF64}
Body::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
1 ─ %1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│ %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│ %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF64}}}
│ %4 = Base.materialize(%3)::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
└── return %4
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok looking into this more. It appears the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is weird, Anyway not this PR's problem! Maybe make an issue on ForwardDiff (or DiffRules) and test inference etc. with other functions here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok sounds good! I'll skip log for now and make tests for other functions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright I was able to add the last test, @test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im] and everything passes! The other two tests suggested both run into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here are a couple of updates on my end. First, I just realized I was running the previous test on the CPU. When I run it on the GPU, I get a scalar indexing error. The stack trace is julia> @test gradcheck_gpu((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)
Error During Test at /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186
Test threw exception
Expression: gradcheck_gpu(((r, c)->begin
sum(abs2, #= /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186 =# @__dot__(sin(conj(c) / r' - im) - imag(c + tanh(r / c'))))
end), r3, c3)
Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
[3] getindex(::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
@ GPUArrays ~/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9
[4] getindex
@ ~/.julia/juliaup/julia-1.8.2+0.x64/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:180 [inlined]
[5] _unsafe_getindex_rs
@ ./reshapedarray.jl:250 [inlined]
[6] _unsafe_getindex
@ ./reshapedarray.jl:247 [inlined]
[7] getindex
@ ./reshapedarray.jl:235 [inlined]
[8] iterate
@ ./abstractarray.jl:1167 [inlined]
[9] iterate
@ ./abstractarray.jl:1165 [inlined]
[10] iterate
@ ./generator.jl:44 [inlined]
[11] _collect(c::Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, itr::Base.Generator{Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:807
[12] collect_similar
@ ./array.jl:716 [inlined]
[13] map
@ ./abstractarray.jl:2933 [inlined]
[14] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:236
[15] ProjectTo
@ ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:414 [inlined]
[16] _project
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:184 [inlined]
[17] unbroadcast(x::LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, x̄::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:58
[18] (::Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:97
[19] (::Zygote.var"#3669#back#859"{Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[20] Pullback
@ ./none:0 [inlined]
[21] (::typeof(∂(#13)))(Δ::Float32)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[22] (::Zygote.var"#60#61"{typeof(∂(#13))})(Δ::Float32)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
[23] gradient(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
[24] gradcheck_gpu(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
@ Main ~/.julia/dev/Zygote/test/cuda.jl:9
[25] top-level scope From the look of the stack trace, this isn't due to this pull request. In fact, if I change the function definition to sin(conj(c)/$(transpose(r)) - im) - imag(c + tanh(r/c'))) then everything is fine, so my guess is that this is some funkiness related to the pullback of an adjoint of a real vector. I'll take a look into this, but I am not sure if that's part of this pull request. Second, I have added some additional tests to ensure we hit every one of the |
||
|
||
xgpu = cu(x) | ||
ygpu = cu(y) | ||
|
||
|
||
g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1] | ||
g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1] | ||
g3 = Zygote.gradient(x->sum(abs2, x), y)[1] | ||
@test g1 isa CUDA.CuArray{ComplexF32} | ||
@test g2 isa CUDA.CuArray{ComplexF32} | ||
@test collect(g1) ≈ collect(g2) | ||
@test collect(g1) ≈ g3 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be
Union{Dual, Dual{<:Complex}}
? You'd have to try pretty hard but I think the Complex path expects Dual inside.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought is was the other way around? At least that is what I am constructing in the
dual_function
.ForwardDiff.jl
also definesDual <: Real
so I think defining it the other way would break things. However, I probably want to be a little more specific here and doThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry, that's what I was thinking but didn't type...