diff --git a/Project.toml b/Project.toml index e209009ad..4a71b47cb 100644 --- a/Project.toml +++ b/Project.toml @@ -23,9 +23,9 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "0.8.12" -ChainRulesCore = "0.10.4" -ChainRulesTestUtils = "0.7.1" +ChainRules = "1" +ChainRulesCore = "1.0.1" +ChainRulesTestUtils = "1" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" ForwardDiff = "0.10" @@ -33,7 +33,7 @@ IRTools = "0.4" MacroTools = "0.5" NaNMath = "0.3" Requires = "1.1" -SpecialFunctions = "0.10, 1.0" +SpecialFunctions = "1.6" StatsFuns = "0.9.8" ZygoteRules = "0.2.1" julia = "1.3" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index eac8a8401..dfeb0b184 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,12 +1,72 @@ # This file is machine-generated - editing it directly is not advised +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.0.1" + +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[ChainRules]] +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "346588c81effb94da6a30c1617e56af6a878e4d6" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.0.1" + +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "ad613c934ec3a3aa0ff19b91f15a16d56ed404b5" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.0.2" + +[[CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.31.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + [[Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.0.2" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + [[DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" @@ -19,39 +79,126 @@ git-tree-sha1 = "395fa1554c69735802bba37d9e7d9586fd44326c" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" version = "0.24.11" +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "8c8eac2af06ce35973c3eadb4ab3243076a408e7" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.12.1" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.18" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.3" + [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.3.0" + [[JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.1" +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + [[LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[LogExpFunctions]] +deps = ["DocStringExtensions", "LinearAlgebra"] +git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.2.5" + [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.6" + [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[NaNMath]] +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.5" + [[NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + [[Parsers]] deps = ["Dates"] git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "1.1.0" +[[Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.2" + [[Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -64,18 +211,84 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.3" + [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] +git-tree-sha1 = "508822dca004bf62e210609148511ad03ce8f1d8" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.6.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "5b2f81eeb66bcfe379947c500aae773c85c31033" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.2.8" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + [[Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] +path = ".." +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.17" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.1" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/docs/Project.toml b/docs/Project.toml index 1b9ab1f81..2a4c85433 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "0.24" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index c4e72f07e..6fcdcdf40 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -4,8 +4,7 @@ end ZygoteRuleConfig() = ZygoteRuleConfig(Context()) -const rrule_fallback_method = Base.which(rrule, Tuple{Any, Vararg{Any}}) -const rrule_redispatcher_method = Base.which(rrule, Tuple{RuleConfig, Any, Vararg{Any}}) +_is_rrule_redispatcher(m::Method) = m.sig == Tuple{typeof(rrule), RuleConfig, Vararg} """ has_chain_rrule(T) @@ -18,19 +17,68 @@ such that if a suitable rule is defined later, the generated function will recom """ function has_chain_rrule(T) config_T, arg_Ts = Iterators.peel(T.parameters) - m_with_config = meta(Tuple{typeof(rrule), config_T, arg_Ts...}) - if m_with_config.method === rrule_redispatcher_method - # it is being redispatched without config, so check it that hits the fallback - m_without_config = meta(Tuple{typeof(rrule), arg_Ts...}) - if m_without_config.method === rrule_fallback_method - # no rrule exists, return instance for m_with_config as that will be invalidated - # directly if configured rule added, or indirectly if unconfigured rule added - return false, m_with_config.instance - end + configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...}) + if _is_rrule_redispatcher(configured_rrule_m.method) + # The config is not being used: + # it is being redispatched without config, so we need the method it redispatches to + rrule_m = meta(Tuple{typeof(rrule), arg_Ts...}) + # Thus any no_rrule that might apply must also not have a config because if there was a + # no_rrule with a config that applied then there would also be a rrule with config that applied + no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), arg_Ts...}) + else + # Not being redispatched: it does have a config + rrule_m = configured_rrule_m + # Thus any no_rrule that might apply must also have a config because if it applied + # it will be identical, and if it doesn't we don't care what it is. + no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...}) + end + + # To understand why we only need to check if the sigs match between no_rrule_m and rrule_m + # in order to decide if to use, one must consider the following facts: + # - for every method in `no_rrule` there is a identical one in `rrule` that returns nothing + # - this includes the general fallback `rrule(::Any...)=nothing`. + # - a configured rrule/no_rrule is always more specific than a otherwise equivalent unconfigured rrule/no_rrule + # + # Consider the following truth table, for what can occur: + # rrule: fallback, no_rrule: fallback => matches => do not use rrule. + # rrule: specific, no_rrule: fallback => !matches => do use rrule, as haven't opted out. + # rrule: fallback, no_rrule: specific => IMPOSSIBLE, every no_rule is identical to some rrule + # rrule: specific, no_rrule: specific => matches => do not use rrule as opted out + # rrule: specific, no_rrule: general => !matches => do use rrule as a more specific rrule takes preciedent over more general opted out + # rrule: general , no_rrule: specific => IMPOSSIBLE, every no_rule us identical to some rrule so can't have a more general rrule being hit, as the specific one would hit first + # + # Note that the fallback cases are the same outcome as the general cases as fallback is just most general. + # It can be seen that checking if it matches is the correct way to decide if we should use the rrule or not. + + + do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m) + if do_not_use_rrule + # Return instance for configured_rrule_m as that will be invalidated + # directly if configured rule added, or indirectly if unconfigured rule added + # Do not need an edge for `no_rrule` as no addition of methods to that can cause this + # decision to need to be revisited (only changes to `rrule`), since we are already not + # using the rrule, so not using more rules wouldn't change anything. + return false, configured_rrule_m.instance + else + # Otherwise found a rrule, no need to add any edges for `rrule`, as it will generate + # code with natural edges if a new method is defined there. + # We also do not need an edge to `no_rrule`, as any time a method is added to `no_rrule` + # a corresponding method is added to `rrule` (to return `nothing`), thus we will already + # be revisiting this decision when a new opt-out is added. + return true, nothing end - # otherwise found a rrule, no need to add any edges, as it will generate code with - # natural edges. - return true, nothing +end + +matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig) +matching_cr_sig(::DataType, ::UnionAll) = false +matching_cr_sig(::UnionAll, ::DataType) = false +matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s) + +type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...} +function type_tuple_tail(d::UnionAll) + body = Base.unwrap_unionall(d) + body_tt = type_tuple_tail(body) + return Base.rewrap_unionall(body_tt, d) end """ @@ -52,7 +100,8 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally. """ -@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks +@inline wrap_chainrules_output(x) = x +@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks @inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) # Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. @inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing @@ -63,7 +112,7 @@ for T_outer in (:Tuple, :NamedTuple) # than happy. @eval @inline function wrap_chainrules_output(x::ChainRules.Tangent{P, T}) where {P, T<:$T_outer} xp = map(wrap_chainrules_output, canonicalize(x)) - convert($T_outer, xp) + ChainRulesCore.backing(xp) # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest end end diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 6ca8257d5..e4db33471 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -58,11 +58,11 @@ See also [`withgradient`](@ref) to keep the value `f(args...)`, and [`pullback`](@ref) for value and back-propagator. ```jldoctest; setup=:(using Zygote) -julia> gradient(*, 2, 3, 5) -(15, 10, 6) +julia> gradient(*, 2.0, 3.0, 5.0) +(15.0, 10.0, 6.0) -julia> gradient(x -> sum(abs2,x), [7, 11, 13]) -([14, 22, 26],) +julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0]) +([14.0, 22.0, 26.0],) julia> gradient([7, 11], 0, 1) do x, y, d p = size(x, d) diff --git a/test/chainrules.jl b/test/chainrules.jl index 32bdd3799..30b758d04 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -232,6 +232,35 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote aug_primal_val, _ = Zygote.pullback(x->StructForTestingTypeOnlyRRules(), 1.2) @test aug_primal_val.x == 2.0 end + + @testset "@opt_out" begin + oout_id(x) = x + oout_id_rrule_hitcount = Ref(0) + function ChainRulesCore.rrule(::typeof(oout_id), x::Any) + oout_id_rrule_hitcount[] += 1 + oout_id_pullback(ȳ) = (NoTangent(), ȳ) + return oout_id(x), oout_id_pullback + end + + @opt_out ChainRulesCore.rrule(::typeof(oout_id), x::AbstractArray) + + # Hit one we haven't opted out + oout_id_rrule_hitcount[] = 0 + oout_id_outer(x) = sum(oout_id(x)) + @test (1.0,) == Zygote.gradient(oout_id_outer, π) + @test oout_id_rrule_hitcount[] == 1 + + # make sure don't hit the one we have opted out + oout_id_rrule_hitcount[] = 0 + @test ([1.0],) == Zygote.gradient(oout_id_outer, [π]) + @test oout_id_rrule_hitcount[] == 0 + + # Now try opting out After we have already used it + @opt_out ChainRulesCore.rrule(::typeof(oout_id), x::Real) + oout_id_rrule_hitcount[] = 0 + @test (1.0,) == Zygote.gradient(oout_id_outer, π) + @test oout_id_rrule_hitcount[] == 0 + end end @testset "ChainRulesCore.rrule_via_ad" begin @@ -275,7 +304,7 @@ end ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad ) test_rrule( - ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad + ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad ) test_rrule(ZygoteRuleConfig(), sum, (1.0, 2.0, 3.0); rrule_f=rrule_via_ad) test_rrule( diff --git a/test/complex.jl b/test/complex.jl index 54f99fd0f..6a0445b85 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -18,7 +18,7 @@ using Zygote, Test, LinearAlgebra @test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10 # https://github.com/FluxML/Zygote.jl/issues/705 -@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ im .* exp.(1:3) +@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3)) @test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3) fs_C_to_R = (real, diff --git a/test/features.jl b/test/features.jl index 766673350..b17f55b41 100644 --- a/test/features.jl +++ b/test/features.jl @@ -449,12 +449,12 @@ end @test pullback(type_test)[1] == Complex{<:Real} @testset "Pairs" begin - @test (x->10*pairs((a=x, b=2))[1])'(100) === 10 + @test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0 @test (x->10*pairs((a=x, b=2))[2])'(100) === 0 foo(;kw...) = 1 @test gradient(() -> foo(a=1,b=2.0)) === () - @test (x->10*(x => 2)[1])'(100) === 10 + @test (x->10*(x => 2)[1])'(100) === 10.0 @test (x->10*(x => 2)[2])'(100) === 0 end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 08dfd45db..eab959ddd 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -81,7 +81,7 @@ end @test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],) @test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,) - @test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ (-234 + 2im)*log(5 - 7im) + @test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ real((-234 + 2im)*log(5 - 7im)) # D[(1+3I)x^p, p] /. {x->5+7I, p->2} // Conjugate end @@ -160,7 +160,7 @@ end # https://github.com/FluxML/Zygote.jl/issues/376 _, back = Zygote._pullback(x->x[1]*im, randn(2)) - @test back(1.0)[2] == [-im, 0] + @test back(1.0)[2] == real([-im, 0]) == [0, 0] # _droplike @test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],) diff --git a/test/utils.jl b/test/utils.jl index 9a3d83ea5..70a8ebd63 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -24,6 +24,7 @@ end xs, y = randn(2,3), rand() f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments + dx, dy = diaghessian(f34, xs, y) @test size(dx) == size(xs) @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))