-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding complex broadcasting for gradients on the GPU #1324
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly leaving this here to say the test failures are expected, but a couple suggestions while I'm at it:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, thanks for tacking it!
I meant to take a closer read but haven't yet, sorry.
I believe there is already sufficient testing done there.
Sadly I would not assume this. There may be very few tests of complex broadcasting, not sure (maybe I missed a section). It might be worth trying to come up with some evil test cases, including e.g. fused broadcasts where only parts are complex.
src/lib/broadcast.jl
Outdated
out = dual_function(f).(args...) | ||
eltype(out) <: Dual || return (out, _ -> nothing) | ||
T = eltype(out) | ||
T <: Union{Dual, Complex} || return (out, _ -> nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be Union{Dual, Dual{<:Complex}}
? You'd have to try pretty hard but I think the Complex path expects Dual inside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought is was the other way around? At least that is what I am constructing in the dual_function
. ForwardDiff.jl
also defines Dual <: Real
so I think defining it the other way would break things. However, I probably want to be a little more specific here and do
T <: Union{Dual, Complex} || return (out, _ -> nothing) | |
T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry, that's what I was thinking but didn't type...
test/cuda.jl
Outdated
@testset "CUDA complex broadcasting" begin | ||
# Issue 961 and 1121 and 1215 | ||
x = rand(Float32, 50) | ||
y = complex(rand(Float32, 50)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why define x
here at all?
Also, this y
has zero imaginary part. rand(ComplexF64, 50)
would be a stronger test.
julia> complex(rand(Float32, 50))
50-element Vector{ComplexF32}:
0.89825445f0 + 0.0f0im
0.40070343f0 + 0.0f0im
0.29411656f0 + 0.0f0im
0.44503874f0 + 0.0f0im
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops! That x
was for a test I was doing on my machine. I think overall that the testing could be a bit better though so I've added another test that uses both real and complex arguments. I probably need to add some additional tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. I think x.^2 .*y .+ y
uses only functions which have special rules, and ought to work without this PR. I think even broadcasting trivial functions like add(x,y) = x+y
will change the path it takes. But messy examples (e.g. with trig, conj/real/imag, in all sorts of ways) are much more likely to expose mistakes like a conj
missing somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to invent some functions, did not try them on GPU:
r3 = Float32.(inv.(2:4))
c3 = ComplexF32.(inv.(5:7) .+ im ./ (8:10))
@test gradient(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)[1] ≈ [0.2077734, 0.15268978, 0.11885023]
@test gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)[1] ≈ [-0.4124833f0 + 0.49228126f0im, -0.4258298f0 + 0.49446818f0im, -0.43560573f0 + 0.49583605f0im]
@test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im]
But locally, with this branch, I expected them to use the new code... but adding printing doesn't seem to work?
(jl_S8DfLf) pkg> st Zygote
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_S8DfLf/Project.toml`
[e88e6eb3] Zygote v0.6.49 `https://github.com/ptiede/Zygote.jl#pt-complexbroadcast`
julia> @eval Zygote function dual(x::Complex, i, N) # from PR, with printing
@show x
re_dual = Dual(real(x), ntuple(==(i), 2N))
im_dual = Dual(imag(x), ntuple(==(N+i), 2N))
return Complex(re_dual, im_dual)
end;
julia> Zygote.refresh()
julia> @test gradient(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)[1] ≈ [0.2077734, 0.15268978, 0.11885023]
Test Passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I looked into this and this occurred because I hadn't added a Complex method for _dual_safearg
. When I added this some issues started to appear. One of them was because the partials for the real and complex parts had different lengths.
However, that is not the big issue. The big issue is that certain functions seem to be causing some type instabilities during the evaluation of the dual numbers. For instance,
x = rand(Complex{Float32}, 100)
f(x) = sum(abs2, log.(y))
@code_warntype Zygote.dual_function(f).(x)
MethodInstance for (::var"##dotfunction#314#7")(::Vector{ComplexF32})
from (::var"##dotfunction#314#7")(x1) in Main
Arguments
#self#::Core.Const(var"##dotfunction#314#7"())
x1::Vector{ComplexF32}
Body::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
1 ─ %1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│ %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│ %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF32}}}
│ %4 = Base.materialize(%3)::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
└── return %4```
Has a problem where the broadcast can't seem to figure out that eltype of the partial field in Dual
should be a Float32
. What is really annoying is that this problem does not occur for Float64
where I get
x64 = Complex{Float64}.(x)
@code_warntype Zygote.dual_function(f)(x64)
MethodInstance for (::var"##dotfunction#313#6")(::Vector{ComplexF64})
from (::var"##dotfunction#313#6")(x1) in Main
Arguments
#self#::Core.Const(var"##dotfunction#313#6"())
x1::Vector{ComplexF64}
Body::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
1 ─ %1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│ %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│ %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF64}}}
│ %4 = Base.materialize(%3)::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
└── return %4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok looking into this more. It appears the log
with Complex{Dual{Float32}}
arguments is type unstable.
My guess is that this occurs because there isn't using the specific forward rule for a complex number for log, or likely any common functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is weird, @code_warntype log(Dual(1f0, 1f0) + im)
is bad. Inside Base.ssqs
, it looks like ldexp(Dual(1f0, 2f0), 3)
makes a Float64 dual, by a method from ForwardDiff.
Anyway not this PR's problem! Maybe make an issue on ForwardDiff (or DiffRules) and test inference etc. with other functions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok sounds good! I'll skip log for now and make tests for other functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright I was able to add the last test,
@test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im]
and everything passes! The other two tests suggested both run into the ldexp
problem with Float32. I have opened up an issue JuliaDiff/ForwardDiff.jl#604 detailing the problem. The good news is that when I fix the problem locally all the tests pass!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are a couple of updates on my end. First, I just realized I was running the previous test on the CPU. When I run it on the GPU, I get a scalar indexing error. The stack trace is
julia> @test gradcheck_gpu((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)
Error During Test at /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186
Test threw exception
Expression: gradcheck_gpu(((r, c)->begin
sum(abs2, #= /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186 =# @__dot__(sin(conj(c) / r' - im) - imag(c + tanh(r / c'))))
end), r3, c3)
Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
[3] getindex(::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
@ GPUArrays ~/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9
[4] getindex
@ ~/.julia/juliaup/julia-1.8.2+0.x64/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:180 [inlined]
[5] _unsafe_getindex_rs
@ ./reshapedarray.jl:250 [inlined]
[6] _unsafe_getindex
@ ./reshapedarray.jl:247 [inlined]
[7] getindex
@ ./reshapedarray.jl:235 [inlined]
[8] iterate
@ ./abstractarray.jl:1167 [inlined]
[9] iterate
@ ./abstractarray.jl:1165 [inlined]
[10] iterate
@ ./generator.jl:44 [inlined]
[11] _collect(c::Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, itr::Base.Generator{Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:807
[12] collect_similar
@ ./array.jl:716 [inlined]
[13] map
@ ./abstractarray.jl:2933 [inlined]
[14] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:236
[15] ProjectTo
@ ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:414 [inlined]
[16] _project
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:184 [inlined]
[17] unbroadcast(x::LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, x̄::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:58
[18] (::Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:97
[19] (::Zygote.var"#3669#back#859"{Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[20] Pullback
@ ./none:0 [inlined]
[21] (::typeof(∂(#13)))(Δ::Float32)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[22] (::Zygote.var"#60#61"{typeof(∂(#13))})(Δ::Float32)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
[23] gradient(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
[24] gradcheck_gpu(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
@ Main ~/.julia/dev/Zygote/test/cuda.jl:9
[25] top-level scope
From the look of the stack trace, this isn't due to this pull request. In fact, if I change the function definition to
sin(conj(c)/$(transpose(r)) - im) - imag(c + tanh(r/c')))
then everything is fine, so my guess is that this is some funkiness related to the pullback of an adjoint of a real vector. I'll take a look into this, but I am not sure if that's part of this pull request.
Second, I have added some additional tests to ensure we hit every one of the _broadcast_forward
branches.
@mcabbott mostly good news. The ldexp type instability was fixed in JuliaDiff/DiffRules.jl#89. r3 = Float32.(inv.(2:4))
f(r) = sum(abs2, log.(1 .+ im .* r)./2)
Zygote.gradient(f, r3) which gives a Digging into this issue a bit more, on 1.6 I can create the following MWE: julia> rd3 = first.(Zygote.dualize(r3)) # CuArray{ForwardDiff.Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}
julia> log.(1im .* rd3)
ERROR: InvalidIRError: compiling kernel #broadcast_kernel#17(CUDA.CuKernelContext, CuDeviceVector{Complex{ForwardDiff.Dual{Nothing, Float32, 1}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(log), Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(*), Tuple{Complex{Int64}, Base.Broadcast.Extruded{CuDeviceVector{ForwardDiff.Dual{Nothing, Float32, 1}, 1}, Tuple{Bool}, Tuple{Int64}}}}}}, Int64) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to exponent)
Stacktrace:
[1] ssqs
@ ./complex.jl:474
[2] log
@ ./complex.jl:594
[3] _broadcast_getindex_evalf
@ ./broadcast.jl:648
[4] _broadcast_getindex
@ ./broadcast.jl:621
[5] getindex
@ ./broadcast.jl:575
[6] broadcast_kernel
@ ~/.julia/packages/GPUArrays/fqD8z/src/host/broadcast.jl:57
... which suggests the problem is in |
Does more inlining help at all, e.g. |
Sadly no :( The MWE also shouldn't have an inlining issue right? |
Ok figured it out! Analyzing • %282 = call #exponent(::ForwardDiff.Dual{Nothing, Float32, 1})::Union{} which errors because Base.exponent(x::ForwardDiff.Dual{<:Real}) = Base.exponent(ForwardDiff.value(x)) everything works and we pass the tests on Julia 1.6. I believe this function definition makes sense since exponent: Real -> Int so we only really care about the value of the function. I don't really understand why this didn't cause an issue on 1.7/1.8, but maybe this got optimized away? |
Alright the dual exponent issue has been fixed. When a new version of ForwardDiff is released when a new version of ForwardDiff.jl get released the 1.6 tests should pass. |
DynamicPPL test failures are caused by JuliaDiff/ForwardDiff.jl#606. |
@mcabbott I think this is finally ready to review again. All the tests are passing, and I have added some additional tests to ensure that every branch is getting hit. |
Is this ready to merge? |
Just a bump to see if this is ready to be merged or it there are some outstanding items that I still need to fix. |
Thanks! I'll tag a new release shortly |
@ptiede which of the issues mentioned in the OP should be closed? |
This should fix 961, 1121, 1215, 1276, i.e. all of them since they were all the same problem in disguise. |
Do you think we need some extra tests or the ones in this PR cover them all? |
I think the tests should cover all of those cases Lines 191 to 192 in 616bf6c
should cover 1276 intrinsically because the type instability that was causing the slowdown is fixed. Line 175 in 616bf6c
should cover the abs2 bug. But coming up with tests was tricky so it is possible that I missed something. |
I just tested and closed all of them |
This is a first attempt to add support for taking gradients of complex numbers when broadcasting and on the GPU. This targets issues #961 #1121 #1215.
A nice side effect of this pull request is that complex broadcasting doesn't have to take the slow route anymore when on the CPU, and fixes the performance issues in #1276
On the current Zygote.jl release, I get:
With this pull-request I get
Approach
To fix these issues, I changed how
broadcast_forward
anddual_function
work. This was inspired by @mcabbott comment but with some changes to ensure there are no dynamic dispatches or type instabilities. Specifically, I had to change thedual
function sincewas leading to some type instability warnings on the GPU and some other strange issues.
On top of the change to
dual
another change is howbroadcast_forward
works. I had to make four separate functions depending on the output and the arguments broadcast. I am not sure if there is a better way to do this, but it currently works and passes all tests on my machine. One concern I had was what to do for complex->complex functions. For this, I just followed what was listed in https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html, but maybe we don't want to follow that?Testing
In terms of testing, I have added some small tests to
cuda.jl
to ensure that nothing is not returned and that the gradient on the GPU and CPU are the same. Since I also changed broadcast_forward on the CPU (always taking the fast path) I believe there is already sufficient testing done there.PR Checklist