Skip to content

Commit

Permalink
Adding complex broadcasting for gradients on the GPU (#1324)
Browse files Browse the repository at this point in the history
* Added complex broadcasting support

* Added tests and clean up the code

* Fix up type instability

* Add testing

* Everything passes tests now

* switch to more generic broadcast_forward

* clean up submission

* Remove various Val's

* change to Complex{<:Dual}

* add mixed complex and real to cuda testing

* import not using

* Add complex to _dual_safearg

* Type stable on my computer

* Fix Dual tagging

* Add more tests

* update tests

* First attempt to fix real performance regression

* Uncomment ldexp rules

* cleanup broadcast and inline

* update tests

* specify more reasonable tolerance for float32

* revert testing bug

* clean up the submission
  • Loading branch information
ptiede authored Jan 10, 2023
1 parent f3857d1 commit 616bf6c
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 38 deletions.
113 changes: 97 additions & 16 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ end
@adjoint broadcasted(::typeof(imag), x::Numeric) =
imag.(x), z̄ -> (nothing, im .* real.(z̄))

@adjoint broadcasted(::typeof(abs2), x::Numeric) =
abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x)

@adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool)
y = b === false ? a : a .+ b
y, Δ -> (nothing, Δ, nothing)
Expand Down Expand Up @@ -181,7 +184,9 @@ _dual_purefun(::Type) = false
_dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers

_dual_safearg(x::Numeric{<:Real}) = true
_dual_safearg(x::Numeric{<:Complex}) = true
_dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
_dual_safearg(x::Ref{<:Numeric{<:Complex}}) = true
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
_dual_safearg(x) = false

Expand All @@ -190,7 +195,7 @@ _dual_safearg(x) = false
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
elseif T <: Union{Real, Complex} && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
return broadcast_forward(f, args...)
end
len = inclen(args)
Expand Down Expand Up @@ -230,35 +235,112 @@ end
# Forward Mode -- necessary for CUDA, also used as a fast path above

import ForwardDiff
using ForwardDiff: Dual
using ForwardDiff: Dual, Partials, value, partials


# We do this because it ensures type stability so it compiles nicely on the gpu
# The val is needed for some type stability
@inline dual(x, i, ::Val{N}) where {N} = x
@inline dual(x::Bool, i, ::Val{N}) where {N} = x
@inline dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(==(i), N))
# For complex since ForwardDiff.jl doesn't play nicely with complex numbers we
# construct a Complex dual number and tag the real and imaginary parts separately
@inline function dual(x::Complex{T}, i, ::Val{N}) where {T,N}
re_dual = Dual(real(x), ntuple(==(i), 2N))
im_dual = Dual(imag(x), ntuple(==(N+i), 2N))
return Complex(re_dual, im_dual)
end

dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
dual(x::Bool, p) = x
function dualize(args::Vararg{Any, N}) where {N}
ds = map(args, ntuple(identity,N)) do x, i
return dual(x, i, Val(N))
end
return ds
end

function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = map(args, ntuple(identity,Val(N))) do x, i
dual(x, ntuple(j -> i==j, Val(N)))
@inline function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = dualize(args...)
return f(ds...)
end
return f(ds...)
end
end


@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
valN = Val(N)
out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
y = broadcast(x -> x.value, out)
T = eltype(out)
T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
if any(eltype(a) <: Complex for a in args)
_broadcast_forward_complex(T, out, args...)
else
_broadcast_forward(T, out, args...)
end
end

# Real input and real output pullback
@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out))
unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles the complex output and real input pullback
@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles complex input and real output. We use the gradient definition from ChainRules here
# since it agrees with what Zygote did for real(x).
@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# # # This is for complex input and complex output
# If we assume that
# f(x + iy) = u(x,y) + iv(x,y)
# then we do the following for the adjoint
# Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y )
# this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html
function _adjoint_complex(N, Δz, df, i)
Δu, Δv = reim(Δz)
du, dv = reim(df)
return Complex(Δu*partials(du, i) + Δv*partials(dv, i), Δu*partials(du, i+N) + Δv*partials(dv, i+N))
end

@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame

# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
Expand Down Expand Up @@ -287,4 +369,3 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
end

pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]

1 change: 0 additions & 1 deletion test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,3 @@ end
end
@test Zygote.hessian(fun, collect(1:9)) [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
end

98 changes: 77 additions & 21 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@ using CUDA
using Zygote: Grads
using LinearAlgebra
using Random: randn!
import FiniteDifferences
CUDA.allowscalar(false)

function gradcheck_gpu(f, xs...)
grad_zygote = gradient(f, xs...)
m = FiniteDifferences.central_fdm(5,1)
grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...)
return all(isapprox.(collect.(grad_zygote), grad_finite_difference))
end


# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
r = rand(Float32, 3,3)
Expand All @@ -26,7 +35,7 @@ end
g_gpu = gradient(x -> v(x, 7), a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect g

w(x) = sum(broadcast(log, x))
g = gradient(x -> w(x), a)[1]
g_gpu = gradient(x -> w(x), a_gpu)[1]
Expand All @@ -38,7 +47,7 @@ end
@test gradient(x -> sum(x .> 3), a_gpu) == (nothing,)
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
@test_skip cu(g3) gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
@test cu(g3) gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
@test cu(g3) gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]

# Projection: eltype preservation:
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
Expand Down Expand Up @@ -90,40 +99,40 @@ end
@testset "gradient algebra" begin
w, b = rand(2) |> cu, rand(2) |> cu
x1, x2 = rand(2) |> cu, rand(2) |> cu
gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test (2 .* gs1)[w] 2 * gs1[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] gs1[w] .+ gs2[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] gs1[w] .+ gs2[w]

gs12 = gs1 .+ gs2
gs1 .+= gs2
@test gs12[w] gs1[w]
@test gs12[w] gs1[w]

gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))

@test .- gs3 isa Grads
@test gs3 .- gs4 isa Grads
@test gs3 .- gs4 isa Grads
@test .+ gs3 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test (gs3 .+ gs4)[w] gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] gs4[b]
@test (gs3 .+ gs4)[b] gs4[b]

@test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3))
@test gs3 isa Grads
@test gs3 isa Grads

@test_throws ArgumentError gs1 .+ gs4
end
Expand All @@ -140,3 +149,50 @@ end
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
end


@testset "CUDA complex broadcasting" begin
# Issue 961 and 1121 and 1215
x = 2*rand(Float32, 10) .- 1f0
y = 2*rand(ComplexF32, 10) .- 1f0

xgpu =cu(x)
ygpu =cu(y)

g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1]
g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1]
g3 = Zygote.gradient(x->sum(abs2, x), y)[1]
@test g1 isa CUDA.CuArray{ComplexF32}
@test g2 isa CUDA.CuArray{ComplexF32}
@test collect(g1) collect(g2)
@test collect(g1) g3


r3 = cu(Float32.(inv.(2:4)))
c3 = cu(ComplexF32.(inv.(5:7) .+ im ./ (8:10)))


# These check _broadcast_forward(::Type{<:Dual}, ...)
@test gradcheck_gpu((x,y)->sum(abs2, x.^2 .+ y), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ imag.(y)), xgpu, ygpu)

# These check _broadcast_forward_complex(::Type{<:Complex}, ...)
@test gradcheck_gpu((x,y)->sum(abs2, cos.(x) .+ sin.(y)), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, cos.(x).*sin.(y)), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu)
@test gradcheck_gpu((r,c) -> sum(abs2, sin.(conj.(c)./transpose(r) .- im) .- imag.(c .+ tanh.(r./c'))), r3, c3)

# Painful test!
@test gradcheck_gpu(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)
@test gradcheck_gpu(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)


# These check _broadcast_forward(::Type{<:Complex}, ...)
@test gradcheck_gpu(x->sum(real, cis.(x)), xgpu)
@test gradcheck_gpu(x->sum(real, cispi.(x)), xgpu)

# These check _broadcast_forward_complex(::Type{<:Dual}, ...)
@test gradcheck_gpu(x->sum(imag, x.^2 .+ abs.(sinh.(conj.(x)))), ygpu)


end

0 comments on commit 616bf6c

Please sign in to comment.