Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding complex broadcasting for gradients on the GPU #1324

Merged
merged 25 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 97 additions & 16 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ end
@adjoint broadcasted(::typeof(imag), x::Numeric) =
imag.(x), z̄ -> (nothing, im .* real.(z̄))

@adjoint broadcasted(::typeof(abs2), x::Numeric) =
abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x)

@adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool)
y = b === false ? a : a .+ b
y, Δ -> (nothing, Δ, nothing)
Expand Down Expand Up @@ -181,7 +184,9 @@ _dual_purefun(::Type) = false
_dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers

_dual_safearg(x::Numeric{<:Real}) = true
_dual_safearg(x::Numeric{<:Complex}) = true
_dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
_dual_safearg(x::Ref{<:Numeric{<:Complex}}) = true
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
_dual_safearg(x) = false

Expand All @@ -190,7 +195,7 @@ _dual_safearg(x) = false
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
elseif T <: Union{Real, Complex} && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
return broadcast_forward(f, args...)
end
len = inclen(args)
Expand Down Expand Up @@ -230,35 +235,112 @@ end
# Forward Mode -- necessary for CUDA, also used as a fast path above

import ForwardDiff
using ForwardDiff: Dual
using ForwardDiff: Dual, Partials, value, partials


# We do this because it ensures type stability so it compiles nicely on the gpu
# The val is needed for some type stability
@inline dual(x, i, ::Val{N}) where {N} = x
@inline dual(x::Bool, i, ::Val{N}) where {N} = x
@inline dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(==(i), N))
# For complex since ForwardDiff.jl doesn't play nicely with complex numbers we
# construct a Complex dual number and tag the real and imaginary parts separately
@inline function dual(x::Complex{T}, i, ::Val{N}) where {T,N}
re_dual = Dual(real(x), ntuple(==(i), 2N))
im_dual = Dual(imag(x), ntuple(==(N+i), 2N))
return Complex(re_dual, im_dual)
end

dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
dual(x::Bool, p) = x
function dualize(args::Vararg{Any, N}) where {N}
ds = map(args, ntuple(identity,N)) do x, i
return dual(x, i, Val(N))
end
return ds
end

function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = map(args, ntuple(identity,Val(N))) do x, i
dual(x, ntuple(j -> i==j, Val(N)))
@inline function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = dualize(args...)
return f(ds...)
end
return f(ds...)
end
end


@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
valN = Val(N)
out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
y = broadcast(x -> x.value, out)
T = eltype(out)
T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
if any(eltype(a) <: Complex for a in args)
_broadcast_forward_complex(T, out, args...)
else
_broadcast_forward(T, out, args...)
end
end

# Real input and real output pullback
@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out))
unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles the complex output and real input pullback
@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles complex input and real output. We use the gradient definition from ChainRules here
# since it agrees with what Zygote did for real(x).
@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# # # This is for complex input and complex output
# If we assume that
# f(x + iy) = u(x,y) + iv(x,y)
# then we do the following for the adjoint
# Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y )
# this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html
function _adjoint_complex(N, Δz, df, i)
Δu, Δv = reim(Δz)
du, dv = reim(df)
return Complex(Δu*partials(du, i) + Δv*partials(dv, i), Δu*partials(du, i+N) + Δv*partials(dv, i+N))
end

@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame

# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
Expand Down Expand Up @@ -287,4 +369,3 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
end

pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]

1 change: 0 additions & 1 deletion test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,3 @@ end
end
@test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
end

98 changes: 77 additions & 21 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@ using CUDA
using Zygote: Grads
using LinearAlgebra
using Random: randn!
import FiniteDifferences
CUDA.allowscalar(false)

function gradcheck_gpu(f, xs...)
grad_zygote = gradient(f, xs...)
m = FiniteDifferences.central_fdm(5,1)
grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...)
return all(isapprox.(collect.(grad_zygote), grad_finite_difference))
end


# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
r = rand(Float32, 3,3)
Expand All @@ -26,7 +35,7 @@ end
g_gpu = gradient(x -> v(x, 7), a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g

w(x) = sum(broadcast(log, x))
g = gradient(x -> w(x), a)[1]
g_gpu = gradient(x -> w(x), a_gpu)[1]
Expand All @@ -38,7 +47,7 @@ end
@test gradient(x -> sum(x .> 3), a_gpu) == (nothing,)
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
@test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]

# Projection: eltype preservation:
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
Expand Down Expand Up @@ -90,40 +99,40 @@ end
@testset "gradient algebra" begin
w, b = rand(2) |> cu, rand(2) |> cu
x1, x2 = rand(2) |> cu, rand(2) |> cu
gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test (2 .* gs1)[w] ≈ 2 * gs1[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]

gs12 = gs1 .+ gs2
gs1 .+= gs2
@test gs12[w] ≈ gs1[w]
@test gs12[w] ≈ gs1[w]

gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))

@test .- gs3 isa Grads
@test gs3 .- gs4 isa Grads
@test gs3 .- gs4 isa Grads
@test .+ gs3 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]
@test (gs3 .+ gs4)[b] ≈ gs4[b]

@test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3))
@test gs3 isa Grads
@test gs3 isa Grads

@test_throws ArgumentError gs1 .+ gs4
end
Expand All @@ -140,3 +149,50 @@ end
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
end


@testset "CUDA complex broadcasting" begin
# Issue 961 and 1121 and 1215
x = 2*rand(Float32, 10) .- 1f0
y = 2*rand(ComplexF32, 10) .- 1f0

xgpu =cu(x)
ygpu =cu(y)

g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1]
g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1]
g3 = Zygote.gradient(x->sum(abs2, x), y)[1]
@test g1 isa CUDA.CuArray{ComplexF32}
@test g2 isa CUDA.CuArray{ComplexF32}
@test collect(g1) ≈ collect(g2)
@test collect(g1) ≈ g3


r3 = cu(Float32.(inv.(2:4)))
c3 = cu(ComplexF32.(inv.(5:7) .+ im ./ (8:10)))


# These check _broadcast_forward(::Type{<:Dual}, ...)
@test gradcheck_gpu((x,y)->sum(abs2, x.^2 .+ y), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ imag.(y)), xgpu, ygpu)

# These check _broadcast_forward_complex(::Type{<:Complex}, ...)
@test gradcheck_gpu((x,y)->sum(abs2, cos.(x) .+ sin.(y)), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, cos.(x).*sin.(y)), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu)
@test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu)
@test gradcheck_gpu((r,c) -> sum(abs2, sin.(conj.(c)./transpose(r) .- im) .- imag.(c .+ tanh.(r./c'))), r3, c3)

# Painful test!
@test gradcheck_gpu(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)
@test gradcheck_gpu(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)


# These check _broadcast_forward(::Type{<:Complex}, ...)
@test gradcheck_gpu(x->sum(real, cis.(x)), xgpu)
@test gradcheck_gpu(x->sum(real, cispi.(x)), xgpu)

# These check _broadcast_forward_complex(::Type{<:Dual}, ...)
@test gradcheck_gpu(x->sum(imag, x.^2 .+ abs.(sinh.(conj.(x)))), ygpu)


end