From 616bf6cfb7d7cea2b755a31f36274e9fff18f678 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 10 Jan 2023 11:02:30 -0500 Subject: [PATCH] Adding complex broadcasting for gradients on the GPU (#1324) * Added complex broadcasting support * Added tests and clean up the code * Fix up type instability * Add testing * Everything passes tests now * switch to more generic broadcast_forward * clean up submission * Remove various Val's * change to Complex{<:Dual} * add mixed complex and real to cuda testing * import not using * Add complex to _dual_safearg * Type stable on my computer * Fix Dual tagging * Add more tests * update tests * First attempt to fix real performance regression * Uncomment ldexp rules * cleanup broadcast and inline * update tests * specify more reasonable tolerance for float32 * revert testing bug * clean up the submission --- src/lib/broadcast.jl | 113 +++++++++++++++++++++++++++++++++++++------ test/complex.jl | 1 - test/cuda.jl | 98 +++++++++++++++++++++++++++++-------- 3 files changed, 174 insertions(+), 38 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 58f7ecf99..dc7b053c1 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -120,6 +120,9 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) +@adjoint broadcasted(::typeof(abs2), x::Numeric) = + abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) + @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b y, Δ -> (nothing, Δ, nothing) @@ -181,7 +184,9 @@ _dual_purefun(::Type) = false _dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers _dual_safearg(x::Numeric{<:Real}) = true +_dual_safearg(x::Numeric{<:Complex}) = true _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true +_dual_safearg(x::Ref{<:Numeric{<:Complex}}) = true _dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types _dual_safearg(x) = false @@ -190,7 +195,7 @@ _dual_safearg(x) = false # Avoid generic broadcasting in two easy cases: if T == Bool return (f.(args...), _ -> nothing) - elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() + elseif T <: Union{Real, Complex} && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() return broadcast_forward(f, args...) end len = inclen(args) @@ -230,35 +235,112 @@ end # Forward Mode -- necessary for CUDA, also used as a fast path above import ForwardDiff -using ForwardDiff: Dual +using ForwardDiff: Dual, Partials, value, partials + + +# We do this because it ensures type stability so it compiles nicely on the gpu +# The val is needed for some type stability +@inline dual(x, i, ::Val{N}) where {N} = x +@inline dual(x::Bool, i, ::Val{N}) where {N} = x +@inline dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(==(i), N)) +# For complex since ForwardDiff.jl doesn't play nicely with complex numbers we +# construct a Complex dual number and tag the real and imaginary parts separately +@inline function dual(x::Complex{T}, i, ::Val{N}) where {T,N} + re_dual = Dual(real(x), ntuple(==(i), 2N)) + im_dual = Dual(imag(x), ntuple(==(N+i), 2N)) + return Complex(re_dual, im_dual) +end -dual(x, p) = x -dual(x::Real, p) = Dual(x, p) -dual(x::Bool, p) = x +function dualize(args::Vararg{Any, N}) where {N} + ds = map(args, ntuple(identity,N)) do x, i + return dual(x, i, Val(N)) + end + return ds +end -function dual_function(f::F) where F - function (args::Vararg{Any,N}) where N - ds = map(args, ntuple(identity,Val(N))) do x, i - dual(x, ntuple(j -> i==j, Val(N))) +@inline function dual_function(f::F) where F + function (args::Vararg{Any,N}) where N + ds = dualize(args...) + return f(ds...) end - return f(ds...) end -end + @inline function broadcast_forward(f, args::Vararg{Any,N}) where N - valN = Val(N) out = dual_function(f).(args...) - eltype(out) <: Dual || return (out, _ -> nothing) - y = broadcast(x -> x.value, out) + T = eltype(out) + T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing) + if any(eltype(a) <: Complex for a in args) + _broadcast_forward_complex(T, out, args...) + else + _broadcast_forward(T, out, args...) + end +end + +# Real input and real output pullback +@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> value(x), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) + unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out)) end (nothing, nothing, dargs...) # nothings for broadcasted & f end return y, bc_fwd_back end +# This handles the complex output and real input pullback +@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back + end + +# This handles complex input and real output. We use the gradient definition from ChainRules here +# since it agrees with what Zygote did for real(x). +@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> value(x), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back +end + +# # # This is for complex input and complex output +# If we assume that +# f(x + iy) = u(x,y) + iv(x,y) +# then we do the following for the adjoint +# Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y ) +# this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html +function _adjoint_complex(N, Δz, df, i) + Δu, Δv = reim(Δz) + du, dv = reim(df) + return Complex(Δu*partials(du, i) + Δv*partials(dv, i), Δu*partials(du, i+N) + Δv*partials(dv, i+N)) +end + +@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back +end + using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, @@ -287,4 +369,3 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve end pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz] - diff --git a/test/complex.jl b/test/complex.jl index efb1e06dd..e50c57486 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -120,4 +120,3 @@ end end @test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0] end - diff --git a/test/cuda.jl b/test/cuda.jl index 5cb1c8cdc..171fa45db 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -2,8 +2,17 @@ using CUDA using Zygote: Grads using LinearAlgebra using Random: randn! +import FiniteDifferences CUDA.allowscalar(false) +function gradcheck_gpu(f, xs...) + grad_zygote = gradient(f, xs...) + m = FiniteDifferences.central_fdm(5,1) + grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...) + return all(isapprox.(collect.(grad_zygote), grad_finite_difference)) +end + + # Test GPU movement inside the call to `gradient` @testset "GPU movement" begin r = rand(Float32, 3,3) @@ -26,7 +35,7 @@ end g_gpu = gradient(x -> v(x, 7), a_gpu)[1] @test g_gpu isa CuArray @test g_gpu |> collect ≈ g - + w(x) = sum(broadcast(log, x)) g = gradient(x -> w(x), a)[1] g_gpu = gradient(x -> w(x), a_gpu)[1] @@ -38,7 +47,7 @@ end @test gradient(x -> sum(x .> 3), a_gpu) == (nothing,) g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 - @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] + @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] # Projection: eltype preservation: @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} @@ -90,40 +99,40 @@ end @testset "gradient algebra" begin w, b = rand(2) |> cu, rand(2) |> cu x1, x2 = rand(2) |> cu, rand(2) |> cu - - gs1 = gradient(() -> sum(w .* x1), Params([w])) - gs2 = gradient(() -> sum(w .* x2), Params([w])) + + gs1 = gradient(() -> sum(w .* x1), Params([w])) + gs2 = gradient(() -> sum(w .* x2), Params([w])) @test .- gs1 isa Grads - @test gs1 .- gs2 isa Grads + @test gs1 .- gs2 isa Grads @test .+ gs1 isa Grads - @test gs1 .+ gs2 isa Grads - @test 2 .* gs1 isa Grads + @test gs1 .+ gs2 isa Grads + @test 2 .* gs1 isa Grads @test (2 .* gs1)[w] ≈ 2 * gs1[w] - @test gs1 .* 2 isa Grads - @test gs1 ./ 2 isa Grads - @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] + @test gs1 .* 2 isa Grads + @test gs1 ./ 2 isa Grads + @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] gs12 = gs1 .+ gs2 gs1 .+= gs2 - @test gs12[w] ≈ gs1[w] + @test gs12[w] ≈ gs1[w] gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b - gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) + gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) @test .- gs3 isa Grads - @test gs3 .- gs4 isa Grads + @test gs3 .- gs4 isa Grads @test .+ gs3 isa Grads - @test gs3 .+ gs4 isa Grads - @test 2 .* gs3 isa Grads - @test gs3 .* 2 isa Grads - @test gs3 ./ 2 isa Grads + @test gs3 .+ gs4 isa Grads + @test 2 .* gs3 isa Grads + @test gs3 .* 2 isa Grads + @test gs3 ./ 2 isa Grads @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] - @test (gs3 .+ gs4)[b] ≈ gs4[b] - + @test (gs3 .+ gs4)[b] ≈ gs4[b] + @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3)) - @test gs3 isa Grads + @test gs3 isa Grads @test_throws ArgumentError gs1 .+ gs4 end @@ -140,3 +149,50 @@ end @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} end + +@testset "CUDA complex broadcasting" begin + # Issue 961 and 1121 and 1215 + x = 2*rand(Float32, 10) .- 1f0 + y = 2*rand(ComplexF32, 10) .- 1f0 + + xgpu =cu(x) + ygpu =cu(y) + + g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1] + g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1] + g3 = Zygote.gradient(x->sum(abs2, x), y)[1] + @test g1 isa CUDA.CuArray{ComplexF32} + @test g2 isa CUDA.CuArray{ComplexF32} + @test collect(g1) ≈ collect(g2) + @test collect(g1) ≈ g3 + + + r3 = cu(Float32.(inv.(2:4))) + c3 = cu(ComplexF32.(inv.(5:7) .+ im ./ (8:10))) + + + # These check _broadcast_forward(::Type{<:Dual}, ...) + @test gradcheck_gpu((x,y)->sum(abs2, x.^2 .+ y), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ imag.(y)), xgpu, ygpu) + + # These check _broadcast_forward_complex(::Type{<:Complex}, ...) + @test gradcheck_gpu((x,y)->sum(abs2, cos.(x) .+ sin.(y)), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, cos.(x).*sin.(y)), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) + @test gradcheck_gpu((r,c) -> sum(abs2, sin.(conj.(c)./transpose(r) .- im) .- imag.(c .+ tanh.(r./c'))), r3, c3) + + # Painful test! + @test gradcheck_gpu(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) + @test gradcheck_gpu(r -> sum(abs2, log.(1 .+ im .* r)./2), r3) + + + # These check _broadcast_forward(::Type{<:Complex}, ...) + @test gradcheck_gpu(x->sum(real, cis.(x)), xgpu) + @test gradcheck_gpu(x->sum(real, cispi.(x)), xgpu) + + # These check _broadcast_forward_complex(::Type{<:Dual}, ...) + @test gradcheck_gpu(x->sum(imag, x.^2 .+ abs.(sinh.(conj.(x)))), ygpu) + + +end