Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

abs2 of complex CUDA array fails with Zygote.gradient #961

Closed
roflmaostc opened this issue May 3, 2021 · 16 comments
Closed

abs2 of complex CUDA array fails with Zygote.gradient #961

roflmaostc opened this issue May 3, 2021 · 16 comments
Labels
CUDA All things GPU

Comments

@roflmaostc
Copy link

roflmaostc commented May 3, 2021

Hey,

when applying abs2 to a complex CUDA array I get an ERROR: MethodError: no method matching iterate(::Nothing).
I'm using CUDA 3.1.0, Julia 1.6.1 and Zygote 0.6.10.
But I also tried it on Julia 1.5.4, CUDA v2.4.0, Zygote v0.5.0 so it must be not a recent introduced issue.

See the MWE below:

julia> using Zygote, CUDA

julia> x = rand(ComplexF32, (2,2))
2×2 Matrix{ComplexF32}:
 0.0598111+0.678913im  0.767138+0.77825im
  0.548067+0.98656im   0.306103+0.166084im

julia> x_c = CuArray(x);

julia> f(x) = sum(abs2.(x))
f (generic function with 1 method)

julia> g(x) = sum(real(x .* conj.(x)))
g (generic function with 1 method)

julia> f(x)  f(x_c)  g(x)  g(x_c)
true

julia> Zygote.gradient(f, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)

julia> Zygote.gradient(g, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)

julia> Zygote.gradient(f, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
  [2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ./broadcast.jl:1309 [inlined]
  [6] Pullback
    @ ./REPL[18]:1 [inlined]
  [7] (::typeof((f)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#41#42"{typeof((f))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::CuArray{ComplexF32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [10] top-level scope
    @ REPL[24]:1
 [11] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

julia> Zygote.gradient(g, x_c)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)
Manifest.toml # This file is machine-generated - editing it directly is not advised

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.0.1"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.0"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[BFloat16s]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.1.0"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[CEnum]]
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"]
git-tree-sha1 = "d4fa6486e94c4087f1d081d7be2d501a170bd51d"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.1.0"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "1f410fba5c04d03ab712f348f1542e6059376547"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.61"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "bd0cc939d94b8bd736dce5bbbe0d635db9f94af7"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.41"

[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.27.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.9"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.0.3"

[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.0.2"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[ExprTools]]
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.3"

[[FFTW]]
deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"]
git-tree-sha1 = "1dc6ca6ad69eb9beadd3ce82b90910f4fa63d7c3"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.4.0"

[[FFTW_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "5a0d4b6a22a34d17d53543bd124f4b08ed78e8b0"
uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a"
version = "3.3.9+7"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.7"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.18"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "3e10e95ddc385e1589c27b1a58f21bf3008b559c"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "6.3.0"

[[GPUCompiler]]
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.11.4"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.2"

[[IntelOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c"
uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0"
version = "2018.0.3+2"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.3.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "3.6.0"

[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"

[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MKL_jll]]
deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
git-tree-sha1 = "c253236b0ed414624b083e6b72bfe891fbd2c7af"
uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
version = "2021.1.1+1"

[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.6"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Memoize]]
deps = ["MacroTools"]
git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa"
uuid = "c03570c3-d221-55d1-a50c-7939bbd78826"
version = "0.4.4"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.5"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.4+0"

[[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0"

[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.1"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.4.0"

[[Reexport]]
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.0.0"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Scratch]]
deps = ["Dates"]
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
uuid = "6c6a2e73-6563-6170-7368-637461726353"
version = "1.0.3"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.3.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "2653e9c769343808781a8bd5010ee7a17c01152e"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.1.2"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"

[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.8"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "927209c83efa62256788a9880c191774c07c5b51"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.10"

[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.1"

[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

Thanks a lot!

Felix

@roflmaostc
Copy link
Author

roflmaostc commented May 4, 2021

Uh, I believe the problem is more general than posted above. Several basic functions are broken:

julia> f_exp(x) = sum(real(exp.(x)))
f_exp (generic function with 1 method)

julia> Zygote.gradient(f_exp, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
  [2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ./broadcast.jl:1309 [inlined]
  [6] Pullback
    @ ./REPL[19]:1 [inlined]
  [7] (::typeof((f_exp)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#41#42"{typeof((f_exp))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::CuArray{ComplexF32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [10] top-level scope
    @ REPL[20]:1
 [11] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

julia> f_abs(x) = sum(real(abs.(x)))
f_abs (generic function with 1 method)

julia> Zygote.gradient(f_abs, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
  [2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ./broadcast.jl:1309 [inlined]
  [6] Pullback
    @ ./REPL[27]:1 [inlined]
  [7] (::typeof((f_abs)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#41#42"{typeof((f_abs))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::CuArray{ComplexF32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [10] top-level scope
    @ REPL[28]:1
 [11] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

@roflmaostc roflmaostc changed the title abs2 of complex array fails with Zygote.gradient abs2 of complex CUDA array fails with Zygote.gradient May 4, 2021
@DhairyaLGandhi
Copy link
Member

Interesting, it's usually the complex numbers right? We might be better off making sure our adjoints can handle complex numbers properly. Some of it prolly needs fixing in ChainRules too. Ideally we'd be able to repro without relying on cuda, is that the case?

@roflmaostc
Copy link
Author

roflmaostc commented May 4, 2021

Without CUDA it seems to work fine, I haven't encountered such an issue yet. Above the examples are with and without CUDA. The only one that fails, is Zygote + Array{<:Complex} + CUDA.

@ToucheSir
Copy link
Member

The relevant lines: https://github.com/FluxML/Zygote.jl/blob/v0.6.10/src/lib/broadcast.jl#L195-L233

It appears broadcasting on complex numbers hits the early bailout in https://github.com/FluxML/Zygote.jl/blob/v0.6.10/src/lib/broadcast.jl#L213, returning a back function that itself returns none (and thus can't be splatted). I would assume this is not the intended behaviour and that at least some value should be propogated?

@roflmaostc
Copy link
Author

roflmaostc commented May 4, 2021

Hm, I tried to do some debugging.

T = Broadcast.combine_eltypes(f, args) is always Float32 independent whether we have a CuArray{<:Complex} or CuArray{<:Real]}. However, that T is not used anywhere.

I'm not sure, but this looks strange, doesn't it?

julia> Zygote.dual_function(abs2).(CuArray(randn(Float32, (2,2))))
2×2 CuArray{Dual{Nothing, Float32, 1}, 2}:
 Dual{Nothing}(0.0782065,-0.559308)  Dual{Nothing}(0.000581367,-0.0482231)
 Dual{Nothing}(0.189928,-0.871615)   Dual{Nothing}(0.37119,1.21851)

julia> Zygote.dual_function(abs2).(CuArray(randn(ComplexF32, (2,2))))
2×2 CuArray{Float32, 2}:
 2.71749   0.407771
 0.735002  0.240149

The following

 @adjoint broadcasted(::typeof(abs2), x::CuArray) = 
            abs2.(x), a -> (nothing, 2 .* a .* x) # copied but replaced Numeric with CuArray

also seems to fix the issue, I mean that works, but is not a general solution.
There must be definitely a general pattern for that...

julia> using Zygote, CUDA

julia> using Zygote: @adjoint, Numeric, broadcasted

julia> x_c = CUDA.rand(ComplexF32, 2, 2)
2×2 CuArray{ComplexF32, 2}:
 0.508057+0.580545im  0.846364+0.0125523im
 0.965085+0.773818im  0.962015+0.289069im

julia> x = Array(x_c);

julia> Zygote.gradient(x -> sum(abs2.(x)), x)
(ComplexF32[1.016113f0 + 1.1610899f0im 1.6927286f0 + 0.025104642f0im; 1.9301703f0 + 1.5476352f0im 1.9240301f0 + 0.5781378f0im],)

julia> Zygote.gradient(x -> sum(abs2.(x)), x_c)  # old error we know
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
 [...]

julia> @adjoint broadcasted(::typeof(abs2), x::Numeric) =
         abs2.(x), a -> (nothing, 2 .* a .* x)

julia> Zygote.gradient(x -> sum(abs2.(x)), x_c)
(ComplexF32[1.016113f0 + 1.1610899f0im 1.6927286f0 + 0.025104642f0im; 1.9301703f0 + 1.5476352f0im 1.9240301f0 + 0.5781378f0im],)

@roflmaostc
Copy link
Author

roflmaostc commented May 4, 2021

Hm, I'm confused why the code distinguishes between complex and real numbers.

DualNumbers.jl doesn't

Edit:
OK reading some discussions, that's not exactly the same as in ForwardDiff.
I believe that's beyond my understanding of Zygote etc. atm.

@mcabbott
Copy link
Member

mcabbott commented Jun 30, 2021

ForwardDiff's Dual numbers should work with complex numbers, but the way they are produced and consumed would need to change. The functions are a bit sloppy, they add a dual perturbation to real numbers, and ignore other types, it would be much better to give an error on plausibly differentiable types which can't be handled.

Anyway, a first look:

using Zygote
gradient(x -> sum(sqrt.(x)), [1,2,3])

y, b = Zygote.broadcast_forward(sqrt, [1,2,3]) # method used for CuArrays
b([1,1,1])  # same as normal

y1, b1 = pullback(x -> abs.(x), [1,2+im,3-im])
y2, b2 = Zygote.broadcast_forward(abs, [1,2+im,3-im])
b1([1,1,1])
b2([1,1,1]) # nothing

@eval Zygote dual(x::Complex, p) = Complex(Dual(real(x), p), imag(x))

b2([1,1,1]) # on re-running, now has the real parts of the sensitivity

So I think you need to do something like this, with 2N perturbations when there are complex numbers present (untested!)

dual(x, p, pc=()) = x
dual(x::Real, p, pc=()) = Dual(x, p)
dual(x::Complex, p, pc) = Complex(Dual(real(x), p), Dual(imag(x), pc))

function dual_function(f::F) where F
  function dual_f(args::Vararg{Any,N}) where N
    if any(a isa Complex for a in args)
      ds = map(args, ntuple(identity, Val(N))) do x, i
        dual(x, ntuple(j -> i==j, Val(2N)), ntuple(j -> N+i==j, Val(2N)))
      end
      return f(ds...)
    else
      ds = map(args, ntuple(identity, Val(N))) do x, i
        dual(x, ntuple(j -> i==j, Val(N)))
      end
      return f(ds...)
  end
end

And then broadcast_forward needs to extract these.

That's for the general case, which would also now be useful for broadcasting complex numbers on the CPU.

But in the meantime, adding the @adjoint rule you suggest for abs2 would not be crazy, there are already rules for real, conj etc. here:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L108-L115

@roflmaostc
Copy link
Author

roflmaostc commented Oct 20, 2022

Hi,
since this is still an issue I usually work around by defining custom adjoints.
But I'm confused by that one:

Is that rule wrong?

julia> using Zygote, CUDA

julia> using Zygote:@adjoint, broadcasted

(jl_bqEEuk) pkg> status
Status `/tmp/jl_bqEEuk/Project.toml`
  [052768ef] CUDA v3.12.0
  [e88e6eb3] Zygote v0.6.49

julia> @adjoint broadcasted(::typeof(exp), x::CuArray) = 
                   exp.(x), a -> (nothing, exp.(x) .* a)

julia> Zygote.gradient(x -> sum(real.(exp.(x))), Array([1.1im * pi]))
(ComplexF64[-0.9510565162951535 + 0.30901699437494773im],)

julia> Zygote.gradient(x -> sum(real.(exp.(x))), CuArray([1.1im * pi]))
(ComplexF64[-0.9510565162951535 - 0.30901699437494773im],)

@roflmaostc
Copy link
Author

So the correct rule apparently is:

julia> @adjoint broadcasted(::typeof(exp), x::CuArray) = 
                   exp.(x), a -> (nothing, exp.(conj.(x)) .* a)

julia> Zygote.gradient(x -> sum(real.(exp.(x))), CuArray([1.1im * pi]))

(ComplexF64[-0.9510565162951535 + 0.30901699437494773im],)

Where is this rule defined? I couldn't not really find it in the jungle of Zygote/ChainRules. I'm still not really familiar with the correct terminology, etc.

@mcabbott
Copy link
Member

Yes I think that's right.

There is no such rule in Zygote. In the Array case I believe it will call the most generic fallback path, which is here: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L197-L206 .

@roflmaostc
Copy link
Author

Thanks a lot the hint!

And where is the rule for exp defined? I'm still kind of confused about the conj.

@mcabbott
Copy link
Member

mcabbott commented Oct 20, 2022

The rule for exp(x::Complex) will I think come from ChainRules.

CR now has broadcasting rules too, BTW, but Zygote doesn't use them (yet, or ever, not sure).

I'm frequently confused about the conj but there are essentially two conventions for what gradient(real∘f, x+iy) could mean, and Zygote picked one of them.

@roflmaostc
Copy link
Author

Yeah I thought so too but the only one I found is:

src/rulesets/Base/fastmath_able.jl:        @scalar_rule exp(x) Ω

Does the @scalar_rule inserts the conj at the right positions?

@mcabbott
Copy link
Member

Yes, it should.

@roflmaostc
Copy link
Author

Ok, thanks! That's interesting to know :)

@CarloLucibello
Copy link
Member

closed #1324

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CUDA All things GPU
Projects
None yet
Development

No branches or pull requests

5 participants