From 807d6898c8317733a849b9e5a5c2e5c1b6a921a2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 23 Oct 2022 20:46:57 -0400 Subject: [PATCH 01/23] Added complex broadcasting support --- Project.toml | 4 +- src/lib/broadcast.jl | 104 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 7d277f688..209da2578 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" # not loaded, just a version bound +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -33,7 +33,7 @@ ChainRulesTestUtils = "1" DiffRules = "1.4" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13" ForwardDiff = "0.10" -GPUArrays = "8.4.2" # not loaded, just a version bound +GPUArrays = "8.4.2" GPUArraysCore = "0.1.1" IRTools = "0.4.4" LogExpFunctions = "0.3.1" diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 58f7ecf99..4c81a2dac 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -120,6 +120,9 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) +# @adjoint broadcasted(::typeof(abs2), x::Numeric) = +# abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) + @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b y, Δ -> (nothing, Δ, nothing) @@ -190,7 +193,7 @@ _dual_safearg(x) = false # Avoid generic broadcasting in two easy cases: if T == Bool return (f.(args...), _ -> nothing) - elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() + elseif T <: Union{Real, Complex} && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() return broadcast_forward(f, args...) end len = inclen(args) @@ -232,23 +235,64 @@ end import ForwardDiff using ForwardDiff: Dual -dual(x, p) = x -dual(x::Real, p) = Dual(x, p) -dual(x::Bool, p) = x +dual(x, p, pc=()) = x +dual(x::Real, p, pc=()) = Dual(x, p) +dual(x::Bool, p, pc=()) = x +dual(x::Complex, p, pc) = Complex(Dual(real(x), p), Dual(imag(x), pc)) function dual_function(f::F) where F function (args::Vararg{Any,N}) where N - ds = map(args, ntuple(identity,Val(N))) do x, i - dual(x, ntuple(j -> i==j, Val(N))) + if any(a isa Complex for a in args) + ds = map(args, ntuple(identity, Val(N))) do x, i + dual(x, ntuple(j -> i==j, Val(2N)), ntuple(j -> N+i==j, Val(2N))) + end + return f(ds...) + else + ds = map(args, ntuple(identity,Val(N))) do x, i + dual(x, ntuple(j -> i==j, Val(N))) + end + return f(ds...) end - return f(ds...) end end +# function dual_function(f::F) where F +# function (args::Vararg{Any,N}) where N +# ds = map(args, ntuple(identity,Val(N))) do x, i +# dual(x, ntuple(j -> i==j, Val(N))) +# end +# return f(ds...) +# end +# end + +# @inline function broadcast_forward(f, args::Vararg{Any,N}) where N +# valN = Val(N) +# out = dual_function(f).(args...) +# eltype(out) <: Dual || return (out, _ -> nothing) +# y = broadcast(x -> x.value, out) +# function bc_fwd_back(ȳ) +# dargs = ntuple(valN) do i +# unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) +# end +# (nothing, nothing, dargs...) # nothings for broadcasted & f +# end +# return y, bc_fwd_back +# end + + @inline function broadcast_forward(f, args::Vararg{Any,N}) where N - valN = Val(N) out = dual_function(f).(args...) - eltype(out) <: Dual || return (out, _ -> nothing) + eltype(out) <: Union{Dual, Complex} || return (out, _ -> nothing) + if any(eltype(a) <: Complex for a in args) + _broadcast_forward_complex(out, args...) + else + _broadcast_forward(out, args...) + end +end + +# Real input and real output +function _broadcast_forward(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} + valN = Val(N) y = broadcast(x -> x.value, out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i @@ -259,6 +303,47 @@ end return y, bc_fwd_back end +# This handles complex output and real input +function _broadcast_forward(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back + end + +# This handles complex input and real output +function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> x.value, out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N]), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back +end + +# This is for complex input and complex output +# I am a little confused what derivative we want to use here so it hasn't been implemented +function _broadcast_forward_complex(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} + throw("Complex output and input not supported in Zygote broadcast_forward") + # valN = Val(N) + # y = broadcast(x -> Complex(x.re.value, x.im.value), out) + # function bc_fwd_back(ȳ) + # dargs = ntuple(valN) do i + # unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N-1]), ȳ, out)) + # end + # (nothing, nothing, dargs...) # nothings for broadcasted & f + # end + # return y, bc_fwd_back +end + using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, @@ -287,4 +372,3 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve end pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz] - From 2972fafacc79e11a26c5bf36be6caf1deb39ffc2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 24 Oct 2022 17:08:50 -0400 Subject: [PATCH 02/23] Added tests and clean up the code --- src/lib/broadcast.jl | 60 +++++++++++------------------------------ test/complex.jl | 1 - test/cuda.jl | 64 +++++++++++++++++++++++++++++--------------- 3 files changed, 58 insertions(+), 67 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 4c81a2dac..5e5ed38ed 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -120,8 +120,8 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) -# @adjoint broadcasted(::typeof(abs2), x::Numeric) = -# abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) +@adjoint broadcasted(::typeof(abs2), x::Numeric) = + abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b @@ -235,11 +235,13 @@ end import ForwardDiff using ForwardDiff: Dual +# Updated to use proposal from 961 dual(x, p, pc=()) = x dual(x::Real, p, pc=()) = Dual(x, p) dual(x::Bool, p, pc=()) = x dual(x::Complex, p, pc) = Complex(Dual(real(x), p), Dual(imag(x), pc)) +# Updated to use proposal from 961 function dual_function(f::F) where F function (args::Vararg{Any,N}) where N if any(a isa Complex for a in args) @@ -256,30 +258,6 @@ function dual_function(f::F) where F end end -# function dual_function(f::F) where F -# function (args::Vararg{Any,N}) where N -# ds = map(args, ntuple(identity,Val(N))) do x, i -# dual(x, ntuple(j -> i==j, Val(N))) -# end -# return f(ds...) -# end -# end - -# @inline function broadcast_forward(f, args::Vararg{Any,N}) where N -# valN = Val(N) -# out = dual_function(f).(args...) -# eltype(out) <: Dual || return (out, _ -> nothing) -# y = broadcast(x -> x.value, out) -# function bc_fwd_back(ȳ) -# dargs = ntuple(valN) do i -# unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) -# end -# (nothing, nothing, dargs...) # nothings for broadcasted & f -# end -# return y, bc_fwd_back -# end - - @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) eltype(out) <: Union{Dual, Complex} || return (out, _ -> nothing) @@ -303,18 +281,19 @@ function _broadcast_forward(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) wh return y, bc_fwd_back end -# This handles complex output and real input -function _broadcast_forward(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f +# This handles complex output and real input and uses the definition from +# ChainRules.jl's section on complex numbers +function _broadcast_forward(out::AbstractArray{<:Complex{<:Dual}}, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) end - return y, bc_fwd_back + (nothing, nothing, dargs...) # nothings for broadcasted & f end + return y, bc_fwd_back +end # This handles complex input and real output function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} @@ -333,15 +312,6 @@ end # I am a little confused what derivative we want to use here so it hasn't been implemented function _broadcast_forward_complex(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} throw("Complex output and input not supported in Zygote broadcast_forward") - # valN = Val(N) - # y = broadcast(x -> Complex(x.re.value, x.im.value), out) - # function bc_fwd_back(ȳ) - # dargs = ntuple(valN) do i - # unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N-1]), ȳ, out)) - # end - # (nothing, nothing, dargs...) # nothings for broadcasted & f - # end - # return y, bc_fwd_back end using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame diff --git a/test/complex.jl b/test/complex.jl index efb1e06dd..e50c57486 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -120,4 +120,3 @@ end end @test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0] end - diff --git a/test/cuda.jl b/test/cuda.jl index 5cb1c8cdc..113def69e 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -26,7 +26,7 @@ end g_gpu = gradient(x -> v(x, 7), a_gpu)[1] @test g_gpu isa CuArray @test g_gpu |> collect ≈ g - + w(x) = sum(broadcast(log, x)) g = gradient(x -> w(x), a)[1] g_gpu = gradient(x -> w(x), a_gpu)[1] @@ -38,7 +38,7 @@ end @test gradient(x -> sum(x .> 3), a_gpu) == (nothing,) g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 - @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] + @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] # Projection: eltype preservation: @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} @@ -90,40 +90,40 @@ end @testset "gradient algebra" begin w, b = rand(2) |> cu, rand(2) |> cu x1, x2 = rand(2) |> cu, rand(2) |> cu - - gs1 = gradient(() -> sum(w .* x1), Params([w])) - gs2 = gradient(() -> sum(w .* x2), Params([w])) + + gs1 = gradient(() -> sum(w .* x1), Params([w])) + gs2 = gradient(() -> sum(w .* x2), Params([w])) @test .- gs1 isa Grads - @test gs1 .- gs2 isa Grads + @test gs1 .- gs2 isa Grads @test .+ gs1 isa Grads - @test gs1 .+ gs2 isa Grads - @test 2 .* gs1 isa Grads + @test gs1 .+ gs2 isa Grads + @test 2 .* gs1 isa Grads @test (2 .* gs1)[w] ≈ 2 * gs1[w] - @test gs1 .* 2 isa Grads - @test gs1 ./ 2 isa Grads - @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] + @test gs1 .* 2 isa Grads + @test gs1 ./ 2 isa Grads + @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] gs12 = gs1 .+ gs2 gs1 .+= gs2 - @test gs12[w] ≈ gs1[w] + @test gs12[w] ≈ gs1[w] gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b - gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) + gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) @test .- gs3 isa Grads - @test gs3 .- gs4 isa Grads + @test gs3 .- gs4 isa Grads @test .+ gs3 isa Grads - @test gs3 .+ gs4 isa Grads - @test 2 .* gs3 isa Grads - @test gs3 .* 2 isa Grads - @test gs3 ./ 2 isa Grads + @test gs3 .+ gs4 isa Grads + @test 2 .* gs3 isa Grads + @test gs3 .* 2 isa Grads + @test gs3 ./ 2 isa Grads @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] - @test (gs3 .+ gs4)[b] ≈ gs4[b] - + @test (gs3 .+ gs4)[b] ≈ gs4[b] + @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3)) - @test gs3 isa Grads + @test gs3 isa Grads @test_throws ArgumentError gs1 .+ gs4 end @@ -140,3 +140,25 @@ end @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} end + +@testset "CUDA complex broadcasting" begin + # Issue #995 test + x = rand(Float32, 50) + y = complex(rand(Float32, 50)) + + xgpu = cu(x) + ygpu = cu(y) + + f995(A) = norm(@. A*xgpu*ygpu) + g1 = Zygote.gradient(f995, 1f0) + gradcheck(f995, 1f0) + + # Issue 961 and 1121 and 1215 + g1 = Zygote.gradient(x->sum(abs2, x), ygpu) + g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu) + g3 = Zygote.graient(x->sum(abs2, x), y) + @test g1 isa CUDA.CuArray{Float32} + @test g2 isa CUDA.CuArray{Float32} + @test g1 ≈ g2 + @test g1 ≈ g3 +end From 51dc88265ba880caf4519e2c3fc8107b6a33dae7 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 24 Oct 2022 23:37:36 -0400 Subject: [PATCH 03/23] Fix up type instability --- src/lib/broadcast.jl | 96 ++++++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 38 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 5e5ed38ed..d2186d280 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -120,8 +120,8 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) -@adjoint broadcasted(::typeof(abs2), x::Numeric) = - abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) +# @adjoint broadcasted(::typeof(abs2), x::Numeric) = +# abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b @@ -235,37 +235,57 @@ end import ForwardDiff using ForwardDiff: Dual -# Updated to use proposal from 961 -dual(x, p, pc=()) = x -dual(x::Real, p, pc=()) = Dual(x, p) -dual(x::Bool, p, pc=()) = x -dual(x::Complex, p, pc) = Complex(Dual(real(x), p), Dual(imag(x), pc)) -# Updated to use proposal from 961 +# We do this because it ensures type stability so it compiles nicely on the gpu +dual(x, i, N) = x +dual(x::Bool, i, ::Val{N}) where {N} = x +dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(j-> i==j, Val(N))) +# For complex since ForwardDiff.jl doesn't play nicely with complex numbers we +# construct a Complex dual number and tag the real and imaginary parts separately +function dual(x::Complex, i, ::Val{N}) where {N} + re_dual = Dual(real(x), ntuple(j->i==j, Val(2N))) + im_dual = Dual(imag(x), ntuple(j->(N+i)==j, Val(2N))) + return Complex(re_dual, im_dual) +end + function dual_function(f::F) where F - function (args::Vararg{Any,N}) where N - if any(a isa Complex for a in args) - ds = map(args, ntuple(identity, Val(N))) do x, i - dual(x, ntuple(j -> i==j, Val(2N)), ntuple(j -> N+i==j, Val(2N))) - end - return f(ds...) - else - ds = map(args, ntuple(identity,Val(N))) do x, i - dual(x, ntuple(j -> i==j, Val(N))) - end - return f(ds...) + function (args::Vararg{Any,N}) where N + ds = map(args, ntuple(identity,Val(N))) do x, i + tmp = dual(x, i, Val(N)) + return tmp + end + return f(ds...) end end -end + +# @inline function broadcast_forward(f, args::Vararg{Any,N}) where N +# valN = Val(N) +# out = dual_function(f).(args...) +# eltype(out) <: Dual || return (out, _ -> nothing) +# y = broadcast(x -> x.value, out) +# function bc_fwd_back(ȳ) +# dargs = ntuple(valN) do i +# unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) +# end +# (nothing, nothing, dargs...) # nothings for broadcasted & f +# end +# return y, bc_fwd_back +# end + @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) eltype(out) <: Union{Dual, Complex} || return (out, _ -> nothing) - if any(eltype(a) <: Complex for a in args) - _broadcast_forward_complex(out, args...) - else + ifelse( + any(eltype(a) isa Complex for a in args), + _broadcast_forward_complex(out, args...), _broadcast_forward(out, args...) - end + ) +# if any(eltype(a) <: Complex for a in args) +# _broadcast_forward_complex(out, args...) +# else +# _broadcast_forward(out, args...) +# end end # Real input and real output @@ -281,21 +301,21 @@ function _broadcast_forward(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) wh return y, bc_fwd_back end -# This handles complex output and real input and uses the definition from -# ChainRules.jl's section on complex numbers -function _broadcast_forward(out::AbstractArray{<:Complex{<:Dual}}, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) +# This handles complex output and real input +function _broadcast_forward(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f end - (nothing, nothing, dargs...) # nothings for broadcasted & f + return y, bc_fwd_back end - return y, bc_fwd_back -end -# This handles complex input and real output +# This handles complex input and real output we use the gradient definition from ChainRules here +# since it agrees with what Zygote did for real(x). function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> x.value, out) @@ -308,8 +328,8 @@ function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any return y, bc_fwd_back end -# This is for complex input and complex output -# I am a little confused what derivative we want to use here so it hasn't been implemented +# # # This is for complex input and complex output +# # # I am a little confused what derivative we want to use here so it hasn't been implemented function _broadcast_forward_complex(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} throw("Complex output and input not supported in Zygote broadcast_forward") end From 0635ba469eea04f2ef1e3087c4b1c0f95a51d170 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 24 Oct 2022 23:37:54 -0400 Subject: [PATCH 04/23] Add testing --- test/cuda.jl | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 113def69e..7c4456e99 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -142,23 +142,19 @@ end @testset "CUDA complex broadcasting" begin - # Issue #995 test + # Issue 961 and 1121 and 1215 x = rand(Float32, 50) y = complex(rand(Float32, 50)) xgpu = cu(x) ygpu = cu(y) - f995(A) = norm(@. A*xgpu*ygpu) - g1 = Zygote.gradient(f995, 1f0) - gradcheck(f995, 1f0) - # Issue 961 and 1121 and 1215 - g1 = Zygote.gradient(x->sum(abs2, x), ygpu) - g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu) - g3 = Zygote.graient(x->sum(abs2, x), y) - @test g1 isa CUDA.CuArray{Float32} - @test g2 isa CUDA.CuArray{Float32} - @test g1 ≈ g2 - @test g1 ≈ g3 + g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1] + g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1] + g3 = Zygote.gradient(x->sum(abs2, x), y)[1] + @test g1 isa CUDA.CuArray{ComplexF32} + @test g2 isa CUDA.CuArray{ComplexF32} + @test collect(g1) ≈ collect(g2) + @test collect(g1) ≈ g3 end From 739e8969368f741a4873e378d2a98f83d7f1a638 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 25 Oct 2022 00:44:00 -0400 Subject: [PATCH 05/23] Everything passes tests now --- src/lib/broadcast.jl | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index d2186d280..9fdc3ad6f 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -120,8 +120,8 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) -# @adjoint broadcasted(::typeof(abs2), x::Numeric) = -# abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) +@adjoint broadcasted(::typeof(abs2), x::Numeric) = + abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b @@ -276,16 +276,11 @@ function dual_function(f::F) where F @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) eltype(out) <: Union{Dual, Complex} || return (out, _ -> nothing) - ifelse( - any(eltype(a) isa Complex for a in args), - _broadcast_forward_complex(out, args...), + if any(eltype(a) <: Complex for a in args) + _broadcast_forward_complex(out, args...) + else _broadcast_forward(out, args...) - ) -# if any(eltype(a) <: Complex for a in args) -# _broadcast_forward_complex(out, args...) -# else -# _broadcast_forward(out, args...) -# end + end end # Real input and real output @@ -329,9 +324,29 @@ function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any end # # # This is for complex input and complex output -# # # I am a little confused what derivative we want to use here so it hasn't been implemented +# # # I am a little confused what derivative we want to use here but this should match +# what is done for all the tests + +# If we assume that +# f(x + iy) = u(x,y) + iv(x,y) +# them we do the following for the adjoint +# Δu ∂/∂xu + Δv∂/∂x v + i(Δu∂/∂yu + Δv ∂/∂y v) +function _adjoint_complex(Δz, df, i) + Δu, Δv = reim(Δz) + du, dv = reim(df) + return Complex(Δu*du.partials[i] + Δv*dv.partials[i], Δu*du.partials[i+N] + Δv*dv.partials[i+N]) +end + function _broadcast_forward_complex(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} - throw("Complex output and input not supported in Zygote broadcast_forward") + valN = Val(N) + y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(y1, o1, i), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back end using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame From a0e21e69e8a221659ebab8a1c4c4404b4345c555 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 25 Oct 2022 09:54:04 -0400 Subject: [PATCH 06/23] switch to more generic broadcast_forward --- src/lib/broadcast.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 9fdc3ad6f..463f4322e 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -275,16 +275,17 @@ function dual_function(f::F) where F @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) - eltype(out) <: Union{Dual, Complex} || return (out, _ -> nothing) + T = eltype(out) + T <: Union{Dual, Complex} || return (out, _ -> nothing) if any(eltype(a) <: Complex for a in args) - _broadcast_forward_complex(out, args...) + _broadcast_forward_complex(T, out, args...) else - _broadcast_forward(out, args...) + _broadcast_forward(T, out, args...) end end # Real input and real output -function _broadcast_forward(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} +function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> x.value, out) function bc_fwd_back(ȳ) @@ -297,7 +298,7 @@ function _broadcast_forward(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) wh end # This handles complex output and real input -function _broadcast_forward(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} +function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) function bc_fwd_back(ȳ) @@ -311,7 +312,7 @@ function _broadcast_forward(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) # This handles complex input and real output we use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). -function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} +function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> x.value, out) function bc_fwd_back(ȳ) @@ -329,15 +330,16 @@ end # If we assume that # f(x + iy) = u(x,y) + iv(x,y) -# them we do the following for the adjoint -# Δu ∂/∂xu + Δv∂/∂x v + i(Δu∂/∂yu + Δv ∂/∂y v) +# then we do the following for the adjoint +# Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y ) +# this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html function _adjoint_complex(Δz, df, i) Δu, Δv = reim(Δz) du, dv = reim(df) return Complex(Δu*du.partials[i] + Δv*dv.partials[i], Δu*du.partials[i+N] + Δv*dv.partials[i+N]) end -function _broadcast_forward_complex(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} +function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) function bc_fwd_back(ȳ) From 6742644da91ad22e3ab7081c4d8b77eeeafe4e79 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 25 Oct 2022 13:38:27 -0400 Subject: [PATCH 07/23] clean up submission --- src/lib/broadcast.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 463f4322e..0d51ab73f 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -258,20 +258,6 @@ function dual_function(f::F) where F end end -# @inline function broadcast_forward(f, args::Vararg{Any,N}) where N -# valN = Val(N) -# out = dual_function(f).(args...) -# eltype(out) <: Dual || return (out, _ -> nothing) -# y = broadcast(x -> x.value, out) -# function bc_fwd_back(ȳ) -# dargs = ntuple(valN) do i -# unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) -# end -# (nothing, nothing, dargs...) # nothings for broadcasted & f -# end -# return y, bc_fwd_back -# end - @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) From 2aa06c62945dd239316e9b967857e46d05bcc6ab Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 25 Oct 2022 23:20:43 -0400 Subject: [PATCH 08/23] Remove various Val's --- src/lib/broadcast.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 0d51ab73f..d50e45808 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -238,19 +238,19 @@ using ForwardDiff: Dual # We do this because it ensures type stability so it compiles nicely on the gpu dual(x, i, N) = x -dual(x::Bool, i, ::Val{N}) where {N} = x -dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(j-> i==j, Val(N))) +dual(x::Bool, i, N) = x +dual(x::Real, i, N) = Dual(x, ntuple(j-> i==j, N)) # For complex since ForwardDiff.jl doesn't play nicely with complex numbers we # construct a Complex dual number and tag the real and imaginary parts separately -function dual(x::Complex, i, ::Val{N}) where {N} - re_dual = Dual(real(x), ntuple(j->i==j, Val(2N))) - im_dual = Dual(imag(x), ntuple(j->(N+i)==j, Val(2N))) +function dual(x::Complex, i, N) + re_dual = Dual(real(x), ntuple(==(i), 2N)) + im_dual = Dual(imag(x), ntuple(==(N+i), 2N)) return Complex(re_dual, im_dual) end function dual_function(f::F) where F function (args::Vararg{Any,N}) where N - ds = map(args, ntuple(identity,Val(N))) do x, i + ds = map(args, ntuple(identity,N)) do x, i tmp = dual(x, i, Val(N)) return tmp end From 851ab3318b1af7a2dde79e31f4676f349f86407f Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 31 Oct 2022 12:42:38 -0400 Subject: [PATCH 09/23] change to Complex{<:Dual} --- src/lib/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index d50e45808..c628858b2 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -262,7 +262,7 @@ function dual_function(f::F) where F @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) T = eltype(out) - T <: Union{Dual, Complex} || return (out, _ -> nothing) + T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing) if any(eltype(a) <: Complex for a in args) _broadcast_forward_complex(T, out, args...) else From f42d9408364de2aec90e51ae5137f98740b18284 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 31 Oct 2022 13:17:38 -0400 Subject: [PATCH 10/23] add mixed complex and real to cuda testing --- test/cuda.jl | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 7c4456e99..cdf1fc629 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -2,6 +2,7 @@ using CUDA using Zygote: Grads using LinearAlgebra using Random: randn! +using FiniteDifferences CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @@ -144,12 +145,11 @@ end @testset "CUDA complex broadcasting" begin # Issue 961 and 1121 and 1215 x = rand(Float32, 50) - y = complex(rand(Float32, 50)) + y = rand(ComplexF32, 50) xgpu = cu(x) ygpu = cu(y) - g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1] g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1] g3 = Zygote.gradient(x->sum(abs2, x), y)[1] @@ -157,4 +157,22 @@ end @test g2 isa CUDA.CuArray{ComplexF32} @test collect(g1) ≈ collect(g2) @test collect(g1) ≈ g3 + + + + # Test real and complex mixed derivates + fm1(x,y) = sum(abs2, x.^2 .*y .+ y) + + m = central_fdm(5,1) + gx_fd, gy_fd = grad(m, fm1, x, y) + + # Test mixed derivatives on CUDA + gx_gpu, gy_gpu = Zygote.gradient(fm1, xgpu, ygpu) + gx_cpu, gy_cpu = Zygote.gradient(fm1, x, y) + @test collect(gx_gpu) ≈ gx_cpu + @test collect(gy_gpu) ≈ gy_cpu + + @test collect(gx_cpu) ≈ gx_fd + @test collect(gx_cpu) ≈ gx_fd + end From 95a6b5bd14fbc9eb5fac1bea33add3559d79ff0d Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 31 Oct 2022 16:14:07 -0400 Subject: [PATCH 11/23] import not using --- test/cuda.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index cdf1fc629..26e2c2aa9 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -2,7 +2,7 @@ using CUDA using Zygote: Grads using LinearAlgebra using Random: randn! -using FiniteDifferences +import FiniteDifferences CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @@ -163,8 +163,8 @@ end # Test real and complex mixed derivates fm1(x,y) = sum(abs2, x.^2 .*y .+ y) - m = central_fdm(5,1) - gx_fd, gy_fd = grad(m, fm1, x, y) + m = FiniteDifferences.central_fdm(5,1) + gx_fd, gy_fd = FiniteDifferences.grad(m, fm1, x, y) # Test mixed derivatives on CUDA gx_gpu, gy_gpu = Zygote.gradient(fm1, xgpu, ygpu) From b29f0904dbdd2f1feb51d1f676e5105473d51738 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 1 Nov 2022 10:33:12 -0400 Subject: [PATCH 12/23] Add complex to _dual_safearg --- src/lib/broadcast.jl | 32 +++++++++++++++++++------------- test/cuda.jl | 37 ++++++++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index c628858b2..74a9c3729 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -184,6 +184,8 @@ _dual_purefun(::Type) = false _dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers _dual_safearg(x::Numeric{<:Real}) = true +_dual_safearg(x::Numeric{<:Complex}) = true +_dual_safearg(x::Ref{<:Numeric{<:Complex}}) = true _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true _dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types _dual_safearg(x) = false @@ -233,26 +235,30 @@ end # Forward Mode -- necessary for CUDA, also used as a fast path above import ForwardDiff -using ForwardDiff: Dual +using ForwardDiff: Dual, Partials # We do this because it ensures type stability so it compiles nicely on the gpu -dual(x, i, N) = x -dual(x::Bool, i, N) = x -dual(x::Real, i, N) = Dual(x, ntuple(j-> i==j, N)) -# For complex since ForwardDiff.jl doesn't play nicely with complex numbers we +@inline dual(x, i, ::Val) = x +@inline dual(x::Bool, i, ::Val) = x +@inline dual(x::Real, i, ::Val{N}) where {N} = Dual{typeof(x)}(x, ntuple(j -> (i==j & j < (N+1)), 2N)) +# function dual(x::Real, i, ::Val{N}) where {N} +# re = Dual(x, ntuple(j -> i==j, 2*N)) +# im = Dual(zero(x), ntuple(j -> i==j, 2*N)) +# return Complex(re, im) +# end + # For complex since ForwardDiff.jl doesn't play nicely with complex numbers we # construct a Complex dual number and tag the real and imaginary parts separately -function dual(x::Complex, i, N) - re_dual = Dual(real(x), ntuple(==(i), 2N)) - im_dual = Dual(imag(x), ntuple(==(N+i), 2N)) +@inline function dual(x::Complex{T}, i, ::Val{N}) where {T<:Real,N} + re_dual = Dual{T}(T(real(x)), Partials{2N,T}(ntuple(==(i), Val(2N)))) + im_dual = Dual{T}(T(imag(x)), Partials{2N,T}(ntuple(==(N+i), Val(2N)))) return Complex(re_dual, im_dual) end function dual_function(f::F) where F function (args::Vararg{Any,N}) where N ds = map(args, ntuple(identity,N)) do x, i - tmp = dual(x, i, Val(N)) - return tmp + return dual(x, i, Val(N)) end return f(ds...) end @@ -298,7 +304,7 @@ function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where # This handles complex input and real output we use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). -function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} +@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> x.value, out) function bc_fwd_back(ȳ) @@ -319,7 +325,7 @@ end # then we do the following for the adjoint # Δu ∂u/∂x + Δv∂v/∂x + i(Δu∂u/∂y + Δv ∂v/∂y ) # this follows https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html -function _adjoint_complex(Δz, df, i) +function _adjoint_complex(N, Δz, df, i) Δu, Δv = reim(Δz) du, dv = reim(df) return Complex(Δu*du.partials[i] + Δv*dv.partials[i], Δu*du.partials[i+N] + Δv*dv.partials[i+N]) @@ -330,7 +336,7 @@ function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N} y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(y1, o1, i), ȳ, out)) + unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out)) end (nothing, nothing, dargs...) # nothings for broadcasted & f end diff --git a/test/cuda.jl b/test/cuda.jl index 26e2c2aa9..f9f48268b 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -5,6 +5,14 @@ using Random: randn! import FiniteDifferences CUDA.allowscalar(false) +function gradcheck_gpu(f, xs...) + grad_zygote = gradient(f, xs...) + m = FiniteDifferences.central_fdm(5,1) + grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...) + return all(isapprox.(collect.(grad_zygote), grad_finite_difference)) + end + + # Test GPU movement inside the call to `gradient` @testset "GPU movement" begin r = rand(Float32, 3,3) @@ -160,19 +168,26 @@ end - # Test real and complex mixed derivates - fm1(x,y) = sum(abs2, x.^2 .*y .+ y) - m = FiniteDifferences.central_fdm(5,1) - gx_fd, gy_fd = FiniteDifferences.grad(m, fm1, x, y) + @test gradcheck_gpu((x,y)->sum(abs2, x.^2 .+ y), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs2, cos.(x) .+ sin.(y)), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, cos.(x).*sin.(y)), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ imag.(y)), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ log.(y)), xgpu, ygpu) + + + # m = FiniteDifferences.central_fdm(5,1) + # gx_fd, gy_fd = FiniteDifferences.grad(m, fm1, x, y) - # Test mixed derivatives on CUDA - gx_gpu, gy_gpu = Zygote.gradient(fm1, xgpu, ygpu) - gx_cpu, gy_cpu = Zygote.gradient(fm1, x, y) - @test collect(gx_gpu) ≈ gx_cpu - @test collect(gy_gpu) ≈ gy_cpu + # # Test mixed derivatives on CUDA + # gx_gpu, gy_gpu = Zygote.gradient(fm1, xgpu, ygpu) + # gx_cpu, gy_cpu = Zygote.gradient(fm1, x, y) + # @test collect(gx_gpu) ≈ gx_cpu + # @test collect(gy_gpu) ≈ gy_cpu - @test collect(gx_cpu) ≈ gx_fd - @test collect(gx_cpu) ≈ gx_fd + # @test collect(gx_cpu) ≈ gx_fd + # @test collect(gx_cpu) ≈ gx_fd end From 40fdb296a8cba2eecf6a1cf27b80950b888fc92c Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 1 Nov 2022 16:02:34 -0400 Subject: [PATCH 13/23] Type stable on my computer --- src/lib/broadcast.jl | 29 +++++++++++++++++------------ test/cuda.jl | 12 +++++++++--- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 74a9c3729..b6d402bec 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -235,7 +235,7 @@ end # Forward Mode -- necessary for CUDA, also used as a fast path above import ForwardDiff -using ForwardDiff: Dual, Partials +using ForwardDiff: Dual, Partials, value, partials # We do this because it ensures type stability so it compiles nicely on the gpu @@ -250,16 +250,21 @@ using ForwardDiff: Dual, Partials # For complex since ForwardDiff.jl doesn't play nicely with complex numbers we # construct a Complex dual number and tag the real and imaginary parts separately @inline function dual(x::Complex{T}, i, ::Val{N}) where {T<:Real,N} - re_dual = Dual{T}(T(real(x)), Partials{2N,T}(ntuple(==(i), Val(2N)))) - im_dual = Dual{T}(T(imag(x)), Partials{2N,T}(ntuple(==(N+i), Val(2N)))) + re_dual = Dual{T}(T(real(x)), Partials{2N,T}(ntuple(==(i), 2N))) + im_dual = Dual{T}(T(imag(x)), Partials{2N,T}(ntuple(==(N+i), 2N))) return Complex(re_dual, im_dual) end -function dual_function(f::F) where F - function (args::Vararg{Any,N}) where N - ds = map(args, ntuple(identity,N)) do x, i +function dualize(args::Vararg{Any, N}) where {N} + ds = map(args, ntuple(identity,N)) do x, i return dual(x, i, Val(N)) end + return ds +end + +function dual_function(f::F) where F + function (args::Vararg{Any,N}) where N + ds = dualize(args...) return f(ds...) end end @@ -279,10 +284,10 @@ end # Real input and real output function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> x.value, out) + y = broadcast(x -> value(x), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) + unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out)) end (nothing, nothing, dargs...) # nothings for broadcasted & f end @@ -292,10 +297,10 @@ end # This handles complex output and real input function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + y = broadcast(x -> Complex.(value(real(x)), value(imag(x))), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) + unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out)) end (nothing, nothing, dargs...) # nothings for broadcasted & f end @@ -306,10 +311,10 @@ function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where # since it agrees with what Zygote did for real(x). @inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> x.value, out) + y = broadcast(x -> value(x), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N]), ȳ, out)) + unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out)) end (nothing, nothing, dargs...) # nothings for broadcasted & f end diff --git a/test/cuda.jl b/test/cuda.jl index f9f48268b..1f2286dfa 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -7,6 +7,7 @@ CUDA.allowscalar(false) function gradcheck_gpu(f, xs...) grad_zygote = gradient(f, xs...) + #@inferred gradient(f, xs...) m = FiniteDifferences.central_fdm(5,1) grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...) return all(isapprox.(collect.(grad_zygote), grad_finite_difference)) @@ -155,8 +156,8 @@ end x = rand(Float32, 50) y = rand(ComplexF32, 50) - xgpu = cu(x) - ygpu = cu(y) + xgpu =cu(x) + ygpu =cu(y) g1 = Zygote.gradient(x->sum(abs2, x), ygpu)[1] g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu)[1] @@ -167,6 +168,8 @@ end @test collect(g1) ≈ g3 + r3 = Float32.(inv.(2:4)) + c3 = ComplexF32.(inv.(5:7) .+ im ./ (8:10)) @test gradcheck_gpu((x,y)->sum(abs2, x.^2 .+ y), xgpu, ygpu) @@ -175,7 +178,10 @@ end @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) @test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ imag.(y)), xgpu, ygpu) - @test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ log.(y)), xgpu, ygpu) + # @test gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)[1] ≈ [-0.4124833f0 + 0.49228126f0im, -0.4258298f0 + 0.49446818f0im, -0.43560573f0 + 0.49583605f0im] + # @inferred gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) + @test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im] + @inferred gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3) # m = FiniteDifferences.central_fdm(5,1) From 15c33ada4bd82bf25ac6610f37d5bb1aa1735443 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Tue, 1 Nov 2022 17:54:50 -0400 Subject: [PATCH 14/23] Fix Dual tagging --- src/lib/broadcast.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index b6d402bec..8c962b0a4 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -185,8 +185,8 @@ _dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powe _dual_safearg(x::Numeric{<:Real}) = true _dual_safearg(x::Numeric{<:Complex}) = true -_dual_safearg(x::Ref{<:Numeric{<:Complex}}) = true _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true +_dual_safearg(x::Ref{<:Numeric{<:Complex}}) = true _dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types _dual_safearg(x) = false @@ -239,9 +239,9 @@ using ForwardDiff: Dual, Partials, value, partials # We do this because it ensures type stability so it compiles nicely on the gpu -@inline dual(x, i, ::Val) = x -@inline dual(x::Bool, i, ::Val) = x -@inline dual(x::Real, i, ::Val{N}) where {N} = Dual{typeof(x)}(x, ntuple(j -> (i==j & j < (N+1)), 2N)) +@inline dual(x, i, ::Val{N}) where {N} = x +@inline dual(x::Bool, i, ::Val{N}) where {N} = x +@inline dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(==(i), 2N)) # function dual(x::Real, i, ::Val{N}) where {N} # re = Dual(x, ntuple(j -> i==j, 2*N)) # im = Dual(zero(x), ntuple(j -> i==j, 2*N)) @@ -249,9 +249,9 @@ using ForwardDiff: Dual, Partials, value, partials # end # For complex since ForwardDiff.jl doesn't play nicely with complex numbers we # construct a Complex dual number and tag the real and imaginary parts separately -@inline function dual(x::Complex{T}, i, ::Val{N}) where {T<:Real,N} - re_dual = Dual{T}(T(real(x)), Partials{2N,T}(ntuple(==(i), 2N))) - im_dual = Dual{T}(T(imag(x)), Partials{2N,T}(ntuple(==(N+i), 2N))) +@inline function dual(x::Complex{T}, i, ::Val{N}) where {T,N} + re_dual = Dual(real(x), ntuple(==(i), 2N)) + im_dual = Dual(imag(x), ntuple(==(N+i), 2N)) return Complex(re_dual, im_dual) end From 5e53ada0ad5af703173f7808236916bc06e750aa Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 3 Nov 2022 11:26:18 -0400 Subject: [PATCH 15/23] Add more tests --- test/cuda.jl | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 1f2286dfa..14aad71ae 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -168,21 +168,34 @@ end @test collect(g1) ≈ g3 - r3 = Float32.(inv.(2:4)) - c3 = ComplexF32.(inv.(5:7) .+ im ./ (8:10)) + r3 = cu(Float32.(inv.(2:4))) + c3 = cu(ComplexF32.(inv.(5:7) .+ im ./ (8:10))) + # These check _broadcast_forward(::Type{<:Dual}, ...) @test gradcheck_gpu((x,y)->sum(abs2, x.^2 .+ y), xgpu, ygpu) + @test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ imag.(y)), xgpu, ygpu) + + # These check _broadcast_forward_complex(::Type{<:Complex}, ...) @test gradcheck_gpu((x,y)->sum(abs2, cos.(x) .+ sin.(y)), xgpu, ygpu) @test gradcheck_gpu((x,y)->sum(abs, cos.(x).*sin.(y)), xgpu, ygpu) @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) - @test gradcheck_gpu((x,y)->sum(abs, exp.(x) .+ imag.(y)), xgpu, ygpu) - # @test gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)[1] ≈ [-0.4124833f0 + 0.49228126f0im, -0.4258298f0 + 0.49446818f0im, -0.43560573f0 + 0.49583605f0im] + @test gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)[1] ≈ [-0.4124833f0 + 0.49228126f0im, -0.4258298f0 + 0.49446818f0im, -0.43560573f0 + 0.49583605f0im] # @inferred gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) @test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im] @inferred gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3) + # These check _broadcast_forward(::Type{<:Complex}, ...) + @test gradcheck_gpu(x->sum(real, cis.(x)), xgpu) + @test gradcheck_gpu(x->sum(real, cispi.(x)), xgpu) + + # These check _broadcast_forward_complex(::Type{<:Dual}, ...) + @test gradcheck_gpu(x->sum(real, x.^2 .+ abs.(sinh.(conj.(x)))), ygpu) + + + + # m = FiniteDifferences.central_fdm(5,1) # gx_fd, gy_fd = FiniteDifferences.grad(m, fm1, x, y) From 2c4857b8b6d2280a37c771a10a2a83f833386a5c Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Fri, 4 Nov 2022 00:31:23 -0400 Subject: [PATCH 16/23] update tests --- test/cuda.jl | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 14aad71ae..171eb7304 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -181,10 +181,12 @@ end @test gradcheck_gpu((x,y)->sum(abs, cos.(x).*sin.(y)), xgpu, ygpu) @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) - @test gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)[1] ≈ [-0.4124833f0 + 0.49228126f0im, -0.4258298f0 + 0.49446818f0im, -0.43560573f0 + 0.49583605f0im] - # @inferred gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) - @test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im] - @inferred gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3) + @test gradcheck_gpu((r,c) -> sum(abs2, sin.(conj.(c)./transpose(r) .- im) .- imag.(c .+ tanh.(r./c'))), r3, c3) + + # Commented out for now because of the ldexp bug + # @test gradcheck_gpu(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) + # @test gradcheck_gpu(r -> sum(abs2, log.(1 .+ im .* r)./2), r3) + # These check _broadcast_forward(::Type{<:Complex}, ...) @test gradcheck_gpu(x->sum(real, cis.(x)), xgpu) @@ -194,19 +196,4 @@ end @test gradcheck_gpu(x->sum(real, x.^2 .+ abs.(sinh.(conj.(x)))), ygpu) - - - - # m = FiniteDifferences.central_fdm(5,1) - # gx_fd, gy_fd = FiniteDifferences.grad(m, fm1, x, y) - - # # Test mixed derivatives on CUDA - # gx_gpu, gy_gpu = Zygote.gradient(fm1, xgpu, ygpu) - # gx_cpu, gy_cpu = Zygote.gradient(fm1, x, y) - # @test collect(gx_gpu) ≈ gx_cpu - # @test collect(gy_gpu) ≈ gy_cpu - - # @test collect(gx_cpu) ≈ gx_fd - # @test collect(gx_cpu) ≈ gx_fd - end From 9fc218032f83f156b0e024719285c1e047e3604c Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Nov 2022 19:25:39 -0400 Subject: [PATCH 17/23] First attempt to fix real performance regression --- src/lib/broadcast.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 8c962b0a4..0e3e4f6a3 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -241,12 +241,7 @@ using ForwardDiff: Dual, Partials, value, partials # We do this because it ensures type stability so it compiles nicely on the gpu @inline dual(x, i, ::Val{N}) where {N} = x @inline dual(x::Bool, i, ::Val{N}) where {N} = x -@inline dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(==(i), 2N)) -# function dual(x::Real, i, ::Val{N}) where {N} -# re = Dual(x, ntuple(j -> i==j, 2*N)) -# im = Dual(zero(x), ntuple(j -> i==j, 2*N)) -# return Complex(re, im) -# end +@inline dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(==(i), N)) # For complex since ForwardDiff.jl doesn't play nicely with complex numbers we # construct a Complex dual number and tag the real and imaginary parts separately @inline function dual(x::Complex{T}, i, ::Val{N}) where {T,N} @@ -255,6 +250,8 @@ using ForwardDiff: Dual, Partials, value, partials return Complex(re_dual, im_dual) end + + function dualize(args::Vararg{Any, N}) where {N} ds = map(args, ntuple(identity,N)) do x, i return dual(x, i, Val(N)) From c6857987a6c4ff50e16e09d98e3482c2f2e7f8e2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 10 Nov 2022 00:08:06 -0500 Subject: [PATCH 18/23] Uncomment ldexp rules --- test/cuda.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 171eb7304..31a4f60ae 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -184,8 +184,8 @@ end @test gradcheck_gpu((r,c) -> sum(abs2, sin.(conj.(c)./transpose(r) .- im) .- imag.(c .+ tanh.(r./c'))), r3, c3) # Commented out for now because of the ldexp bug - # @test gradcheck_gpu(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) - # @test gradcheck_gpu(r -> sum(abs2, log.(1 .+ im .* r)./2), r3) + @test gradcheck_gpu(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) + @test gradcheck_gpu(r -> sum(abs2, log.(1 .+ im .* r)./2), r3) # These check _broadcast_forward(::Type{<:Complex}, ...) From efc4f6733823d4504d9e14de9a3d40ccc6fa14a5 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 10 Nov 2022 13:07:45 -0500 Subject: [PATCH 19/23] cleanup broadcast and inline --- src/lib/broadcast.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 0e3e4f6a3..e4adc1cd3 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -259,7 +259,7 @@ function dualize(args::Vararg{Any, N}) where {N} return ds end -function dual_function(f::F) where F +@inline function dual_function(f::F) where F function (args::Vararg{Any,N}) where N ds = dualize(args...) return f(ds...) @@ -279,7 +279,7 @@ function dual_function(f::F) where F end # Real input and real output -function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} +@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> value(x), out) function bc_fwd_back(ȳ) @@ -292,9 +292,9 @@ function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} end # This handles complex output and real input -function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} +@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> Complex.(value(real(x)), value(imag(x))), out) + y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out)) @@ -333,9 +333,9 @@ function _adjoint_complex(N, Δz, df, i) return Complex(Δu*du.partials[i] + Δv*dv.partials[i], Δu*du.partials[i+N] + Δv*dv.partials[i+N]) end -function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} +@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) - y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out)) From 51e3ba39f805934ebd15ad30d91909c56956cea8 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Fri, 11 Nov 2022 12:45:43 -0500 Subject: [PATCH 20/23] update tests --- test/cuda.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 31a4f60ae..23c1aebe7 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -7,11 +7,10 @@ CUDA.allowscalar(false) function gradcheck_gpu(f, xs...) grad_zygote = gradient(f, xs...) - #@inferred gradient(f, xs...) m = FiniteDifferences.central_fdm(5,1) grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...) return all(isapprox.(collect.(grad_zygote), grad_finite_difference)) - end +end # Test GPU movement inside the call to `gradient` @@ -183,7 +182,7 @@ end @test gradcheck_gpu((x,y)->sum(abs, cos.(x) .+ sin.(conj.(y))), xgpu, ygpu) @test gradcheck_gpu((r,c) -> sum(abs2, sin.(conj.(c)./transpose(r) .- im) .- imag.(c .+ tanh.(r./c'))), r3, c3) - # Commented out for now because of the ldexp bug + # Painful test! @test gradcheck_gpu(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3) @test gradcheck_gpu(r -> sum(abs2, log.(1 .+ im .* r)./2), r3) @@ -193,7 +192,7 @@ end @test gradcheck_gpu(x->sum(real, cispi.(x)), xgpu) # These check _broadcast_forward_complex(::Type{<:Dual}, ...) - @test gradcheck_gpu(x->sum(real, x.^2 .+ abs.(sinh.(conj.(x)))), ygpu) + @test gradcheck_gpu(x->sum(imag, x.^2 .+ abs.(sinh.(conj.(x)))), ygpu) end From c888db876617a15cfdcdba8a5b9c009b4b474f10 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Fri, 11 Nov 2022 18:10:56 -0500 Subject: [PATCH 21/23] specify more reasonable tolerance for float32 --- test/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cuda.jl b/test/cuda.jl index 23c1aebe7..8f852c571 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -9,7 +9,7 @@ function gradcheck_gpu(f, xs...) grad_zygote = gradient(f, xs...) m = FiniteDifferences.central_fdm(5,1) grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...) - return all(isapprox.(collect.(grad_zygote), grad_finite_difference)) + return all(isapprox.(collect.(grad_zygote), grad_finite_difference, atol=1e-5)) end From 83ed91753d1e87f379c8676e051f0976e06a3212 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 12 Nov 2022 14:29:45 -0500 Subject: [PATCH 22/23] revert testing bug --- test/cuda.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 8f852c571..171fa45db 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -9,7 +9,7 @@ function gradcheck_gpu(f, xs...) grad_zygote = gradient(f, xs...) m = FiniteDifferences.central_fdm(5,1) grad_finite_difference = FiniteDifferences.grad(m, f, collect.(xs)...) - return all(isapprox.(collect.(grad_zygote), grad_finite_difference, atol=1e-5)) + return all(isapprox.(collect.(grad_zygote), grad_finite_difference)) end @@ -152,8 +152,8 @@ end @testset "CUDA complex broadcasting" begin # Issue 961 and 1121 and 1215 - x = rand(Float32, 50) - y = rand(ComplexF32, 50) + x = 2*rand(Float32, 10) .- 1f0 + y = 2*rand(ComplexF32, 10) .- 1f0 xgpu =cu(x) ygpu =cu(y) From 7b0044b51ca625be2688bae72edc5299cc5fc520 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 16 Nov 2022 09:48:29 -0500 Subject: [PATCH 23/23] clean up the submission --- src/lib/broadcast.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index e4adc1cd3..dc7b053c1 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -239,10 +239,11 @@ using ForwardDiff: Dual, Partials, value, partials # We do this because it ensures type stability so it compiles nicely on the gpu +# The val is needed for some type stability @inline dual(x, i, ::Val{N}) where {N} = x @inline dual(x::Bool, i, ::Val{N}) where {N} = x @inline dual(x::Real, i, ::Val{N}) where {N} = Dual(x, ntuple(==(i), N)) - # For complex since ForwardDiff.jl doesn't play nicely with complex numbers we +# For complex since ForwardDiff.jl doesn't play nicely with complex numbers we # construct a Complex dual number and tag the real and imaginary parts separately @inline function dual(x::Complex{T}, i, ::Val{N}) where {T,N} re_dual = Dual(real(x), ntuple(==(i), 2N)) @@ -250,8 +251,6 @@ using ForwardDiff: Dual, Partials, value, partials return Complex(re_dual, im_dual) end - - function dualize(args::Vararg{Any, N}) where {N} ds = map(args, ntuple(identity,N)) do x, i return dual(x, i, Val(N)) @@ -278,7 +277,7 @@ end end end -# Real input and real output +# Real input and real output pullback @inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> value(x), out) @@ -291,7 +290,7 @@ end return y, bc_fwd_back end -# This handles complex output and real input +# This handles the complex output and real input pullback @inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) @@ -304,7 +303,7 @@ end return y, bc_fwd_back end -# This handles complex input and real output we use the gradient definition from ChainRules here +# This handles complex input and real output. We use the gradient definition from ChainRules here # since it agrees with what Zygote did for real(x). @inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) @@ -319,9 +318,6 @@ end end # # # This is for complex input and complex output -# # # I am a little confused what derivative we want to use here but this should match -# what is done for all the tests - # If we assume that # f(x + iy) = u(x,y) + iv(x,y) # then we do the following for the adjoint @@ -330,7 +326,7 @@ end function _adjoint_complex(N, Δz, df, i) Δu, Δv = reim(Δz) du, dv = reim(df) - return Complex(Δu*du.partials[i] + Δv*dv.partials[i], Δu*du.partials[i+N] + Δv*dv.partials[i+N]) + return Complex(Δu*partials(du, i) + Δv*partials(dv, i), Δu*partials(du, i+N) + Δv*partials(dv, i+N)) end @inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}