From cef527fda4a8cb3bae88ba33e2d518081191757a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 12 Sep 2024 14:26:03 +0100 Subject: [PATCH 01/62] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 255943b8f..f135556c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.49" +version = "0.2.50" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From b9c3f650ddbff926482bdb27b68efe426e829d45 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 12 Sep 2024 14:37:49 +0100 Subject: [PATCH 02/62] Fix usage with benchmarktools --- src/interpreter/s2s_reverse_mode_ad.jl | 38 ++++++++++++------------- test/interpreter/s2s_reverse_mode_ad.jl | 6 ++++ 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index b07d79c8b..92659df41 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -495,9 +495,9 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) rrule!! # intrinsic / builtin / thing we provably have rule for elseif is_invoke mi = stmt.args[1]::Core.MethodInstance - LazyDerivedRule(info.interp, mi, info.safety_on) # Static dispatch + LazyDerivedRule(mi, info.safety_on) # Static dispatch else - DynamicDerivedRule(info.interp, info.safety_on) # Dynamic dispatch + DynamicDerivedRule(info.safety_on) # Dynamic dispatch end # Wrap the raw rule in a struct which ensures that any `ZeroRData`s are stripped @@ -1420,23 +1420,20 @@ of its arguments. Stores rules in an internal cache to avoid re-deriving. This is used to implement dynamic dispatch. =# -struct DynamicDerivedRule{T, V} - interp::T +struct DynamicDerivedRule{V} cache::V safety_on::Bool end -function DynamicDerivedRule(interp::TapirInterpreter, safety_on::Bool) - return DynamicDerivedRule(interp, Dict{Any, Any}(), safety_on) -end +DynamicDerivedRule(safety_on::Bool) = DynamicDerivedRule(Dict{Any, Any}(), safety_on) -_copy(x::P) where {P<:DynamicDerivedRule} = P(x.interp, Dict{Any, Any}(), x.safety_on) +_copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any, Any}(), x.safety_on) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} sig = Tuple{map(_typeof ∘ primal, args)...} rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing - rule = build_rrule(dynamic_rule.interp, sig; safety_on=dynamic_rule.safety_on) + rule = build_rrule(get_tapir_interpreter(), sig; safety_on=dynamic_rule.safety_on) dynamic_rule.cache[sig] = rule end return rule(args...) @@ -1460,26 +1457,27 @@ reason to keep this around is for debugging -- it is very helpful to have this t in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit. =# -mutable struct LazyDerivedRule{Tinterp<:TapirInterpreter, primal_sig, Trule} - interp::Tinterp +mutable struct LazyDerivedRule{primal_sig, Trule} safety_on::Bool mi::Core.MethodInstance rule::Trule - function LazyDerivedRule(interp::A, mi::Core.MethodInstance, safety_on::Bool) where {A} - return new{A, mi.specTypes, rule_type(interp, mi; safety_on)}(interp, safety_on, mi) + function LazyDerivedRule(mi::Core.MethodInstance, safety_on::Bool) + interp = get_tapir_interpreter() + return new{mi.specTypes, rule_type(interp, mi; safety_on)}(safety_on, mi) end - function LazyDerivedRule{Tinterp, Tprimal_sig, Trule}( - interp::Tinterp, mi::Core.MethodInstance, safety_on::Bool - ) where {Tinterp, Tprimal_sig, Trule} - return new{Tinterp, Tprimal_sig, Trule}(interp, safety_on, mi) + function LazyDerivedRule{Tprimal_sig, Trule}( + mi::Core.MethodInstance, safety_on::Bool + ) where {Tprimal_sig, Trule} + return new{Tprimal_sig, Trule}(safety_on, mi) end end -_copy(x::P) where {P<:LazyDerivedRule} = P(x.interp, x.mi, x.safety_on) +_copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.safety_on) -function (rule::LazyDerivedRule{T, sig, Trule})(args::Vararg{Any, N}) where {N, T, sig, Trule} +function (rule::LazyDerivedRule{sig, Trule})(args::Vararg{Any, N}) where {N, sig, Trule} if !isdefined(rule, :rule) - derived_rule = build_rrule(rule.interp, rule.mi; safety_on=rule.safety_on) + interp = get_tapir_interpreter() + derived_rule = build_rrule(interp, rule.mi; safety_on=rule.safety_on) if derived_rule isa Trule rule.rule = derived_rule else diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 75b5a9bac..be9cd8203 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -274,5 +274,11 @@ end Xoshiro(123456), S2SGlobals.f, S2SGlobals.A(2 * ones(3)), ones(3); interface_only=false, is_primitive=false, ) + + # BenchmarkTools not working due to world age problems. Provided that this code + # runs successfully, everything is okay -- no need to check anything specific. + f(x) = sin(cos(x)) + rule = Tapir.build_rrule(f, 0.0) + @benchmark Tapir.value_and_gradient!!($rule, $f, $(Ref(0.0))[]) end end From 8f0f75d13e35f338f203ba47863ac6d593e5d5d0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 11:41:31 +0100 Subject: [PATCH 03/62] Initial pass --- src/Tapir.jl | 1 + src/chain_rules_macro.jl | 81 +++++++++++++++++++++++++-------------- test/chain_rules_macro.jl | 17 ++++++++ 3 files changed, 71 insertions(+), 28 deletions(-) diff --git a/src/Tapir.jl b/src/Tapir.jl index f4656ce3e..89d8f61a9 100644 --- a/src/Tapir.jl +++ b/src/Tapir.jl @@ -13,6 +13,7 @@ using Random, Setfield +# There are many clashing names, so we will always qualify uses of names from CRC. import ChainRulesCore using Base: diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index de5c64a02..502feffed 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -1,5 +1,55 @@ -_to_rdata(::ChainRulesCore.NoTangent) = NoRData() -_to_rdata(dx::Float64) = dx +""" + to_cr_tangent(t) + +Convert a Tapir tangent into a type that ChainRules.jl `rrule`s expect to see. +Inverse of `to_tapir_tangent`. +""" +to_cr_tangent(t::IEEEFloat) = t +to_cr_tangent(t::Array{<:IEEEFloat}) = t +to_cr_tangent(::NoTangent) = ChainRulesCore.NoTangent() + +""" + to_tapir_tangent(cr_t) + +Convert a ChainRules.jl tangent, `cr_t`, into the corresponding Tapir tangent. +Inverse of `to_cr_tangent`. +""" +to_tapir_tangent(t::IEEEFloat) = t +to_tapir_tangent(t::Array{<:IEEEFloat}) = t +to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() + +""" + rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} + +Used to implement `rrule!!`s via `ChainRulesCore.rrule`. + +Given a function `foo`, argument types `arg_types`, and a method `ChainRulesCore.rrule` of +which applies to these, you can make use of this function as follows: +```julia +Tapir.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} +function Tapir.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...) + return rrule_wrapper_implementation(f, args...) +end +``` +Assumes that methods of `to_cr_tangent` and `to_tapir_tangent` are defined such that you +can convert between the different representations of tangents that Tapir and ChainRulesCore +expect. + +Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the +amount of boilerplate code that you are required to write even further. +""" +function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} + y_primal, cr_pb = ChainRulesCore.rrule(tuple_map(primal, fargs)...) + y_fdata = fdata(zero_tangent(y_primal)) + function pb!!(y_rdata) + cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + cr_dfargs = cr_pb(cr_tangent) + dfargs = tuple_map(to_tapir_tangent, cr_dfargs) + tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) + return tuple_map(rdata, dfargs) + end + return CoDual(y_primal, y_fdata), pb!! +end @doc""" @from_rrule ctx sig @@ -32,37 +82,12 @@ macro from_rrule(ctx, sig) arg_types = map(t -> :(Tapir.CoDual{<:$t}), arg_type_symbols) arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) - call_rrule = Expr( - :call, - :(Tapir.ChainRulesCore.rrule), - map(n -> :(Tapir.primal($n)), arg_names)..., - ) - - pb_output_names = map(n -> Symbol("dx_$(n)_inc"), eachindex(arg_names)) - - call_pb = Expr(:(=), Expr(:tuple, pb_output_names...), :(pb(dy))) - incrementers = Expr(:tuple, map(b -> :(Tapir._to_rdata($b)), pb_output_names)...) - - pb = ExprTools.combinedef(Dict( - :head => :function, - :name => :pb!!, - :args => [:dy], - :body => quote - $call_pb - return $incrementers - end, - )) - rule_expr = ExprTools.combinedef( Dict( :head => :function, :name => :(Tapir.rrule!!), :args => arg_exprs, - :body => quote - y, pb = $call_rrule - $pb - return Tapir.zero_fcodual(y), pb!! - end, + :body => Expr(:call, rrule_wrapper_implementation, arg_names...), ) ) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 920cd0dd3..98f071613 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -1,3 +1,5 @@ +# Test case with isbits data. + bleh(x::Float64, y::Int) = x * y function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) @@ -6,6 +8,21 @@ end Tapir.@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} +# Test case with heap-allocated data. + +test_sum(x) = sum(x) + +function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) + test_sum_pb(dy::Real) = ChainRulesCore.NoTangent(), fill(dy, size(x)) + return test_sum(x), test_sum_pb +end + +Tapir.@is_primitive DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} +function Tapir.rrule!!(f::CoDual{typeof(test_sum)}, x::CoDual{<:Array{<:Base.IEEEFloat}}) + return Tapir.rrule_wrapper_implementation(f, x) +end + @testset "chain_rules_macro" begin Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) + Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) end From e791ceff1af48d4d86dbe379c28279657a4292c3 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 11:41:42 +0100 Subject: [PATCH 04/62] Bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f135556c0..65294a512 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.50" +version = "0.2.51" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From f45456e4947760a32a94bfe457c262b7c62bcb54 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 12:42:59 +0100 Subject: [PATCH 05/62] Unit test to_tapir_tangent and to_cr_tangent --- test/chain_rules_macro.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 98f071613..ad4b5ab31 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -23,6 +23,18 @@ function Tapir.rrule!!(f::CoDual{typeof(test_sum)}, x::CoDual{<:Array{<:Base.IEE end @testset "chain_rules_macro" begin - Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) - Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) + @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ + (5.0, 5.0), + (ones(5), ones(5)), + (NoTangent(), ChainRulesCore.NoTangent()), + ] + @test Tapir.to_cr_tangent(t) == t_cr + @test Tapir.to_tapir_tangent(t_cr) == t + @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t + @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr + end + @testset "rules" begin + Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) + Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) + end end From bec9f06ef919be3e40c4d1bf859ff09de6df0579 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 12:43:49 +0100 Subject: [PATCH 06/62] Make use of macro --- test/chain_rules_macro.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index ad4b5ab31..4b1857323 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -17,10 +17,7 @@ function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) return test_sum(x), test_sum_pb end -Tapir.@is_primitive DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} -function Tapir.rrule!!(f::CoDual{typeof(test_sum)}, x::CoDual{<:Array{<:Base.IEEEFloat}}) - return Tapir.rrule_wrapper_implementation(f, x) -end +Tapir.@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} @testset "chain_rules_macro" begin @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ From d03710178cec3b79804ebbf47e4fb4d09da13996 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 12:59:43 +0100 Subject: [PATCH 07/62] More testing and tidying up --- src/chain_rules_macro.jl | 2 +- test/chain_rules_macro.jl | 49 ++++++++++++++++++++++++++++++++++----- test/front_matter.jl | 2 +- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 502feffed..097830a4b 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -92,7 +92,7 @@ macro from_rrule(ctx, sig) ) ex = quote - Tapir.is_primitive(::Type{$ctx}, ::Type{$sig}) = true + Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true $rule_expr end return esc(ex) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 4b1857323..9ab1ec983 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -1,3 +1,10 @@ +module ChainRulesInteropTestResources + +using ChainRulesCore, LinearAlgebra, Tapir + +using Base: IEEEFloat +using Tapir: DefaultCtx, @from_rrule + # Test case with isbits data. bleh(x::Float64, y::Int) = x * y @@ -6,9 +13,9 @@ function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) end -Tapir.@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} +@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} -# Test case with heap-allocated data. +# Test case with heap-allocated input. test_sum(x) = sum(x) @@ -17,7 +24,33 @@ function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) return test_sum(x), test_sum_pb end -Tapir.@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} +@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} + +# Test case with heap-allocated output. + +test_scale(x::Real, y::AbstractVector{<:Real}) = x * y + +function ChainRulesCore.rrule(::typeof(test_scale), x::Real, y::AbstractVector{<:Real}) + function test_scale_pb(dout::AbstractVector{<:Real}) + return ChainRulesCore.NoTangent(), dot(dout, y), dout * x + end + return x * y, test_scale_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}} + +# Test case with non-differentiable type as output. + +test_nothing() = nothing + +function ChainRulesCore.rrule(::typeof(test_nothing)) + test_nothing_pb(::ChainRulesCore.NoTangent) = (ChainRulesCore.NoTangent(),) + return nothing, test_nothing_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_nothing)} + +end @testset "chain_rules_macro" begin @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ @@ -30,8 +63,12 @@ Tapir.@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr end - @testset "rules" begin - Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) - Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) + @testset "rules: $(typeof(fargs))" for fargs in Any[ + (ChainRulesInteropTestResources.bleh, 5.0, 4), + (ChainRulesInteropTestResources.test_sum, ones(5)), + (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), + (ChainRulesInteropTestResources.test_nothing,), + ] + test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end end diff --git a/test/front_matter.jl b/test/front_matter.jl index b42f66937..f78ca47e8 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -14,7 +14,7 @@ using import ChainRulesCore -using Base: unsafe_load, pointer_from_objref +using Base: unsafe_load, pointer_from_objref, IEEEFloat using Base.Iterators: product using Core: bitcast, svec, ReturnNode, PhiNode, PiNode, GotoIfNot, GotoNode, SSAValue, Argument From 54947f0b69b59288abbcde7bb1ff21ccc4386faa Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 13:10:35 +0100 Subject: [PATCH 08/62] Add some basic type checking and a test --- src/chain_rules_macro.jl | 7 +++++-- test/chain_rules_macro.jl | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 097830a4b..5fab9da46 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -39,12 +39,15 @@ Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to redu amount of boilerplate code that you are required to write even further. """ function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} - y_primal, cr_pb = ChainRulesCore.rrule(tuple_map(primal, fargs)...) + primals = tuple_map(primal, fargs) + tangent_types = tuple_map(x -> tangent_type(typeof(x)), primals) + y_primal, cr_pb = ChainRulesCore.rrule(primals...) y_fdata = fdata(zero_tangent(y_primal)) function pb!!(y_rdata) cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) cr_dfargs = cr_pb(cr_tangent) - dfargs = tuple_map(to_tapir_tangent, cr_dfargs) + dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) + dfargs = tuple_map(typeassert, dfargs_unvalidated, tangent_types) tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) return tuple_map(rdata, dfargs) end diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 9ab1ec983..7512b3e98 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -50,6 +50,19 @@ end @from_rrule DefaultCtx Tuple{typeof(test_nothing)} +# Test case in which ChainRulesCore returns a tangent which is of the "wrong" type from the +# perspective of Tapir.jl. In this instance, some kind of error should be thrown, rather +# than it being possible for the error to propagate. + +test_bad_rdata(x::Real) = 5x + +function ChainRulesCore.rrule(::typeof(test_bad_rdata), x::Float64) + test_bad_rdata_pb(dy::Float64) = ChainRulesCore.NoTangent(), Float32(dy * 5) + return 5x, test_bad_rdata_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} + end @testset "chain_rules_macro" begin @@ -71,4 +84,9 @@ end ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end + @testset "bad rdata" begin + f = ChainRulesInteropTestResources.test_bad_rdata + out, pb!! = Tapir.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) + @test_throws TypeError pb!!(5.0) + end end From bc88483ea66be7df16f1b5afcddbea4ec0c5c710 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 13:19:27 +0100 Subject: [PATCH 09/62] Improve formatting and commenting --- src/chain_rules_macro.jl | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 5fab9da46..b72f21b7e 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -38,17 +38,39 @@ expect. Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the amount of boilerplate code that you are required to write even further. """ -function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} +@inline function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} + + # Run forwards-pass. primals = tuple_map(primal, fargs) - tangent_types = tuple_map(x -> tangent_type(typeof(x)), primals) y_primal, cr_pb = ChainRulesCore.rrule(primals...) y_fdata = fdata(zero_tangent(y_primal)) + + # Construct functions which, when applied to the tangent types returned on the + # reverse-pass, will check that they are of the expected type. This will pick up on + # obvious problems, but is intended to be fast / optimised away when things go well. + # As such, you should think of this as a lightweight version of "debug_mode". + tangent_type_assertions = tuple_map( + x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals + ) + function pb!!(y_rdata) + + # Construct tangent w.r.t. output. cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + + # Run reverse-pass using ChainRules. cr_dfargs = cr_pb(cr_tangent) + + # Convert output into tangent types appropriate for Tapir. dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) - dfargs = tuple_map(typeassert, dfargs_unvalidated, tangent_types) + + # Apply type assertions. + dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) + + # Increment the fdata. tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) + + # Return the rdata. return tuple_map(rdata, dfargs) end return CoDual(y_primal, y_fdata), pb!! From f29b8f31c3c7ee022f89da101d1a59cf205dfe39 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 14:22:10 +0100 Subject: [PATCH 10/62] Formatting --- src/chain_rules_macro.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index b72f21b7e..10316993b 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -94,8 +94,8 @@ Use this function with care. It has only been tested for `Float64` arguments and whose `tangent_type` is `NoTangent`, and it is entirely probable that it won't work for arguments which aren't `Float64` or non-differentiable. -You should definitely make use of [`TestUtils.test_rule`](@ref) to verify that the rule created -works as intended. +You should definitely make use of [`TestUtils.test_rule`](@ref) to verify that the rule +created works as intended. """ macro from_rrule(ctx, sig) From 50d7dd83ade66e4dd1ba434736bc2e6b721c3ce6 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 14:31:34 +0100 Subject: [PATCH 11/62] Improve documentation --- src/chain_rules_macro.jl | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 10316993b..9185eb57a 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -18,8 +18,8 @@ to_tapir_tangent(t::IEEEFloat) = t to_tapir_tangent(t::Array{<:IEEEFloat}) = t to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() -""" - rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} +@doc""" + rrule_wrapper_implementation(f::CoDual, args::CoDual...) Used to implement `rrule!!`s via `ChainRulesCore.rrule`. @@ -35,6 +35,10 @@ Assumes that methods of `to_cr_tangent` and `to_tapir_tangent` are defined such can convert between the different representations of tangents that Tapir and ChainRulesCore expect. +Furthermore, it is _essential_ that +1. `f(args)` does not mutate `f` or `args`, and +2. the result of `f(args)` does not alias any data stored in `f` or `args`. + Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the amount of boilerplate code that you are required to write even further. """ @@ -79,9 +83,8 @@ end @doc""" @from_rrule ctx sig -Creates a `Tapir.rrule!!` from a `ChainRulesCore.rrule`. `ctx` is the type of the context in -which this rule should apply, and `sig` is the type-tuple which specifies which primal the -rule should apply to. +Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. +This macro is a thin wrapper around [`rrule_wrapper_implementation`](@ref). For example, ```julia @@ -89,13 +92,16 @@ For example, ``` would define a `Tapir.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`. -Health warning: -Use this function with care. It has only been tested for `Float64` arguments and arguments -whose `tangent_type` is `NoTangent`, and it is entirely probable that it won't work for -arguments which aren't `Float64` or non-differentiable. +Limitations: it is your responsibility to ensure that +1. calls with signature `sig` do not mutate their arguments, +2. the output of calls with signature `sig` does not alias any of the inputs, +3. `sig` is a `Tuple{...}`, not a `Tuple{...} where {...}`. + +This last point is a limitation of the current implementation, rather than something +fundamental, whereas the first two points are more basic points. -You should definitely make use of [`TestUtils.test_rule`](@ref) to verify that the rule -created works as intended. +As with all hand-written rules, you should definitely make use of +[`TestUtils.test_rule`](@ref) to verify correctness on some test cases. """ macro from_rrule(ctx, sig) From 1788c07d6298fd4121c2ec1a6b0cc54dfb0f1049 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 14:52:10 +0100 Subject: [PATCH 12/62] Explain how not to use rrule functionality --- docs/make.jl | 7 +++++-- docs/src/using_chain_rules.md | 13 +++++++++++++ src/chain_rules_macro.jl | 22 ++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 docs/src/using_chain_rules.md diff --git a/docs/make.jl b/docs/make.jl index 88f352785..42aa87b31 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -32,9 +32,12 @@ makedocs( "Algorithmic Differentiation" => "algorithmic_differentiation.md", "Tapir.jl's Rule System" => "mathematical_interpretation.md", ], + "Utilities" => [ + "Using ChainRules" => "using_chain_rules.md", + "Safe Mode" => "safe_mode.md", + "Debugging and MWEs" => "debugging_and_mwes.md", + ], "Known Limitations" => "known_limitations.md", - "Safe Mode" => "safe_mode.md", - "Debugging and MWEs" => "debugging_and_mwes.md", ] ) diff --git a/docs/src/using_chain_rules.md b/docs/src/using_chain_rules.md new file mode 100644 index 000000000..531d4c4cf --- /dev/null +++ b/docs/src/using_chain_rules.md @@ -0,0 +1,13 @@ +# Using ChainRules.jl + +[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode. +These rules are methods of the `ChainRulesCore.rrule` function. +There are some instances where there is it most convenient to implement a `Tapir.rrule!!` by wrapping an existing `ChainRulesCore.rrule`. + +There is enough similarity between these two systems that most of the boilerplate code can be avoided. +The docstrings below explain this functionality, and how it should / should not be used. + +```@docs +Tapir.@from_rrule +Tapir.rrule_wrapper_implementation +``` diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 9185eb57a..9a6251251 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -102,6 +102,28 @@ fundamental, whereas the first two points are more basic points. As with all hand-written rules, you should definitely make use of [`TestUtils.test_rule`](@ref) to verify correctness on some test cases. + +# A Note On Type Constraints + +Many methods of `ChainRuleCore.rrule` are implemented with very loose type constraints. +For example, it would not be surprising to see a method of rrule with the signature +```julia +Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}} +``` +There are a variety of reasons for this way of doing things, and whether it is a good idea +to write rules for such generic objects has been debated at length. + +Suffice it to say, you should not write rules for this package which are so generically +typed. +Rather, you should create rules for the subset of types for which you believe that the +`ChainRulesCore.rrule` will work correctly, and leave this package to derive rules for the +rest. +For example, in the above case you might be confident that the rule will behave correctly +for input types `Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}}`. You should therefore +only write a rule for these types: +```julia +@from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} +``` """ macro from_rrule(ctx, sig) From b4e80bc0ca8f5f27c8756eab103f55d8b2f714c9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 15:05:39 +0100 Subject: [PATCH 13/62] Add rules for BLAS utilities --- src/rrules/blas.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 5b15fb085..5a995485c 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -19,7 +19,21 @@ function tri!(A, u::Char, d::Char) return u == 'L' ? tril!(A, d == 'U' ? -1 : 0) : triu!(A, d == 'U' ? 1 : 0) end +# +# Utility +# + +@is_primitive MinimalCtx Tuple{typeof(BLAS.get_num_threads)} +rrule!!(f::CoDual{typeof(BLAS.get_num_threads)}) = simple_zero_adjoint(f) +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} +rrule!!(f::CoDual{typeof(BLAS.lbt_get_num_threads)}) = simple_zero_adjoint(f) + +@is_primitive MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} +rrule!!(f::CoDual{typeof(BLAS.set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} +rrule!!(f::CoDual{typeof(BLAS.lbt_set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) # # LEVEL 1 @@ -793,6 +807,12 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) test_cases = vcat( + # Utility + (false, :stability, nothing, BLAS.get_num_threads), + (false, :stability, nothing, BLAS.lbt_get_num_threads), + (false, :stability, nothing, BLAS.set_num_threads, 1), + (false, :stability, nothing, BLAS.lbt_set_num_threads, 1), + # # BLAS LEVEL 1 # From 4a2b8e0890d019ad54ba636b8391a9f1d73b66a2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 16:01:27 +0100 Subject: [PATCH 14/62] Initial NNlib integration --- Project.toml | 6 +++++- ext/TapirNNlibExt.jl | 12 ++++++++++++ test/front_matter.jl | 1 + test/integration_testing/nnlib.jl | 9 +++++++++ test/runtests.jl | 1 + 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 ext/TapirNNlibExt.jl create mode 100644 test/integration_testing/nnlib.jl diff --git a/Project.toml b/Project.toml index 65294a512..3e63a0a86 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] @@ -29,6 +30,7 @@ TapirCUDAExt = "CUDA" TapirDynamicPPLExt = "DynamicPPL" TapirJETExt = "JET" TapirLogDensityProblemsADExt = "LogDensityProblemsAD" +TapirNNlibExt = "NNlib" TapirSpecialFunctionsExt = "SpecialFunctions" [compat] @@ -47,6 +49,7 @@ Graphs = "1" JET = "0.9" LogDensityProblemsAD = "1" MistyClosures = "1" +NNlib = "0.9" PDMats = "0.11" Setfield = "1" SpecialFunctions = "2" @@ -67,6 +70,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -76,4 +80,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "NNlib", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl new file mode 100644 index 000000000..2b074a7d5 --- /dev/null +++ b/ext/TapirNNlibExt.jl @@ -0,0 +1,12 @@ +module TapirNNlibExt + + using NNlib, Tapir + using Base: IEEEFloat + + import Tapir: @from_rrule, DefaultCtx + + @from_rrule( + DefaultCtx, + Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, + ) +end diff --git a/test/front_matter.jl b/test/front_matter.jl index f78ca47e8..ec9bbdab8 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -5,6 +5,7 @@ using FillArrays, JET, LinearAlgebra, + NNlib, PDMats, Random, SpecialFunctions, diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl new file mode 100644 index 000000000..b12150777 --- /dev/null +++ b/test/integration_testing/nnlib.jl @@ -0,0 +1,9 @@ +@testset "nnlib" begin + @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ + (:stability, NNlib.upsample_nearest, randn(3), (2,)), + (:stability, NNlib.upsample_nearest, randn(3, 2), (2, 2)), + (:stability, NNlib.upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + ] + test_rule(sr(1), fargs...; is_primitive=true, perf_flag) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e921cc49d..73cb208f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,6 +49,7 @@ include("front_matter.jl") include(joinpath("integration_testing", "battery_tests.jl")) include(joinpath("integration_testing", "dynamic_ppl.jl")) include(joinpath("integration_testing", "logdensityproblemsad_interop.jl")) + include(joinpath("integration_testing", "nnlib.jl")) include(joinpath("integration_testing", "special_functions.jl")) elseif test_group == "integration_testing/misc_abstract_array" include(joinpath("integration_testing", "misc_abstract_array.jl")) From d1d9fae42b7e83124b07d87b3f37b7d57bd4ae83 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 16:15:13 +0100 Subject: [PATCH 15/62] Thunks and batched_mul --- src/chain_rules_macro.jl | 1 + test/chain_rules_macro.jl | 6 ++++++ test/integration_testing/nnlib.jl | 1 + 3 files changed, 8 insertions(+) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 9a6251251..a3e88f390 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -17,6 +17,7 @@ Inverse of `to_cr_tangent`. to_tapir_tangent(t::IEEEFloat) = t to_tapir_tangent(t::Array{<:IEEEFloat}) = t to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() +to_tapir_tangent(t::ChainRulesCore.Thunk) = to_tapir_tangent(ChainRulesCore.unthunk(t)) @doc""" rrule_wrapper_implementation(f::CoDual, args::CoDual...) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 7512b3e98..e0c85d4b7 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -76,6 +76,12 @@ end @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr end + + # The fact that I'm testing this separately suggests to me that there's something that + # I've not quite gotten right about the abstractions involved here. + @testset "ChainRulesCore.thunk" begin + @test Tapir.to_tapir_tangent(ChainRulesCore.Thunk(() -> ones(5))) == ones(5) + end @testset "rules: $(typeof(fargs))" for fargs in Any[ (ChainRulesInteropTestResources.bleh, 5.0, 4), (ChainRulesInteropTestResources.test_sum, ones(5)), diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl index b12150777..11644344e 100644 --- a/test/integration_testing/nnlib.jl +++ b/test/integration_testing/nnlib.jl @@ -1,5 +1,6 @@ @testset "nnlib" begin @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ + (:none, NNlib.batched_mul, randn(3, 2, 3), randn(2, 5, 3)), (:stability, NNlib.upsample_nearest, randn(3), (2,)), (:stability, NNlib.upsample_nearest, randn(3, 2), (2, 2)), (:stability, NNlib.upsample_nearest, randn(3, 2, 3), (2, 2, 5)), From 6f036adcab39bf3b18105cf2f7cd22c228e17f60 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 17:18:23 +0100 Subject: [PATCH 16/62] More rules + kwargs + rename --- ext/TapirNNlibExt.jl | 15 ++++++++++ src/chain_rules_macro.jl | 49 +++++++++++++++++++++++++++---- test/integration_testing/nnlib.jl | 39 ++++++++++++++++++++---- 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 2b074a7d5..75306f6e9 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -5,6 +5,21 @@ module TapirNNlibExt import Tapir: @from_rrule, DefaultCtx + @from_rrule( + DefaultCtx, Tuple{typeof(batched_mul), Array{<:IEEEFloat, 3}, Array{<:IEEEFloat, 3}} + ) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(softmax), Array{<:IEEEFloat}}, + ) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsoftmax), Array{<:IEEEFloat}}, + ) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsumexp), Array{<:IEEEFloat}}, + ) @from_rrule( DefaultCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index a3e88f390..5983908e8 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -20,7 +20,7 @@ to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() to_tapir_tangent(t::ChainRulesCore.Thunk) = to_tapir_tangent(ChainRulesCore.unthunk(t)) @doc""" - rrule_wrapper_implementation(f::CoDual, args::CoDual...) + rrule_wrapper(f::CoDual, args::CoDual...) Used to implement `rrule!!`s via `ChainRulesCore.rrule`. @@ -29,7 +29,7 @@ which applies to these, you can make use of this function as follows: ```julia Tapir.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} function Tapir.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...) - return rrule_wrapper_implementation(f, args...) + return rrule_wrapper(f, args...) end ``` Assumes that methods of `to_cr_tangent` and `to_tapir_tangent` are defined such that you @@ -43,7 +43,7 @@ Furthermore, it is _essential_ that Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the amount of boilerplate code that you are required to write even further. """ -@inline function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} +function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} # Run forwards-pass. primals = tuple_map(primal, fargs) @@ -81,11 +81,50 @@ amount of boilerplate code that you are required to write even further. return CoDual(y_primal, y_fdata), pb!! end +function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) where {N} + + # Run forwards-pass. + primals = tuple_map(primal, fargs) + y_primal, cr_pb = Core.kwcall(primals[1], ChainRulesCore.rrule, primals[2:end]...) + y_fdata = fdata(zero_tangent(y_primal)) + + # Construct functions which, when applied to the tangent types returned on the + # reverse-pass, will check that they are of the expected type. This will pick up on + # obvious problems, but is intended to be fast / optimised away when things go well. + # As such, you should think of this as a lightweight version of "debug_mode". + tangent_type_assertions = tuple_map( + x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals[2:end] + ) + + function pb!!(y_rdata) + + # Construct tangent w.r.t. output. + cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + + # Run reverse-pass using ChainRules. + cr_dfargs = cr_pb(cr_tangent) + + # Convert output into tangent types appropriate for Tapir. + dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) + + # Apply type assertions. + dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) + + # Increment the fdata. + tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs[2:end], dfargs) + + # Return the rdata. + kwargs_rdata = rdata(zero_tangent(fargs[1])) + return NoRData(), kwargs_rdata, tuple_map(rdata, dfargs)... + end + return CoDual(y_primal, y_fdata), pb!! +end + @doc""" @from_rrule ctx sig Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. -This macro is a thin wrapper around [`rrule_wrapper_implementation`](@ref). +This macro is a thin wrapper around [`rrule_wrapper`](@ref). For example, ```julia @@ -141,7 +180,7 @@ macro from_rrule(ctx, sig) :head => :function, :name => :(Tapir.rrule!!), :args => arg_exprs, - :body => Expr(:call, rrule_wrapper_implementation, arg_names...), + :body => Expr(:call, rrule_wrapper, arg_names...), ) ) diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl index 11644344e..4a6157748 100644 --- a/test/integration_testing/nnlib.jl +++ b/test/integration_testing/nnlib.jl @@ -1,10 +1,39 @@ @testset "nnlib" begin @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ - (:none, NNlib.batched_mul, randn(3, 2, 3), randn(2, 5, 3)), - (:stability, NNlib.upsample_nearest, randn(3), (2,)), - (:stability, NNlib.upsample_nearest, randn(3, 2), (2, 2)), - (:stability, NNlib.upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + + # batched_mul + (:none, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), + + # softmax + (:stability, Core.kwcall, (dims=1, ), softmax, randn(2,)), + (:stability, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), + (:none, x -> softmax(5x), randn(3, 2)), + (:none, x -> softmax(x; dims=1), randn(3, 2)), + (:none, x -> softmax(x; dims=2), randn(3, 2)), + (:none, x -> softmax(x; dims=(1, 2)), randn(3, 2)), + + # logsoftmax + (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), + (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), + + # logsumexp + (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), + (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), + (:stability, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), + + # upsample_nearest + (:stability, upsample_nearest, randn(3), (2,)), + (:stability, upsample_nearest, randn(3, 2), (2, 2)), + (:stability, upsample_nearest, randn(3, 2, 3), (2, 2, 5)), ] - test_rule(sr(1), fargs...; is_primitive=true, perf_flag) + test_rule(sr(1), fargs...; is_primitive=false, perf_flag) end end From e225a0adafea729179d8255598ac8bad9873c2b9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 18:33:04 +0100 Subject: [PATCH 17/62] Fix link in docs --- docs/src/using_chain_rules.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/using_chain_rules.md b/docs/src/using_chain_rules.md index 531d4c4cf..e1726aad3 100644 --- a/docs/src/using_chain_rules.md +++ b/docs/src/using_chain_rules.md @@ -9,5 +9,5 @@ The docstrings below explain this functionality, and how it should / should not ```@docs Tapir.@from_rrule -Tapir.rrule_wrapper_implementation +Tapir.rrule_wrapper ``` From 3bba38ebbc9d53dd80e43574199c8ca9570eb6e0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 18:36:35 +0100 Subject: [PATCH 18/62] Rename chain_rules_macro to chain_rules_interop --- src/{chain_rules_macro.jl => chain_rules_interop.jl} | 0 test/{chain_rules_macro.jl => chain_rules_interop.jl} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/{chain_rules_macro.jl => chain_rules_interop.jl} (100%) rename test/{chain_rules_macro.jl => chain_rules_interop.jl} (100%) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_interop.jl similarity index 100% rename from src/chain_rules_macro.jl rename to src/chain_rules_interop.jl diff --git a/test/chain_rules_macro.jl b/test/chain_rules_interop.jl similarity index 100% rename from test/chain_rules_macro.jl rename to test/chain_rules_interop.jl From 619f0ce9ed2bb069fb05b28b61884e402d6bb97b Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 11:56:24 +0100 Subject: [PATCH 19/62] Complete rename of chain rules interop file --- src/Tapir.jl | 2 +- test/runtests.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Tapir.jl b/src/Tapir.jl index 89d8f61a9..1af51cb0f 100644 --- a/src/Tapir.jl +++ b/src/Tapir.jl @@ -86,7 +86,7 @@ include(joinpath("rrules", "misc.jl")) include(joinpath("rrules", "new.jl")) include(joinpath("rrules", "tasks.jl")) -include("chain_rules_macro.jl") +include("chain_rules_interop.jl") include("interface.jl") export diff --git a/test/runtests.jl b/test/runtests.jl index 73cb208f7..619f6ea9d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ include("front_matter.jl") @info "tasks" include(joinpath("rrules", "tasks.jl")) end - include("chain_rules_macro.jl") + include("chain_rules_interop.jl") elseif test_group == "integration_testing/misc" include(joinpath("integration_testing", "battery_tests.jl")) include(joinpath("integration_testing", "dynamic_ppl.jl")) From 345c46a0e7d58338b56a13c999e3817ebae7c919 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 11:58:37 +0100 Subject: [PATCH 20/62] Refactor chain rules interop --- src/chain_rules_interop.jl | 69 ++++++++++++------------------------- test/chain_rules_interop.jl | 14 +++++++- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 5983908e8..4bf5c90c0 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -1,23 +1,25 @@ -""" + """ to_cr_tangent(t) Convert a Tapir tangent into a type that ChainRules.jl `rrule`s expect to see. -Inverse of `to_tapir_tangent`. """ to_cr_tangent(t::IEEEFloat) = t to_cr_tangent(t::Array{<:IEEEFloat}) = t to_cr_tangent(::NoTangent) = ChainRulesCore.NoTangent() """ - to_tapir_tangent(cr_t) + increment_and_get_rdata!(fdata, rdata, cr_tangent) -Convert a ChainRules.jl tangent, `cr_t`, into the corresponding Tapir tangent. -Inverse of `to_cr_tangent`. """ -to_tapir_tangent(t::IEEEFloat) = t -to_tapir_tangent(t::Array{<:IEEEFloat}) = t -to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() -to_tapir_tangent(t::ChainRulesCore.Thunk) = to_tapir_tangent(ChainRulesCore.unthunk(t)) +increment_and_get_rdata!(::NoFData, r::T, t::T) where {T<:IEEEFloat} = r + t +function increment_and_get_rdata!(f::Array{P}, ::NoRData, t::Array{P}) where {P<:IEEEFloat} + increment!!(f, t) + return NoRData() +end +increment_and_get_rdata!(::Any, r, ::ChainRulesCore.NoTangent) = r +function increment_and_get_rdata!(f, r, t::ChainRulesCore.Thunk) + return increment_and_get_rdata!(f, r, ChainRulesCore.unthunk(t)) +end @doc""" rrule_wrapper(f::CoDual, args::CoDual...) @@ -47,17 +49,10 @@ function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} # Run forwards-pass. primals = tuple_map(primal, fargs) + lazy_rdata = tuple_map(Tapir.lazy_zero_rdata, primals) y_primal, cr_pb = ChainRulesCore.rrule(primals...) y_fdata = fdata(zero_tangent(y_primal)) - # Construct functions which, when applied to the tangent types returned on the - # reverse-pass, will check that they are of the expected type. This will pick up on - # obvious problems, but is intended to be fast / optimised away when things go well. - # As such, you should think of this as a lightweight version of "debug_mode". - tangent_type_assertions = tuple_map( - x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals - ) - function pb!!(y_rdata) # Construct tangent w.r.t. output. @@ -66,17 +61,10 @@ function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} # Run reverse-pass using ChainRules. cr_dfargs = cr_pb(cr_tangent) - # Convert output into tangent types appropriate for Tapir. - dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) - - # Apply type assertions. - dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) - - # Increment the fdata. - tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) - - # Return the rdata. - return tuple_map(rdata, dfargs) + # Increment fdata and get rdata. + return map(fargs, lazy_rdata, cr_dfargs) do x, l_rdata, cr_dx + return increment_and_get_rdata!(tangent(x), instantiate(l_rdata), cr_dx) + end end return CoDual(y_primal, y_fdata), pb!! end @@ -85,17 +73,10 @@ function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) # Run forwards-pass. primals = tuple_map(primal, fargs) + lazy_rdata = tuple_map(lazy_zero_rdata, primals) y_primal, cr_pb = Core.kwcall(primals[1], ChainRulesCore.rrule, primals[2:end]...) y_fdata = fdata(zero_tangent(y_primal)) - # Construct functions which, when applied to the tangent types returned on the - # reverse-pass, will check that they are of the expected type. This will pick up on - # obvious problems, but is intended to be fast / optimised away when things go well. - # As such, you should think of this as a lightweight version of "debug_mode". - tangent_type_assertions = tuple_map( - x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals[2:end] - ) - function pb!!(y_rdata) # Construct tangent w.r.t. output. @@ -104,18 +85,12 @@ function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) # Run reverse-pass using ChainRules. cr_dfargs = cr_pb(cr_tangent) - # Convert output into tangent types appropriate for Tapir. - dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) - - # Apply type assertions. - dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) - - # Increment the fdata. - tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs[2:end], dfargs) - - # Return the rdata. + # Increment fdata and compute rdata. kwargs_rdata = rdata(zero_tangent(fargs[1])) - return NoRData(), kwargs_rdata, tuple_map(rdata, dfargs)... + args_rdata = map(fargs[2:end], lazy_rdata[2:end], cr_dfargs) do x, l_rdata, cr_dx + return increment_and_get_rdata!(tangent(x), instantiate(l_rdata), cr_dx) + end + return NoRData(), kwargs_rdata, args_rdata... end return CoDual(y_primal, y_fdata), pb!! end diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index e0c85d4b7..66c56bec4 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -63,6 +63,16 @@ end @from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} +# Test case for rule with kwargs. +test_kwargs(x; y::Bool) = y ? x : 2x + +function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool) + test_kwargs_pb(dz::Float64) = ChainRulesCore.NoTangent(), y ? dz : 2dz + return y ? x : 2x, test_kwargs_pb +end + +@from_rrule DefaultCtx Tuple{typeof(Core.kwcall), NamedTuple, typeof(test_kwargs), Float64} + end @testset "chain_rules_macro" begin @@ -87,12 +97,14 @@ end (ChainRulesInteropTestResources.test_sum, ones(5)), (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), (ChainRulesInteropTestResources.test_nothing,), + (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end @testset "bad rdata" begin f = ChainRulesInteropTestResources.test_bad_rdata out, pb!! = Tapir.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) - @test_throws TypeError pb!!(5.0) + @test_throws MethodError pb!!(5.0) end end From 8e87d116983c0c8d3d0bc874ecdafb56b3faa1ad Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 11:58:46 +0100 Subject: [PATCH 21/62] Add more nnlib functionality --- ext/TapirNNlibExt.jl | 82 +++++++++++++++++++++++- test/integration_testing/nnlib.jl | 101 ++++++++++++++++++++++-------- 2 files changed, 157 insertions(+), 26 deletions(-) diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 75306f6e9..8edd69f44 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -1,6 +1,6 @@ module TapirNNlibExt - using NNlib, Tapir + using NNlib, Random, Tapir using Base: IEEEFloat import Tapir: @from_rrule, DefaultCtx @@ -8,14 +8,28 @@ module TapirNNlibExt @from_rrule( DefaultCtx, Tuple{typeof(batched_mul), Array{<:IEEEFloat, 3}, Array{<:IEEEFloat, 3}} ) + @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof(dropout), + AbstractRNG, + Array{<:IEEEFloat}, + IEEEFloat, + }, + ) + @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}) @from_rrule( DefaultCtx, Tuple{typeof(Core.kwcall), NamedTuple, typeof(softmax), Array{<:IEEEFloat}}, ) + @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}) @from_rrule( DefaultCtx, Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsoftmax), Array{<:IEEEFloat}}, ) + @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}) @from_rrule( DefaultCtx, Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsumexp), Array{<:IEEEFloat}}, @@ -24,4 +38,70 @@ module TapirNNlibExt DefaultCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, ) + @from_rrule( + DefaultCtx, + Tuple{ + typeof(NNlib.fold), + Array{<:IEEEFloat}, + NTuple{N, Int} where {N}, + DenseConvDims, + }, + ) + @from_rrule(DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims}) + @from_rrule( + DefaultCtx, + Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + ) + @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof(NNlib.scatter), + Any, + Array, + Array{<:Union{Integer, Tuple}}, + }, + ) + + for backend in (Symbol(), :_direct, :_im2col), name in (:conv, :depthwiseconv) + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof(NNlib.$(Symbol("$name$(backend)"))), + Array{<:IEEEFloat}, + Array{<:IEEEFloat}, + ConvDims, + }, + ) + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(NNlib.$(Symbol("$name$(backend)"))), + Array{<:IEEEFloat}, + Array{<:IEEEFloat}, + ConvDims, + }, + ) + end + for pool in [:maxpool, :meanpool] + @eval @from_rrule(DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}) + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof($pool), + Array{<:IEEEFloat}, + PoolDims, + }, + ) + end + @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(pad_constant), Array, Any, Any}, + ) end diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl index 4a6157748..0cfdbb4bf 100644 --- a/test/integration_testing/nnlib.jl +++ b/test/integration_testing/nnlib.jl @@ -1,39 +1,90 @@ @testset "nnlib" begin - @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ + x = randn(5, 4, 3, 2) + w = randn(2, 2, 3, 3) + dense_cdims = DenseConvDims(x, w) + sep_cdims = DepthwiseConvDims(x, w) + pool_dims = PoolDims(size(x), 2) + + grid = Array{Float64}(undef, 2, 2, 2, 1) + grid[:, 1, 1, 1] .= (-1, -1) + grid[:, 2, 1, 1] .= (1, -1) + grid[:, 1, 2, 1] .= (-1, 1) + grid[:, 2, 2, 1] .= (1, 1) + + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ # batched_mul - (:none, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), + (false, :none, true, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), + + # dropout + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=1), randn(2, 2), 0.5), + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=2), randn(2, 2), 0.1), + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=(1, 2)), randn(2, 2), 0.4), # softmax - (:stability, Core.kwcall, (dims=1, ), softmax, randn(2,)), - (:stability, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), - (:none, x -> softmax(5x), randn(3, 2)), - (:none, x -> softmax(x; dims=1), randn(3, 2)), - (:none, x -> softmax(x; dims=2), randn(3, 2)), - (:none, x -> softmax(x; dims=(1, 2)), randn(3, 2)), + (false, :stability, true, softmax, randn(2)), + (false, :stability, true, softmax, randn(2, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), + (false, :none, false, x -> softmax(5x), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=1), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=2), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=(1, 2)), randn(3, 2)), # logsoftmax - (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), - (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), + (false, :stability, true, logsoftmax, randn(2)), + (false, :stability, true, logsoftmax, randn(2, 3)), + (false, :stability, true, logsoftmax, randn(2, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), # logsumexp - (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), - (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), - (:stability, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), + (false, :stability, true, logsumexp, randn(2,)), + (false, :stability, true, logsumexp, randn(3, 3)), + (false, :stability, true, logsumexp, randn(3, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), # upsample_nearest - (:stability, upsample_nearest, randn(3), (2,)), - (:stability, upsample_nearest, randn(3, 2), (2, 2)), - (:stability, upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + (false, :stability, true, upsample_nearest, randn(3), (2,)), + (false, :stability, true, upsample_nearest, randn(3, 2), (2, 2)), + (false, :stability, true, upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + + # fold + (false, :none, true, NNlib.fold, randn(12, 12, 2), size(x), dense_cdims), + + # unfold + (false, :none, true, NNlib.unfold, x, dense_cdims), + + # scatter + (false, :stability, true, NNlib.scatter, +, randn(2), [1, 3]), + (false, :stability, true, Core.kwcall, (;), NNlib.scatter, +, randn(2), [1, 3]), + + # conv + (false, :none, true, Core.kwcall, (;), conv, x, w, dense_cdims), + (false, :none, true, conv, x, w, dense_cdims), + (false, :none, true, Core.kwcall, (;), depthwiseconv, x, w, sep_cdims), + (false, :none, true, depthwiseconv, x, w, sep_cdims), + + # pooling + (false, :none, true, maxpool, x, pool_dims), + (false, :none, true, Core.kwcall, (;), maxpool, x, pool_dims), + (false, :none, true, meanpool, x, pool_dims), + (false, :none, true, Core.kwcall, (;), meanpool, x, pool_dims), + + # padding + (false, :none, false, x -> pad_constant(x, 1, 2.0), x), + (false, :none, false, x -> pad_constant(x, 1, 2.0; dims=:), x), ] - test_rule(sr(1), fargs...; is_primitive=false, perf_flag) + test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end end From d3459782d6cc66d03cc5581694a524d5ddcce4b2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 14:00:13 +0100 Subject: [PATCH 22/62] Remove old tests --- test/chain_rules_interop.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 66c56bec4..368d683e9 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -76,15 +76,12 @@ end end @testset "chain_rules_macro" begin - @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ + @testset "to_cr_tangent" for (t, t_cr) in Any[ (5.0, 5.0), (ones(5), ones(5)), (NoTangent(), ChainRulesCore.NoTangent()), ] @test Tapir.to_cr_tangent(t) == t_cr - @test Tapir.to_tapir_tangent(t_cr) == t - @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t - @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr end # The fact that I'm testing this separately suggests to me that there's something that From 0f3fe90af1280eb2f3c3b0aff81211067fa0a943 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 19:12:20 +0100 Subject: [PATCH 23/62] Some work --- Project.toml | 7 ++- ext/TapirLuxLibExt.jl | 21 +++++++++ src/interpreter/s2s_reverse_mode_ad.jl | 2 + test/{integration_testing => ext}/cuda.jl | 0 .../dynamic_ppl.jl | 0 .../logdensityproblemsad.jl} | 0 test/ext/luxlib.jl | 8 ++++ test/{integration_testing => ext}/nnlib.jl | 0 .../special_functions.jl | 0 test/front_matter.jl | 2 + test/integration_testing/lux.jl | 44 +++++++++++++++++++ test/runtests.jl | 11 ++--- 12 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 ext/TapirLuxLibExt.jl rename test/{integration_testing => ext}/cuda.jl (100%) rename test/{integration_testing => ext}/dynamic_ppl.jl (100%) rename test/{integration_testing/logdensityproblemsad_interop.jl => ext/logdensityproblemsad.jl} (100%) create mode 100644 test/ext/luxlib.jl rename test/{integration_testing => ext}/nnlib.jl (100%) rename test/{integration_testing => ext}/special_functions.jl (100%) create mode 100644 test/integration_testing/lux.jl diff --git a/Project.toml b/Project.toml index 3e63a0a86..70b7646c4 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -30,6 +31,7 @@ TapirCUDAExt = "CUDA" TapirDynamicPPLExt = "DynamicPPL" TapirJETExt = "JET" TapirLogDensityProblemsADExt = "LogDensityProblemsAD" +TapirLuxLibExt = "LuxLib" TapirNNlibExt = "NNlib" TapirSpecialFunctionsExt = "SpecialFunctions" @@ -48,6 +50,7 @@ FillArrays = "1" Graphs = "1" JET = "0.9" LogDensityProblemsAD = "1" +LuxLib = "1.2" MistyClosures = "1" NNlib = "0.9" PDMats = "0.11" @@ -70,6 +73,8 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -80,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "NNlib", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/ext/TapirLuxLibExt.jl b/ext/TapirLuxLibExt.jl new file mode 100644 index 000000000..83fe62307 --- /dev/null +++ b/ext/TapirLuxLibExt.jl @@ -0,0 +1,21 @@ +module TapirLuxLibExt + + using LuxLib, Random, Tapir + using Base: IEEEFloat + + import LuxLib.Impl: matmul, matmuladd, fused_dense + import Tapir: @from_rrule, DefaultCtx + + @from_rrule DefaultCtx Tuple{typeof(matmul), Array{<:IEEEFloat}, Array{<:IEEEFloat}} + @from_rrule( + DefaultCtx, + Tuple{typeof(matmuladd), Array{<:IEEEFloat}, Array{<:IEEEFloat}, Vector{<:IEEEFloat}}, + ) + + # The implementations of rrules for fused operations are not straightforward to + # incorporate into Tapir.jl, because they call back into AD. + # We take a simple appoach to their implementation: differentiate an un-fused version + # of their implementation. This will likely hit performance, but it makes implementing + # rules much more straightforward, in that we only have to be able to implement their + # constituent parts, rather than the entire thing. +end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 92659df41..217750aaa 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -843,6 +843,8 @@ function build_rrule( interp::TapirInterpreter{C}, sig_or_mi; safety_on=false, silence_safety_messages=true ) where {C} + @show sig_or_mi + # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. if Base.get_world_counter() > interp.world diff --git a/test/integration_testing/cuda.jl b/test/ext/cuda.jl similarity index 100% rename from test/integration_testing/cuda.jl rename to test/ext/cuda.jl diff --git a/test/integration_testing/dynamic_ppl.jl b/test/ext/dynamic_ppl.jl similarity index 100% rename from test/integration_testing/dynamic_ppl.jl rename to test/ext/dynamic_ppl.jl diff --git a/test/integration_testing/logdensityproblemsad_interop.jl b/test/ext/logdensityproblemsad.jl similarity index 100% rename from test/integration_testing/logdensityproblemsad_interop.jl rename to test/ext/logdensityproblemsad.jl diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl new file mode 100644 index 000000000..ce0e15c5e --- /dev/null +++ b/test/ext/luxlib.jl @@ -0,0 +1,8 @@ +@testset "luxlib" begin + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ + (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), + (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), + ] + test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) + end +end diff --git a/test/integration_testing/nnlib.jl b/test/ext/nnlib.jl similarity index 100% rename from test/integration_testing/nnlib.jl rename to test/ext/nnlib.jl diff --git a/test/integration_testing/special_functions.jl b/test/ext/special_functions.jl similarity index 100% rename from test/integration_testing/special_functions.jl rename to test/ext/special_functions.jl diff --git a/test/front_matter.jl b/test/front_matter.jl index ec9bbdab8..8bb01a516 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -5,6 +5,8 @@ using FillArrays, JET, LinearAlgebra, + Lux, + LuxLib, NNlib, PDMats, Random, diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl new file mode 100644 index 000000000..3f4778e4b --- /dev/null +++ b/test/integration_testing/lux.jl @@ -0,0 +1,44 @@ +@testset "lux" begin + @testset "$(typeof(f))" for (f, x_f32) in Any[ + (Dense(2, 4), randn(Float32, 2, 3)), + (Dense(2, 4, gelu), randn(Float32, 2, 3)), + # (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), + # (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), + # (Scale(2), randn(Float32, 2, 3)), + # (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule + # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + # (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule + # (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), # fpext getting used here somehow + # (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), # fpext getting used here somehow + # (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + # (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + # (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # stack overflow. Probably task again + # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again + # (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again + # (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + ] + @info "$(_typeof((f, x_f32...)))" + ps, st = f64(Lux.setup(sr(123456), f)) + x = f64(x_f32) + test_rule(sr(123456), f, x, ps, st; is_primitive=false) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 619f6ea9d..fc1681882 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,10 +47,11 @@ include("front_matter.jl") include("chain_rules_interop.jl") elseif test_group == "integration_testing/misc" include(joinpath("integration_testing", "battery_tests.jl")) - include(joinpath("integration_testing", "dynamic_ppl.jl")) - include(joinpath("integration_testing", "logdensityproblemsad_interop.jl")) - include(joinpath("integration_testing", "nnlib.jl")) - include(joinpath("integration_testing", "special_functions.jl")) + include(joinpath("ext", "dynamic_ppl.jl")) + include(joinpath("ext", "logdensityproblemsad.jl")) + include(joinpath("ext", "luxlib.jl")) + include(joinpath("ext", "nnlib.jl")) + include(joinpath("ext", "special_functions.jl")) elseif test_group == "integration_testing/misc_abstract_array" include(joinpath("integration_testing", "misc_abstract_array.jl")) elseif test_group == "integration_testing/diff_tests" @@ -66,7 +67,7 @@ include("front_matter.jl") elseif test_group == "integration_testing/temporalgps" include(joinpath("integration_testing", "temporalgps.jl")) elseif test_group == "gpu" - include(joinpath("integration_testing", "cuda.jl")) + include(joinpath("ext", "cuda.jl")) else throw(error("test_group=$(test_group) is not recognised")) end From ae93a27fcf1b105b69145e875e2082743f780b10 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:23:39 +0100 Subject: [PATCH 24/62] Remove errant show statment --- src/interpreter/s2s_reverse_mode_ad.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 217750aaa..92659df41 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -843,8 +843,6 @@ function build_rrule( interp::TapirInterpreter{C}, sig_or_mi; safety_on=false, silence_safety_messages=true ) where {C} - @show sig_or_mi - # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. if Base.get_world_counter() > interp.world From 82ecd82e41758debbac91c7f6fef78ac49d6e8cb Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:36:23 +0100 Subject: [PATCH 25/62] Remove redundant test --- test/chain_rules_interop.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 368d683e9..5b2f72b85 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -83,12 +83,6 @@ end ] @test Tapir.to_cr_tangent(t) == t_cr end - - # The fact that I'm testing this separately suggests to me that there's something that - # I've not quite gotten right about the abstractions involved here. - @testset "ChainRulesCore.thunk" begin - @test Tapir.to_tapir_tangent(ChainRulesCore.Thunk(() -> ones(5))) == ones(5) - end @testset "rules: $(typeof(fargs))" for fargs in Any[ (ChainRulesInteropTestResources.bleh, 5.0, 4), (ChainRulesInteropTestResources.test_sum, ones(5)), From ca93535d1f055f3ee461780518423f8d59daff23 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:43:05 +0100 Subject: [PATCH 26/62] Support where --- src/chain_rules_interop.jl | 36 +++++++++++++++++++++--------------- test/chain_rules_interop.jl | 8 ++++++++ 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 4bf5c90c0..8abc3e767 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -109,11 +109,7 @@ would define a `Tapir.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCo Limitations: it is your responsibility to ensure that 1. calls with signature `sig` do not mutate their arguments, -2. the output of calls with signature `sig` does not alias any of the inputs, -3. `sig` is a `Tuple{...}`, not a `Tuple{...} where {...}`. - -This last point is a limitation of the current implementation, rather than something -fundamental, whereas the first two points are more basic points. +2. the output of calls with signature `sig` does not alias any of the inputs. As with all hand-written rules, you should definitely make use of [`TestUtils.test_rule`](@ref) to verify correctness on some test cases. @@ -142,22 +138,32 @@ only write a rule for these types: """ macro from_rrule(ctx, sig) - @assert sig.head == :curly - @assert sig.args[1] == :Tuple - arg_type_symbols = sig.args[2:end] + if sig.head == :curly + @assert sig.args[1] == :Tuple + arg_type_symbols = sig.args[2:end] + where_params = nothing + elseif sig.head == :where + @assert sig.args[1].args[1] == :Tuple + arg_type_symbols = sig.args[1].args[2:end] + where_params = sig.args[2:end] + else + throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) + end arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) arg_types = map(t -> :(Tapir.CoDual{<:$t}), arg_type_symbols) arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) - rule_expr = ExprTools.combinedef( - Dict( - :head => :function, - :name => :(Tapir.rrule!!), - :args => arg_exprs, - :body => Expr(:call, rrule_wrapper, arg_names...), - ) + def = Dict( + :head => :function, + :name => :(Tapir.rrule!!), + :args => arg_exprs, + :body => Expr(:call, rrule_wrapper, arg_names...), ) + if where_params !== nothing + def[:whereparams] = where_params + end + rule_expr = ExprTools.combinedef(def) ex = quote Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 5b2f72b85..039398bcb 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -63,6 +63,14 @@ end @from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} +# Test case for rule with diagonal dispatch. +test_add(x, y) = x + y +function ChainRulesCore.rrule(::typeof(test_add), x, y) + test_add_pb(dout) = ChainRulesCore.NoTangent(), dout, dout + return x + y, test_add_pb +end +@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} + # Test case for rule with kwargs. test_kwargs(x; y::Bool) = y ? x : 2x From fc6c00fcf9a4ea1646a667f1057c1b44b1aaf5ae Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:58:23 +0100 Subject: [PATCH 27/62] Make use of where params --- ext/TapirNNlibExt.jl | 23 +++++++++++------------ test/ext/nnlib.jl | 1 + test/front_matter.jl | 2 ++ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 8edd69f44..4a1f2675e 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -2,11 +2,13 @@ module TapirNNlibExt using NNlib, Random, Tapir using Base: IEEEFloat + using NNlib: dropout import Tapir: @from_rrule, DefaultCtx @from_rrule( - DefaultCtx, Tuple{typeof(batched_mul), Array{<:IEEEFloat, 3}, Array{<:IEEEFloat, 3}} + DefaultCtx, + Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, ) @from_rrule( DefaultCtx, @@ -15,9 +17,9 @@ module TapirNNlibExt NamedTuple, typeof(dropout), AbstractRNG, - Array{<:IEEEFloat}, - IEEEFloat, - }, + Array{P}, + P, + } where {P<:IEEEFloat}, ) @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}) @from_rrule( @@ -71,19 +73,16 @@ module TapirNNlibExt typeof(Core.kwcall), NamedTuple, typeof(NNlib.$(Symbol("$name$(backend)"))), - Array{<:IEEEFloat}, - Array{<:IEEEFloat}, + Array{P}, + Array{P}, ConvDims, - }, + } where {P<:IEEEFloat}, ) @eval @from_rrule( DefaultCtx, Tuple{ - typeof(NNlib.$(Symbol("$name$(backend)"))), - Array{<:IEEEFloat}, - Array{<:IEEEFloat}, - ConvDims, - }, + typeof(NNlib.$(Symbol("$name$(backend)"))), Array{P}, Array{P}, ConvDims, + } where {P<:IEEEFloat}, ) end for pool in [:maxpool, :meanpool] diff --git a/test/ext/nnlib.jl b/test/ext/nnlib.jl index 0cfdbb4bf..1f00d9f52 100644 --- a/test/ext/nnlib.jl +++ b/test/ext/nnlib.jl @@ -85,6 +85,7 @@ (false, :none, false, x -> pad_constant(x, 1, 2.0), x), (false, :none, false, x -> pad_constant(x, 1, 2.0; dims=:), x), ] + @info "$(typeof(fargs))" test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end end diff --git a/test/front_matter.jl b/test/front_matter.jl index 8bb01a516..56ac344e4 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -23,6 +23,8 @@ using Core: bitcast, svec, ReturnNode, PhiNode, PiNode, GotoIfNot, GotoNode, SSAValue, Argument using Core.Intrinsics: pointerref, pointerset +using NNlib: dropout + using Tapir: CC, IntrinsicsWrappers, From 473bc0288cd5d50882a9ad99a09b65688c3cf4f6 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 14:53:31 +0100 Subject: [PATCH 28/62] Improve kwarg interface --- ext/TapirLuxLibExt.jl | 4 +- ext/TapirNNlibExt.jl | 76 ++++++------------------------------- src/chain_rules_interop.jl | 49 ++++++++++++++++++------ test/chain_rules_interop.jl | 16 ++++---- 4 files changed, 61 insertions(+), 84 deletions(-) diff --git a/ext/TapirLuxLibExt.jl b/ext/TapirLuxLibExt.jl index 83fe62307..37eb7272b 100644 --- a/ext/TapirLuxLibExt.jl +++ b/ext/TapirLuxLibExt.jl @@ -6,10 +6,10 @@ module TapirLuxLibExt import LuxLib.Impl: matmul, matmuladd, fused_dense import Tapir: @from_rrule, DefaultCtx - @from_rrule DefaultCtx Tuple{typeof(matmul), Array{<:IEEEFloat}, Array{<:IEEEFloat}} + @from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) @from_rrule( DefaultCtx, - Tuple{typeof(matmuladd), Array{<:IEEEFloat}, Array{<:IEEEFloat}, Vector{<:IEEEFloat}}, + Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, ) # The implementations of rrules for fused operations are not straightforward to diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 4a1f2675e..5bc4b1f19 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -12,30 +12,12 @@ module TapirNNlibExt ) @from_rrule( DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof(dropout), - AbstractRNG, - Array{P}, - P, - } where {P<:IEEEFloat}, - ) - @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(softmax), Array{<:IEEEFloat}}, - ) - @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsoftmax), Array{<:IEEEFloat}}, - ) - @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsumexp), Array{<:IEEEFloat}}, + Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, + true, ) + @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) + @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) + @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) @from_rrule( DefaultCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, @@ -43,64 +25,30 @@ module TapirNNlibExt @from_rrule( DefaultCtx, Tuple{ - typeof(NNlib.fold), - Array{<:IEEEFloat}, - NTuple{N, Int} where {N}, - DenseConvDims, + typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, }, ) - @from_rrule(DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims}) @from_rrule( - DefaultCtx, - Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} ) @from_rrule( DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof(NNlib.scatter), - Any, - Array, - Array{<:Union{Integer, Tuple}}, - }, + Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + true, ) - for backend in (Symbol(), :_direct, :_im2col), name in (:conv, :depthwiseconv) - @eval @from_rrule( - DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof(NNlib.$(Symbol("$name$(backend)"))), - Array{P}, - Array{P}, - ConvDims, - } where {P<:IEEEFloat}, - ) @eval @from_rrule( DefaultCtx, Tuple{ typeof(NNlib.$(Symbol("$name$(backend)"))), Array{P}, Array{P}, ConvDims, } where {P<:IEEEFloat}, + true, ) end for pool in [:maxpool, :meanpool] - @eval @from_rrule(DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}) @eval @from_rrule( - DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof($pool), - Array{<:IEEEFloat}, - PoolDims, - }, + DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true ) end - @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(pad_constant), Array, Any, Any}, - ) + @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) end diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 8abc3e767..8beba59d9 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -96,7 +96,7 @@ function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) end @doc""" - @from_rrule ctx sig + @from_rrule ctx sig [has_kwargs=false] Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. This macro is a thin wrapper around [`rrule_wrapper`](@ref). @@ -107,6 +107,11 @@ For example, ``` would define a `Tapir.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`. +```julia +@from_rrule DefaultCtx Tuple{typeof(foo), Float64} true +``` +would define a method of `Tapir.rrule!!` which can handle keyword arguments. + Limitations: it is your responsibility to ensure that 1. calls with signature `sig` do not mutate their arguments, 2. the output of calls with signature `sig` does not alias any of the inputs. @@ -136,8 +141,9 @@ only write a rule for these types: @from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} ``` """ -macro from_rrule(ctx, sig) +macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) + # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. if sig.head == :curly @assert sig.args[1] == :Tuple arg_type_symbols = sig.args[2:end] @@ -152,8 +158,35 @@ macro from_rrule(ctx, sig) arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) arg_types = map(t -> :(Tapir.CoDual{<:$t}), arg_type_symbols) - arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) + rule_expr = construct_def(arg_names, arg_types, where_params) + + if has_kwargs + kw_sig = Expr(:curly, :Tuple, :(typeof(Core.kwcall)), :NamedTuple, arg_type_symbols...) + kw_sig = where_params === nothing ? kw_sig : Expr(:where, kw_sig, where_params...) + kw_is_primitive = :(Tapir.is_primitive(::Type{$ctx}, ::Type{<:$kw_sig}) = true) + kwcall_type = :(Tapir.CoDual{typeof(Core.kwcall)}) + nt_type = :(Tapir.CoDual{<:NamedTuple}) + kwargs_rule_expr = construct_def( + vcat(:_kwcall, :kwargs, arg_names), + vcat(kwcall_type, nt_type, arg_types), + where_params, + ) + else + kw_is_primitive = nothing + kwargs_rule_expr = nothing + end + ex = quote + Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true + $rule_expr + $kw_is_primitive + $kwargs_rule_expr + end + return esc(ex) +end + +function construct_def(arg_names, arg_types, where_params) + arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) def = Dict( :head => :function, :name => :(Tapir.rrule!!), @@ -163,11 +196,5 @@ macro from_rrule(ctx, sig) if where_params !== nothing def[:whereparams] = where_params end - rule_expr = ExprTools.combinedef(def) - - ex = quote - Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true - $rule_expr - end - return esc(ex) -end + return ExprTools.combinedef(def) +end \ No newline at end of file diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 039398bcb..bdb10fd16 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -13,7 +13,7 @@ function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) end -@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} +@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} false # Test case with heap-allocated input. @@ -24,7 +24,7 @@ function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) return test_sum(x), test_sum_pb end -@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} +@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} false # Test case with heap-allocated output. @@ -37,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(test_scale), x::Real, y::AbstractVector{< return x * y, test_scale_pb end -@from_rrule DefaultCtx Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}} +@from_rrule( + DefaultCtx, Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}}, false +) # Test case with non-differentiable type as output. @@ -48,7 +50,7 @@ function ChainRulesCore.rrule(::typeof(test_nothing)) return nothing, test_nothing_pb end -@from_rrule DefaultCtx Tuple{typeof(test_nothing)} +@from_rrule DefaultCtx Tuple{typeof(test_nothing)} false # Test case in which ChainRulesCore returns a tangent which is of the "wrong" type from the # perspective of Tapir.jl. In this instance, some kind of error should be thrown, rather @@ -61,7 +63,7 @@ function ChainRulesCore.rrule(::typeof(test_bad_rdata), x::Float64) return 5x, test_bad_rdata_pb end -@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} +@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} false # Test case for rule with diagonal dispatch. test_add(x, y) = x + y @@ -69,7 +71,7 @@ function ChainRulesCore.rrule(::typeof(test_add), x, y) test_add_pb(dout) = ChainRulesCore.NoTangent(), dout, dout return x + y, test_add_pb end -@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} false # Test case for rule with kwargs. test_kwargs(x; y::Bool) = y ? x : 2x @@ -79,7 +81,7 @@ function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool) return y ? x : 2x, test_kwargs_pb end -@from_rrule DefaultCtx Tuple{typeof(Core.kwcall), NamedTuple, typeof(test_kwargs), Float64} +@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs), Float64}, true) end From 1cfbfcca31b5d49161732c252a40f081343d9e64 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 14:54:41 +0100 Subject: [PATCH 29/62] Default kwargs test --- test/chain_rules_interop.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index bdb10fd16..1f77d54a9 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -74,9 +74,9 @@ end @from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} false # Test case for rule with kwargs. -test_kwargs(x; y::Bool) = y ? x : 2x +test_kwargs(x; y::Bool=false) = y ? x : 2x -function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool) +function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool=false) test_kwargs_pb(dz::Float64) = ChainRulesCore.NoTangent(), y ? dz : 2dz return y ? x : 2x, test_kwargs_pb end @@ -100,6 +100,7 @@ end (ChainRulesInteropTestResources.test_nothing,), (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (ChainRulesInteropTestResources.test_kwargs, 5.0), ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end From 8ac290342d75f86b7ea963f0c4ac52a87adb4fa1 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 14:59:38 +0100 Subject: [PATCH 30/62] Improve docstring --- src/chain_rules_interop.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 8beba59d9..4b3fc7e71 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -8,8 +8,10 @@ to_cr_tangent(t::Array{<:IEEEFloat}) = t to_cr_tangent(::NoTangent) = ChainRulesCore.NoTangent() """ - increment_and_get_rdata!(fdata, rdata, cr_tangent) + increment_and_get_rdata!(fdata, zero_rdata, cr_tangent) +Increment `fdata` by the fdata component of the ChainRules.jl-style tangent, `cr_tangent`, +and return the rdata component of `cr_tangent` by adding it to `zero_rdata`. """ increment_and_get_rdata!(::NoFData, r::T, t::T) where {T<:IEEEFloat} = r + t function increment_and_get_rdata!(f::Array{P}, ::NoRData, t::Array{P}) where {P<:IEEEFloat} From ce5afd9e61f692a47cedcdcf7bd256ea65e24f44 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 25 Sep 2024 09:24:00 +0100 Subject: [PATCH 31/62] Some work --- ext/MooncakeLuxLibExt.jl | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index 3bf398709..a7ac94e1c 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -1,21 +1,30 @@ module MooncakeLuxLibExt - using LuxLib, Random, Mooncake - using Base: IEEEFloat +using LuxLib, Random, Mooncake +using Base: IEEEFloat - import LuxLib.Impl: matmul, matmuladd, fused_dense - import Mooncake: @from_rrule, DefaultCtx +import LuxLib.Impl: matmul, matmuladd, fused_dense +import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter - @from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) - @from_rrule( - DefaultCtx, - Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, - ) +@from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) +@from_rrule( + DefaultCtx, + Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, +) + +# Unfused version of `fused_dense`, which `build_rrule` makes use of. +function unfused_dense( + opmode, + act::F, + weight::AbstractMatrix, + x::AbstractMatrix, + b::LuxLib.Optional{<:AbstractVector}, +) where {F} + return bias_activation(act, matmul(opmode, weight, x), b) +end + +function Mooncake.build_rrule(interp::MooncakeInterpreter, sig_or_mi; kwargs...) + return Mooncake.build +end - # The implementations of rrules for fused operations are not straightforward to - # incorporate into Mooncake.jl, because they call back into AD. - # We take a simple appoach to their implementation: differentiate an un-fused version - # of their implementation. This will likely hit performance, but it makes implementing - # rules much more straightforward, in that we only have to be able to implement their - # constituent parts, rather than the entire thing. end From 6edc9a4677e5686605c353397ebe28326820616f Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 30 Sep 2024 18:52:17 +0100 Subject: [PATCH 32/62] Some work --- ext/MooncakeLuxLibExt.jl | 38 +++++++++++++++++----- src/interpreter/abstract_interpretation.jl | 23 +++++++++++-- src/interpreter/ir_utils.jl | 38 +++++++++++++++++----- src/rrules/fastmath.jl | 18 +++++----- test/ext/luxlib.jl | 25 +++++++++++--- test/integration_testing/lux.jl | 32 +++++++++--------- test/runtests.jl | 1 + 7 files changed, 126 insertions(+), 49 deletions(-) diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index 3add54e06..5ce8ba8fa 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -2,9 +2,10 @@ module MooncakeLuxLibExt using LuxLib, Random, Mooncake using Base: IEEEFloat +using Base.Experimental: @overlay -import LuxLib.Impl: matmul, matmuladd, fused_dense -import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter +import LuxLib.Impl: matmul, matmuladd +import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table @from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) @from_rrule( @@ -12,19 +13,40 @@ import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, ) -# Unfused version of `fused_dense`, which `build_rrule` makes use of. -function unfused_dense( +# Re-implement a bunch of methods to ensure that Mooncake can differentiate them. +@overlay mooncake_method_table function LuxLib.Impl.fused_dense( opmode, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::LuxLib.Optional{<:AbstractVector}, ) where {F} - return bias_activation(act, matmul(opmode, weight, x), b) + return bias_activation(act, matmul(weight, x), b) end -# function Mooncake.build_rrule(interp::MooncakeInterpreter, sig_or_mi; kwargs...) -# return Mooncake.build -# end +@overlay mooncake_method_table function LuxLib.Impl.bias_activation_loop!( + y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector +) where {F, xT, yT} + return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias) +end + +@overlay mooncake_method_table function LuxLib.Impl.activation_loop!( + y::AbstractArray, σ::F, x::AbstractArray +) where {F} + return LuxLib.Impl.activation_simd_loop!(y, σ, x) +end + +@overlay mooncake_method_table function LuxLib.Impl.fused_conv( + ::LuxLib.Impl.AbstractInternalArrayOpMode, + act::F, + weight::AbstractArray{wT, N}, + x::AbstractArray{xT, N}, + bias::LuxLib.Optional{<:AbstractVector}, + cdims::LuxLib.Impl.ConvDims, +) where {F, wT, xT, N} + return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias) +end + +# IMPORT SLEEFPirates RULES! Use a loop. end diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 745ed649a..90f716219 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -5,7 +5,6 @@ # The most important bit of this code is `inlining_policy` -- the rest is copy + pasted # boiler plate, largely taken from https://github.com/JuliaLang/julia/blob/2fe4190b3d26b4eee52b2b1b1054ddd6e38a941e/test/compiler/newinterp.jl#L11 - struct ClosureCacheKey world_age::UInt key::Any @@ -17,6 +16,8 @@ end MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance, Core.CodeInstance}()) +Base.Experimental.@MethodTable mooncake_method_table + struct MooncakeInterpreter{C} <: CC.AbstractInterpreter meta # additional information world::UInt @@ -25,6 +26,7 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult} code_cache::MooncakeCache oc_cache::Dict{ClosureCacheKey, Any} + method_table_to_overlay::CC.MethodTable function MooncakeInterpreter( ::Type{C}; meta=nothing, @@ -34,8 +36,18 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::MooncakeCache=MooncakeCache(), oc_cache::Dict{ClosureCacheKey, Any}=Dict{ClosureCacheKey, Any}(), + method_table_to_overlay::CC.MethodTable=mooncake_method_table, ) where {C} - return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) + return new{C}( + meta, + world, + inf_params, + opt_params, + inf_cache, + code_cache, + oc_cache, + method_table_to_overlay, + ) end end @@ -91,6 +103,9 @@ function CC.setindex!( ) return setindex!(wvc.cache.dict, ci, mi) end +function CC.method_table(interp::MooncakeInterpreter) + return CC.OverlayMethodTable(interp.world, interp.method_table_to_overlay) +end _type(x) = x _type(x::CC.Const) = _typeof(x.val) @@ -108,7 +123,9 @@ function CC.inlining_policy( # Do not inline away primitives. argtype_tuple = Tuple{map(_type, argtypes)...} - is_primitive(C, argtype_tuple) && return nothing + if is_primitive(C, argtype_tuple) + return nothing + end # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. return @invoke CC.inlining_policy( diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index bae39e46c..9370f11c7 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -170,6 +170,9 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) return ir end +Base.iterate(x::CC.MethodLookupResult) = CC.iterate(x) +Base.iterate(x::CC.MethodLookupResult, n::Int) = CC.iterate(x, n) + """ lookup_ir( interp::AbstractInterpreter, @@ -181,18 +184,35 @@ there is no code found, or if more than one `IRCode` instance returned. Returns a tuple containing the `IRCode` and its return type. """ -function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple}) - output = Base.code_ircode_by_type(sig; interp) - if isempty(output) - throw(ArgumentError("No methods found for signature $sig")) - elseif length(output) > 1 - throw(ArgumentError("$(length(output)) methods found for signature $sig")) +function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_until=nothing) + matches = CC.findall(tt, CC.method_table(interp)) + asts = [] + for match in matches.matches + match = match::Core.MethodMatch + meth = Base.func_for_method_checked(match.method, tt, match.sparams) + (code, ty) = CC.typeinf_ircode( + interp, + meth, + match.spec_types, + match.sparams, + optimize_until, + ) + if code === nothing + push!(asts, match.method => Any) + else + push!(asts, code => ty) + end + end + if isempty(asts) + throw(ArgumentError("No methods found for signature $asts")) + elseif length(asts) > 1 + throw(ArgumentError("$(length(asts)) methods found for signature $sig")) end - return only(output) + return only(asts) end -function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance) - return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, nothing) +function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance; optimize_until=nothing) + return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, optimize_until) end """ diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 93c7e17aa..26811f6bc 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -1,21 +1,21 @@ -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp_fast(primal(x)) - exp_fast_pb!!(dy::Float64) = NoRData(), dy * yp + exp_fast_pb!!(dy::P) = NoRData(), dy * yp return CoDual(yp, NoFData()), exp_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp2_fast(primal(x)) - exp2_fast_pb!!(dy::Float64) = NoRData(), dy * yp * log(2) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(2) return CoDual(yp, NoFData()), exp2_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp10_fast(primal(x)) - exp2_fast_pb!!(dy::Float64) = NoRData(), dy * yp * log(10) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(10) return CoDual(yp, NoFData()), exp2_fast_pb!! end diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index ce0e15c5e..46c6396a6 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -1,8 +1,25 @@ @testset "luxlib" begin - @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ - (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), - (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), - ] + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in vcat( + Any[ + (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), + (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), + (false, :none, false, LuxLib.Impl.activation, Lux.relu, randn(5, 4)), + ( + false, :none, false, + LuxLib.Impl.activation_loop!, randn(5, 3), NNlib.gelu, randn(5, 3), + ), + ], + vec(map(Iterators.product( + [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + [randn(5), nothing], + [Lux.relu, tanh, NNlib.gelu], + )) do (opmode, bias, activation) + ( + false, :none, false, + LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, + ) + end), + ) test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end end diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl index 3f4778e4b..e191751ba 100644 --- a/test/integration_testing/lux.jl +++ b/test/integration_testing/lux.jl @@ -2,18 +2,18 @@ @testset "$(typeof(f))" for (f, x_f32) in Any[ (Dense(2, 4), randn(Float32, 2, 3)), (Dense(2, 4, gelu), randn(Float32, 2, 3)), - # (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), - # (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), - # (Scale(2), randn(Float32, 2, 3)), - # (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), + (Scale(2), randn(Float32, 2, 3)), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule - # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), # (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule # (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow # (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow @@ -26,8 +26,8 @@ # (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression # (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression # (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # stack overflow. Probably task again - # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # another task problem + # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # task again # (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again # (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again # (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), @@ -37,8 +37,8 @@ # (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] @info "$(_typeof((f, x_f32...)))" - ps, st = f64(Lux.setup(sr(123456), f)) - x = f64(x_f32) - test_rule(sr(123456), f, x, ps, st; is_primitive=false) + ps, st = f32(Lux.setup(sr(123456), f)) + x = f32(x_f32) + test_rule(sr(123456), f, x, ps, st; is_primitive=false, interface_only=true) end end diff --git a/test/runtests.jl b/test/runtests.jl index 244030955..75d73aa53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ include("front_matter.jl") include(joinpath("ext", "luxlib.jl")) include(joinpath("ext", "nnlib.jl")) include(joinpath("ext", "special_functions.jl")) + include(joinpath("integration_testing", "lux.jl")) elseif test_group == "integration_testing/misc_abstract_array" include(joinpath("integration_testing", "misc_abstract_array.jl")) elseif test_group == "integration_testing/diff_tests" From f66cc9cdb1e5bb45a65657c6e41991fe4587cbe6 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 10:15:37 +0100 Subject: [PATCH 33/62] Better conv support in nnlib rules --- ext/MooncakeNNlibExt.jl | 45 +++++++++++++++++++++++++---------------- test/ext/nnlib.jl | 13 ++++++++++++ 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/ext/MooncakeNNlibExt.jl b/ext/MooncakeNNlibExt.jl index 706a627b2..fbd1b2fa7 100644 --- a/ext/MooncakeNNlibExt.jl +++ b/ext/MooncakeNNlibExt.jl @@ -4,51 +4,62 @@ module MooncakeNNlibExt using Base: IEEEFloat using NNlib: dropout - import Mooncake: @from_rrule, DefaultCtx + using NNlib: conv, depthwiseconv + import Mooncake: @from_rrule, DefaultCtx, MinimalCtx @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, ) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, true, ) - @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) - @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) - @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, ) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{ typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, }, ) @from_rrule( - DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} + MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} ) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, true, ) - for backend in (Symbol(), :_direct, :_im2col), name in (:conv, :depthwiseconv) + for conv in [:conv, :depthwiseconv] + local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) + + @eval @from_rrule( + MinimalCtx, + Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, + ) @eval @from_rrule( - DefaultCtx, - Tuple{ - typeof(NNlib.$(Symbol("$name$(backend)"))), Array{P}, Array{P}, ConvDims, - } where {P<:IEEEFloat}, + MinimalCtx, + Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, true, ) end + @eval @from_rrule( + MinimalCtx, + Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, + ) for pool in [:maxpool, :meanpool] @eval @from_rrule( - DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true + MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true ) end - @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) + @from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) end diff --git a/test/ext/nnlib.jl b/test/ext/nnlib.jl index 1f00d9f52..2a3c3fde5 100644 --- a/test/ext/nnlib.jl +++ b/test/ext/nnlib.jl @@ -3,6 +3,9 @@ w = randn(2, 2, 3, 3) dense_cdims = DenseConvDims(x, w) sep_cdims = DepthwiseConvDims(x, w) + y = conv(x, w, dense_cdims) + y_sep = depthwiseconv(x, w, sep_cdims) + pool_dims = PoolDims(size(x), 2) grid = Array{Float64}(undef, 2, 2, 2, 1) @@ -75,6 +78,16 @@ (false, :none, true, Core.kwcall, (;), depthwiseconv, x, w, sep_cdims), (false, :none, true, depthwiseconv, x, w, sep_cdims), + # ∇conv_data + (false, :none, true, Core.kwcall, (;), ∇conv_data, y, w, dense_cdims), + (false, :none, true, ∇conv_data, y, w, dense_cdims), + (false, :none, true, Core.kwcall, (;), ∇depthwiseconv_data, y_sep, w, sep_cdims), + (false, :none, true, ∇depthwiseconv_data, y_sep, w, sep_cdims), + + # ∇conv_filter + (false, :none, true, Core.kwcall, (;), ∇conv_filter, x, y, dense_cdims), + (false, :none, true, ∇conv_filter, x, y, dense_cdims), + # pooling (false, :none, true, maxpool, x, pool_dims), (false, :none, true, Core.kwcall, (;), maxpool, x, pool_dims), From f865fde5f0292ca53c29f941cd7491e1d6ff1cdd Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:27:10 +0100 Subject: [PATCH 34/62] More LuxLib rules --- ext/MooncakeLuxLibExt.jl | 136 ++++++++++++++++++++++++++++++-- test/ext/luxlib.jl | 47 ++++++++--- test/front_matter.jl | 1 + test/integration_testing/lux.jl | 45 +++++------ 4 files changed, 191 insertions(+), 38 deletions(-) diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index 5ce8ba8fa..1bfa63db7 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -4,13 +4,18 @@ using LuxLib, Random, Mooncake using Base: IEEEFloat using Base.Experimental: @overlay -import LuxLib.Impl: matmul, matmuladd -import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table +import LuxLib: Impl +import LuxLib.Utils: static_training_mode_check +import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table, CoDual -@from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) +@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) @from_rrule( DefaultCtx, - Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, + Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, +) +@from_rrule( + DefaultCtx, + Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, ) # Re-implement a bunch of methods to ensure that Mooncake can differentiate them. @@ -21,7 +26,7 @@ import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_t x::AbstractMatrix, b::LuxLib.Optional{<:AbstractVector}, ) where {F} - return bias_activation(act, matmul(weight, x), b) + return bias_activation(act, Impl.matmul(weight, x), b) end @overlay mooncake_method_table function LuxLib.Impl.bias_activation_loop!( @@ -47,6 +52,125 @@ end return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias) end -# IMPORT SLEEFPirates RULES! Use a loop. +for f in [ + Impl.SLEEFActivations.sigmoid_fast, + Impl.SLEEFActivations.softplus, + Impl.SLEEFActivations.logsigmoid, + Impl.SLEEFActivations.swish, + Impl.SLEEFActivations.lisht, + Impl.SLEEFActivations.tanh, + Impl.SLEEFActivations.tanh_fast, +] + @from_rrule DefaultCtx Tuple{typeof(f), IEEEFloat} + @from_rrule( + DefaultCtx, + Tuple{typeof(Broadcast.broadcasted), typeof(f), Union{IEEEFloat, Array{<:IEEEFloat}}}, + ) +end + +Mooncake.@is_primitive(DefaultCtx, Tuple{typeof(static_training_mode_check), Vararg}) +function Mooncake.rrule!!(f::CoDual{typeof(static_training_mode_check)}, x::CoDual...) + return Mooncake.simple_zero_adjoint(f, x...) +end + + + + +# This is a really horrible hack that we need to do until Mooncake is able to support the +# call-back-into-ad interface that ChainRules exposes. + +import LuxLib.Impl: + safe_eltype, + batchnorm_affine_normalize_internal, + batchnorm_affine_normalize_internal!, + ∇batchnorm_affine_normalize, + AbstractInternalArrayOpMode + +import ChainRulesCore as CRC + +function CRC.rrule( + ::typeof(batchnorm_affine_normalize_internal), + opmode::AbstractInternalArrayOpMode, + ::typeof(identity), + x::AbstractArray{T, N}, + μ::AbstractVector, + σ²::AbstractVector, + γ::LuxLib.Optional{<:AbstractVector}, + β::LuxLib.Optional{<:AbstractVector}, + ϵ::Real, +) where {T, N} + y = similar( + x, + promote_type( + safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) + ) + ) + γ′ = similar( + x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1) + ) + + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) + + 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) + 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) + 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) + + ∇batchnorm_affine_normalize_internal = LuxLib.Impl.@closure Δ -> begin + ∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, Δ, x, μ, σ², γ, β, ϵ, γ′) + ∂∅ = CRC.NoTangent() + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ + end + + return y, ∇batchnorm_affine_normalize_internal +end + +@from_rrule( + DefaultCtx, + Tuple{ + typeof(batchnorm_affine_normalize_internal), + AbstractInternalArrayOpMode, + typeof(identity), + AbstractArray, + AbstractVector, + AbstractVector, + LuxLib.Optional{<:AbstractVector}, + LuxLib.Optional{<:AbstractVector}, + Real, + }, +) + +@overlay mooncake_method_table function batchnorm_affine_normalize_internal( + opmode::LuxLib.AbstractInternalArrayOpMode, + act::F, + x::AbstractArray{xT, 3}, + μ::AbstractVector, + σ²::AbstractVector, + γ::Union{Nothing, AbstractVector}, + β::Union{Nothing, AbstractVector}, + ϵ::Real, +) where {F, xT} + y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ) + LuxLib.Impl.activation!(y, opmode, act, y) + return y +end + +@overlay mooncake_method_table function batchnorm_affine_normalize_internal( + opmode::LuxLib.AbstractInternalArrayOpMode, + ::typeof(identity), + x::AbstractArray{xT, 3}, + μ::AbstractVector, + σ²::AbstractVector, + γ::Union{Nothing, AbstractVector}, + β::Union{Nothing, AbstractVector}, + ϵ::Real, +) where {xT} + y = similar(x, + promote_type( + safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) + ) + ) + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) + return y +end end diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 46c6396a6..1befaed81 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -3,22 +3,49 @@ Any[ (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), + (false, :none, true, LuxLib.Impl.batched_matmul, randn(5, 4, 3), randn(4, 3, 3)), (false, :none, false, LuxLib.Impl.activation, Lux.relu, randn(5, 4)), ( false, :none, false, LuxLib.Impl.activation_loop!, randn(5, 3), NNlib.gelu, randn(5, 3), ), - ], - vec(map(Iterators.product( - [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], - [randn(5), nothing], - [Lux.relu, tanh, NNlib.gelu], - )) do (opmode, bias, activation) + (false, :stability_and_allocs, true, SLEEFActivations.sigmoid_fast, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.softplus, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.logsigmoid, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.swish, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.lisht, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.tanh, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.tanh_fast, randn()), ( - false, :none, false, - LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, - ) - end), + false, :stability_and_allocs, true, + LuxLib.Utils.static_training_mode_check, + nothing, + LuxLib.Utils.True(), + LuxLib.Utils.True(), + ), + ( + false, :none, true, + LuxLib.Impl.batchnorm_affine_normalize_internal, + LuxLib.LoopedArrayOp(), + identity, + randn(5, 4, 3), + randn(4), + ones(4), + nothing, + nothing, + 1.1, + ), + ], + # vec(map(Iterators.product( + # [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + # [randn(5), nothing], + # [Lux.relu, tanh, NNlib.gelu], + # )) do (opmode, bias, activation) + # ( + # false, :none, false, + # LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, + # ) + # end), ) test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end diff --git a/test/front_matter.jl b/test/front_matter.jl index 0cf19d693..78fe0a3b6 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -24,6 +24,7 @@ using Core: using Core.Intrinsics: pointerref, pointerset using NNlib: dropout +using LuxLib.Impl: SLEEFActivations using Mooncake: CC, diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl index e191751ba..15a90245b 100644 --- a/test/integration_testing/lux.jl +++ b/test/integration_testing/lux.jl @@ -12,29 +12,30 @@ (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), - # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule + (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), - # (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule - # (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), # fpext getting used here somehow - # (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), # fpext getting used here somehow - # (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # another task problem - # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # task again - # (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again - # (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again - # (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), - # (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), - # (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), - # (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), - # (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), + (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + (InstanceNorm(6), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] @info "$(_typeof((f, x_f32...)))" ps, st = f32(Lux.setup(sr(123456), f)) From 149e7b4d1941496fc70bd455c19da39e403094f2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:27:24 +0100 Subject: [PATCH 35/62] Permit :meta nodes in IR --- src/interpreter/s2s_reverse_mode_ad.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 039d94a00..472c8c0d0 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -617,6 +617,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) :leave, :pop_exception, :throw_undef_if_not, + :meta, ] # Expressions which do not require any special treatment. return ad_stmt_info(line, nothing, stmt, nothing) From 2dcd5350496e87225ba486cf9fbaa62065a8baa0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:35:01 +0100 Subject: [PATCH 36/62] Remove redundant test --- test/ext/luxlib.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 1befaed81..5edc96101 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -23,18 +23,6 @@ LuxLib.Utils.True(), LuxLib.Utils.True(), ), - ( - false, :none, true, - LuxLib.Impl.batchnorm_affine_normalize_internal, - LuxLib.LoopedArrayOp(), - identity, - randn(5, 4, 3), - randn(4), - ones(4), - nothing, - nothing, - 1.1, - ), ], # vec(map(Iterators.product( # [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], From 0933f37f61ccacd8baebf233a2dc23e53af5ebd7 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:46:09 +0100 Subject: [PATCH 37/62] Uncomment some tests --- test/ext/luxlib.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 5edc96101..1748d37fe 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -24,16 +24,16 @@ LuxLib.Utils.True(), ), ], - # vec(map(Iterators.product( - # [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], - # [randn(5), nothing], - # [Lux.relu, tanh, NNlib.gelu], - # )) do (opmode, bias, activation) - # ( - # false, :none, false, - # LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, - # ) - # end), + vec(map(Iterators.product( + [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + [randn(5), nothing], + [Lux.relu, tanh, NNlib.gelu], + )) do (opmode, bias, activation) + ( + false, :none, false, + LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, + ) + end), ) test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end From d217102f18a9a355fc47c668d0ba8c6d05ef6382 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:52:42 +0100 Subject: [PATCH 38/62] Rename chain rules doc --- docs/make.jl | 2 +- docs/src/{using_chain_rules.md => tools_for_rules.md} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename docs/src/{using_chain_rules.md => tools_for_rules.md} (100%) diff --git a/docs/make.jl b/docs/make.jl index 1a7d11797..f23ec6094 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -33,7 +33,7 @@ makedocs( "Mooncake.jl's Rule System" => "mathematical_interpretation.md", ], "Utilities" => [ - "Using ChainRules" => "using_chain_rules.md", + "Tools for Rules" => "tools_for_rules.md", "Debug Mode" => "debug_mode.md", "Debugging and MWEs" => "debugging_and_mwes.md", ], diff --git a/docs/src/using_chain_rules.md b/docs/src/tools_for_rules.md similarity index 100% rename from docs/src/using_chain_rules.md rename to docs/src/tools_for_rules.md From c6f8cf01518b23d6a7328b26f587d01413f69a7e Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 14:29:46 +0100 Subject: [PATCH 39/62] Add notes to docs on rule writing strategies --- docs/src/tools_for_rules.md | 39 ++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md index a7a077090..a3740e99e 100644 --- a/docs/src/tools_for_rules.md +++ b/docs/src/tools_for_rules.md @@ -1,4 +1,41 @@ -# Using ChainRules.jl +# Tools for Rules + +Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. +However, this does not always necessitate writing your own `rrule!!` from scratch. +In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations. + +## Simplfiying Code via Overlays + +Suppose you have a function +```julia +foo(x::Float64) = bar(x) +``` +where Mooncake.jl fails to differentiate `bar` for some reason. +If you have access to another function `baz`, which does the same thing as `bar`, but does so in a way which Mooncake.jl can differentiate, you can simply write: +```julia +Base.Experimental.@overlay Mooncake.mooncake_method_table foo(x::Float64) = baz(x) +``` +When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than the original, and should successfully differentiate it. +If you search for `@overlay` in the Mooncake.jl source code, you will see a variety of instances where this is used in practice. + +This approach is often very straightforward, and we recommend you try this first before going down the path of writing rules. + +## Functions with Zero Derivative + +If the above strategy does not work, but you find yourself in the surprisingly common situation that the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following: +```@docs +Mooncake.simple_zero_adjoint +``` +Suppose you have a function `foo(x, y, z)` whose derivative is zero, you would write an `rrule!!` as follows: +```julia +function Mooncake.rrule!!(f::CoDual{typeof(foo)}, x::CoDual, y::CoDual, z::CoDual) + return Mooncake.simple_zero_adjoint(f, x, y, z) +end +``` +Users of ChainRules.jl should be familiar with this functionality -- it is morally the same as `ChainRulesCore.@non_differentiable`. +This approach is utilised often in Mooncake.jl's codebase. + +## Using ChainRules.jl [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the `ChainRulesCore.rrule` function. From d12afa4d070ad870ce5dbe4da19f9f87d07431cb Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Oct 2024 14:39:38 +0100 Subject: [PATCH 40/62] Add mooncake_overlay --- docs/src/tools_for_rules.md | 24 +++++----- ext/MooncakeLuxLibExt.jl | 19 ++++---- src/Mooncake.jl | 1 + src/interpreter/method_overlays.jl | 68 +++++++++++++++++++++++++++++ test/interpreter/method_overlays.jl | 7 +++ test/runtests.jl | 1 + 6 files changed, 100 insertions(+), 20 deletions(-) create mode 100644 src/interpreter/method_overlays.jl create mode 100644 test/interpreter/method_overlays.jl diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md index a3740e99e..abc6b0533 100644 --- a/docs/src/tools_for_rules.md +++ b/docs/src/tools_for_rules.md @@ -1,24 +1,20 @@ # Tools for Rules +```@meta +DocTestSetup = quote + using Mooncake +end +``` + Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own `rrule!!` from scratch. In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations. ## Simplfiying Code via Overlays -Suppose you have a function -```julia -foo(x::Float64) = bar(x) -``` -where Mooncake.jl fails to differentiate `bar` for some reason. -If you have access to another function `baz`, which does the same thing as `bar`, but does so in a way which Mooncake.jl can differentiate, you can simply write: -```julia -Base.Experimental.@overlay Mooncake.mooncake_method_table foo(x::Float64) = baz(x) +```@docs +Mooncake.@mooncake_overlay ``` -When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than the original, and should successfully differentiate it. -If you search for `@overlay` in the Mooncake.jl source code, you will see a variety of instances where this is used in practice. - -This approach is often very straightforward, and we recommend you try this first before going down the path of writing rules. ## Functions with Zero Derivative @@ -48,3 +44,7 @@ The docstrings below explain this functionality, and how it should / should not Mooncake.@from_rrule Mooncake.rrule_wrapper ``` + +```@meta +DocTestSetup = nothing +``` diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index 1bfa63db7..bd706ea1c 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -2,11 +2,14 @@ module MooncakeLuxLibExt using LuxLib, Random, Mooncake using Base: IEEEFloat -using Base.Experimental: @overlay import LuxLib: Impl import LuxLib.Utils: static_training_mode_check -import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table, CoDual +import Mooncake: + @from_rrule, + DefaultCtx, + @mooncake_overlay, + CoDual @from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) @from_rrule( @@ -19,7 +22,7 @@ import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_t ) # Re-implement a bunch of methods to ensure that Mooncake can differentiate them. -@overlay mooncake_method_table function LuxLib.Impl.fused_dense( +@mooncake_overlay function LuxLib.Impl.fused_dense( opmode, act::F, weight::AbstractMatrix, @@ -29,19 +32,19 @@ import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_t return bias_activation(act, Impl.matmul(weight, x), b) end -@overlay mooncake_method_table function LuxLib.Impl.bias_activation_loop!( +@mooncake_overlay function LuxLib.Impl.bias_activation_loop!( y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector ) where {F, xT, yT} return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias) end -@overlay mooncake_method_table function LuxLib.Impl.activation_loop!( +@mooncake_overlay function LuxLib.Impl.activation_loop!( y::AbstractArray, σ::F, x::AbstractArray ) where {F} return LuxLib.Impl.activation_simd_loop!(y, σ, x) end -@overlay mooncake_method_table function LuxLib.Impl.fused_conv( +@mooncake_overlay function LuxLib.Impl.fused_conv( ::LuxLib.Impl.AbstractInternalArrayOpMode, act::F, weight::AbstractArray{wT, N}, @@ -139,7 +142,7 @@ end }, ) -@overlay mooncake_method_table function batchnorm_affine_normalize_internal( +@mooncake_overlay function batchnorm_affine_normalize_internal( opmode::LuxLib.AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3}, @@ -154,7 +157,7 @@ end return y end -@overlay mooncake_method_table function batchnorm_affine_normalize_internal( +@mooncake_overlay function batchnorm_affine_normalize_internal( opmode::LuxLib.AbstractInternalArrayOpMode, ::typeof(identity), x::AbstractArray{xT, 3}, diff --git a/src/Mooncake.jl b/src/Mooncake.jl index c9abae09f..1a260669b 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -71,6 +71,7 @@ include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) +include(joinpath("interpreter", "method_overlays.jl")) include("test_utils.jl") diff --git a/src/interpreter/method_overlays.jl b/src/interpreter/method_overlays.jl new file mode 100644 index 000000000..9334a6161 --- /dev/null +++ b/src/interpreter/method_overlays.jl @@ -0,0 +1,68 @@ +""" + @mooncake_overlay method_expr + +Define a method of a function which only Mooncake can see. This can be used to write +versions of methods which can be successfully differentiated by Mooncake if the original +cannot be. + +For example, suppose that you have a function +```jldoctest overlay +julia> foo(x::Float64) = bar(x) +foo (generic function with 1 method) +``` +where Mooncake.jl fails to differentiate `bar` for some reason. +If you have access to another function `baz`, which does the same thing as `bar`, but does + so in a way which Mooncake.jl can differentiate, you can simply write: +```jldoctest overlay +julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x) + +``` +When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than +the original, and differentiate it instead. + +# A Worked Example + +To demonstrate how to use `@mooncake_overlay`s in practice, we here demonstrate how the +answer that Mooncake.jl gives changes if you change the definition of a function using a +`@mooncake_overlay`. +Do not do this in practice -- this is just a simple way to demonostrate how to use overlays! + +First, consider a simple example: +```jldoctest overlay-doctest +julia> scale(x) = 2x +scale (generic function with 1 method) + +julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); + +julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) +(10.0, (NoTangent(), 2.0)) +``` + +We can use `@mooncake_overlay` to change the definition which Mooncake.jl sees: +```jldoctest overlay-doctest +julia> Mooncake.@mooncake_overlay scale(x) = 3x + +julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); + +julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) +(15.0, (NoTangent(), 3.0)) +``` +As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method. + +Additionally, it is possible to use the usual multi-line syntax to declare an overlay: +```jldoctest overlay-doctest +julia> Mooncake.@mooncake_overlay function scale(x) + return 4x + end + +julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); + +julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) +(20.0, (NoTangent(), 4.0)) +``` +""" +macro mooncake_overlay(method_expr) + def = splitdef(method_expr) + def[:name] = Expr(:overlay, :(Mooncake.mooncake_method_table), def[:name]) + return esc(combinedef(def)) +end diff --git a/test/interpreter/method_overlays.jl b/test/interpreter/method_overlays.jl new file mode 100644 index 000000000..306e9b2e6 --- /dev/null +++ b/test/interpreter/method_overlays.jl @@ -0,0 +1,7 @@ +overlay_tester(x) = 2x +Mooncake.@mooncake_overlay overlay_tester(x) = 3x + +@testset "method_overlays" begin + rule = Mooncake.build_rrule(Tuple{typeof(overlay_tester), Float64}) + @test value_and_gradient!!(rule, overlay_tester, 5.0) == (15.0, (NoTangent(), 3.0)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 75d73aa53..c99426cdd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ include("front_matter.jl") include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) + include(joinpath("interpreter", "method_overlays.jl")) end include("interface.jl") include("config.jl") From fe1999d42059516b377fbf5c21f33410c3a1880a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Oct 2024 14:40:30 +0100 Subject: [PATCH 41/62] Add simpler method of build_rrule --- src/interpreter/s2s_reverse_mode_ad.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 472c8c0d0..933285ba2 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -830,6 +830,13 @@ function build_rrule(args...; debug_mode=false) return build_rrule(interp, _typeof(TestUtils.__get_primals(args)); debug_mode) end +""" + build_rrule(sig_or_mi) + +Equivalent to `build_rrule(Mooncake.get_interpreter(), sig_or_mi)`. +""" +build_rrule(sig_or_mi) = build_rrule(get_interpreter(), sig_or_mi) + const MOONCAKE_INFERENCE_LOCK = ReentrantLock() """ From e9dce9dce542892e460d2aeb960febae885a101c Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Oct 2024 14:57:41 +0100 Subject: [PATCH 42/62] Fix dispatch problem --- src/interpreter/s2s_reverse_mode_ad.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 933285ba2..b0ce364d8 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -831,11 +831,11 @@ function build_rrule(args...; debug_mode=false) end """ - build_rrule(sig_or_mi) + build_rrule(sig::Type{<:Tuple}) -Equivalent to `build_rrule(Mooncake.get_interpreter(), sig_or_mi)`. +Equivalent to `build_rrule(Mooncake.get_interpreter(), sig)`. """ -build_rrule(sig_or_mi) = build_rrule(get_interpreter(), sig_or_mi) +build_rrule(sig::Type{<:Tuple}) = build_rrule(get_interpreter(), sig) const MOONCAKE_INFERENCE_LOCK = ReentrantLock() From 6149fd5a68485175947784aef4b08cc1283bd190 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Oct 2024 17:45:03 +0100 Subject: [PATCH 43/62] Tidy up --- docs/src/tools_for_rules.md | 27 +-- ext/MooncakeDynamicPPLExt.jl | 14 +- ext/MooncakeLuxLibExt.jl | 8 +- src/Mooncake.jl | 3 +- src/codual.jl | 17 -- src/interpreter/method_overlays.jl | 68 ------ .../avoiding_non_differentiable_code.jl | 13 +- src/rrules/blas.jl | 15 +- src/rrules/builtins.jl | 29 +-- src/rrules/foreigncall.jl | 54 +---- src/rrules/misc.jl | 61 ++--- src/rrules/tasks.jl | 3 +- ...in_rules_interop.jl => tools_for_rules.jl} | 214 +++++++++++++++--- test/interpreter/method_overlays.jl | 7 - test/runtests.jl | 3 +- ...in_rules_interop.jl => tools_for_rules.jl} | 68 ++++-- 16 files changed, 294 insertions(+), 310 deletions(-) delete mode 100644 src/interpreter/method_overlays.jl rename src/{chain_rules_interop.jl => tools_for_rules.jl} (53%) delete mode 100644 test/interpreter/method_overlays.jl rename test/{chain_rules_interop.jl => tools_for_rules.jl} (58%) diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md index abc6b0533..db46a23b9 100644 --- a/docs/src/tools_for_rules.md +++ b/docs/src/tools_for_rules.md @@ -1,11 +1,5 @@ # Tools for Rules -```@meta -DocTestSetup = quote - using Mooncake -end -``` - Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own `rrule!!` from scratch. In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations. @@ -16,20 +10,14 @@ In this section, we detail some useful strategies which can help you avoid havin Mooncake.@mooncake_overlay ``` -## Functions with Zero Derivative +## Functions with Zero Adjoint -If the above strategy does not work, but you find yourself in the surprisingly common situation that the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following: +If the above strategy does not work, but you find yourself in the surprisingly common +situation that the adjoint of the derivative of your function is always zero, you can very +straightforwardly write a rule by making use of the following: ```@docs -Mooncake.simple_zero_adjoint -``` -Suppose you have a function `foo(x, y, z)` whose derivative is zero, you would write an `rrule!!` as follows: -```julia -function Mooncake.rrule!!(f::CoDual{typeof(foo)}, x::CoDual, y::CoDual, z::CoDual) - return Mooncake.simple_zero_adjoint(f, x, y, z) -end +Mooncake.@zero_adjoint ``` -Users of ChainRules.jl should be familiar with this functionality -- it is morally the same as `ChainRulesCore.@non_differentiable`. -This approach is utilised often in Mooncake.jl's codebase. ## Using ChainRules.jl @@ -42,9 +30,4 @@ The docstrings below explain this functionality, and how it should / should not ```@docs Mooncake.@from_rrule -Mooncake.rrule_wrapper -``` - -```@meta -DocTestSetup = nothing ``` diff --git a/ext/MooncakeDynamicPPLExt.jl b/ext/MooncakeDynamicPPLExt.jl index 84fef9598..c8184728e 100644 --- a/ext/MooncakeDynamicPPLExt.jl +++ b/ext/MooncakeDynamicPPLExt.jl @@ -1,17 +1,9 @@ module MooncakeDynamicPPLExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL, istrans - using Mooncake: Mooncake -else - using ..DynamicPPL: DynamicPPL, istrans - using ..Mooncake: Mooncake -end - -using Mooncake: DefaultCtx, CoDual, simple_zero_adjoint +using DynamicPPL: DynamicPPL, istrans +using Mooncake: Mooncake # This is purely an optimisation. -Mooncake.@is_primitive DefaultCtx Tuple{typeof(istrans), Vararg} -Mooncake.rrule!!(f::CoDual{typeof(istrans)}, x::CoDual...) = simple_zero_adjoint(f, x...) +Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg} end # module diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index bd706ea1c..10ae97bba 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -71,13 +71,7 @@ for f in [ ) end -Mooncake.@is_primitive(DefaultCtx, Tuple{typeof(static_training_mode_check), Vararg}) -function Mooncake.rrule!!(f::CoDual{typeof(static_training_mode_check)}, x::CoDual...) - return Mooncake.simple_zero_adjoint(f, x...) -end - - - +Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check), Vararg} # This is a really horrible hack that we need to do until Mooncake is able to support the # call-back-into-ad interface that ChainRules exposes. diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 1a260669b..67c8f5459 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -71,8 +71,8 @@ include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) -include(joinpath("interpreter", "method_overlays.jl")) +include("tools_for_rules.jl") include("test_utils.jl") include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) @@ -87,7 +87,6 @@ include(joinpath("rrules", "misc.jl")) include(joinpath("rrules", "new.jl")) include(joinpath("rrules", "tasks.jl")) -include("chain_rules_interop.jl") include("interface.jl") include("config.jl") diff --git a/src/codual.jl b/src/codual.jl index 44af2a605..dac1b2c1b 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -65,23 +65,6 @@ end @inline (pb::NoPullback)(_) = tuple_map(instantiate, pb.r) -""" - simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} - -Utility functionality for constructing `rrule!!`s for functions which produce adjoints which -always return zero. Equivalent to: -```julia -zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) -``` - -WARNING: this is only correct if the output of `primal(f)(map(primal, x)...)` does not alias -anything in `f` or `x`. This is always the case if the result is a bits type, but more care -may be required if it is not. -""" -@inline function simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} - return zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) -end - to_fwds(x::CoDual) = CoDual(primal(x), fdata(tangent(x))) to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFData()) diff --git a/src/interpreter/method_overlays.jl b/src/interpreter/method_overlays.jl deleted file mode 100644 index 9334a6161..000000000 --- a/src/interpreter/method_overlays.jl +++ /dev/null @@ -1,68 +0,0 @@ -""" - @mooncake_overlay method_expr - -Define a method of a function which only Mooncake can see. This can be used to write -versions of methods which can be successfully differentiated by Mooncake if the original -cannot be. - -For example, suppose that you have a function -```jldoctest overlay -julia> foo(x::Float64) = bar(x) -foo (generic function with 1 method) -``` -where Mooncake.jl fails to differentiate `bar` for some reason. -If you have access to another function `baz`, which does the same thing as `bar`, but does - so in a way which Mooncake.jl can differentiate, you can simply write: -```jldoctest overlay -julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x) - -``` -When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than -the original, and differentiate it instead. - -# A Worked Example - -To demonstrate how to use `@mooncake_overlay`s in practice, we here demonstrate how the -answer that Mooncake.jl gives changes if you change the definition of a function using a -`@mooncake_overlay`. -Do not do this in practice -- this is just a simple way to demonostrate how to use overlays! - -First, consider a simple example: -```jldoctest overlay-doctest -julia> scale(x) = 2x -scale (generic function with 1 method) - -julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); - -julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) -(10.0, (NoTangent(), 2.0)) -``` - -We can use `@mooncake_overlay` to change the definition which Mooncake.jl sees: -```jldoctest overlay-doctest -julia> Mooncake.@mooncake_overlay scale(x) = 3x - -julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); - -julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) -(15.0, (NoTangent(), 3.0)) -``` -As can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method. - -Additionally, it is possible to use the usual multi-line syntax to declare an overlay: -```jldoctest overlay-doctest -julia> Mooncake.@mooncake_overlay function scale(x) - return 4x - end - -julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); - -julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) -(20.0, (NoTangent(), 4.0)) -``` -""" -macro mooncake_overlay(method_expr) - def = splitdef(method_expr) - def[:name] = Expr(:overlay, :(Mooncake.mooncake_method_table), def[:name]) - return esc(combinedef(def)) -end diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index a63b5f60a..d3a1800ae 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -6,16 +6,9 @@ function rrule!!(f::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Int return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback(f, x, y) end -@is_primitive MinimalCtx Tuple{typeof(randn), AbstractRNG, Vararg} -function rrule!!(f::CoDual{typeof(randn)}, rng::CoDual{<:AbstractRNG}, args::CoDual...) - return simple_zero_adjoint(f, rng, args...) -end - -@is_primitive MinimalCtx Tuple{typeof(string), Vararg} -rrule!!(f::CoDual{typeof(string)}, x::CoDual...) = simple_zero_adjoint(f, x...) - -@is_primitive MinimalCtx Tuple{Type{Symbol}, Vararg} -rrule!!(f::CoDual{Type{Symbol}}, x::CoDual...) = simple_zero_adjoint(f, x...) +@zero_adjoint MinimalCtx Tuple{typeof(randn), AbstractRNG, Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(string), Vararg} +@zero_adjoint MinimalCtx Tuple{Type{Symbol}, Vararg} function generate_hand_written_rrule!!_test_cases( rng_ctor, ::Val{:avoiding_non_differentiable_code} diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 1aaf51add..830127000 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -25,17 +25,10 @@ const MatrixOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}} # Utility # -@is_primitive MinimalCtx Tuple{typeof(BLAS.get_num_threads)} -rrule!!(f::CoDual{typeof(BLAS.get_num_threads)}) = simple_zero_adjoint(f) - -@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} -rrule!!(f::CoDual{typeof(BLAS.lbt_get_num_threads)}) = simple_zero_adjoint(f) - -@is_primitive MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} -rrule!!(f::CoDual{typeof(BLAS.set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) - -@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} -rrule!!(f::CoDual{typeof(BLAS.lbt_set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) +@zero_adjoint MinimalCtx Tuple{typeof(BLAS.get_num_threads)} +@zero_adjoint MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} +@zero_adjoint MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} +@zero_adjoint MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} # # LEVEL 1 diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 9b1d9f0cc..8d58960f6 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -418,9 +418,8 @@ end end # IntrinsicsWrappers -rrule!!(f::CoDual{typeof(<:)}, T1, T2) = simple_zero_adjoint(f, T1, T2) - -rrule!!(f::CoDual{typeof(===)}, x, y) = simple_zero_adjoint(f, x, y) +@zero_adjoint MinimalCtx Tuple{typeof(<:), Any, Any} +@zero_adjoint MinimalCtx Tuple{typeof(===), Any, Any} # Core._abstracttype @@ -463,9 +462,7 @@ end # Core._call_latest # Doesn't do anything differentiable. -function rrule!!(f::CoDual{typeof(Core._compute_sparams)}, args::CoDual...) - return simple_zero_adjoint(f, args...) -end +@zero_adjoint MinimalCtx Tuple{typeof(Core._compute_sparams), Vararg} # Core._equiv_typedef # Core._expr @@ -615,15 +612,12 @@ end # Core.set_binding_type! -rrule!!(f::CoDual{typeof(Core.sizeof)}, x) = simple_zero_adjoint(f, x) +@zero_adjoint MinimalCtx Tuple{typeof(Core.sizeof), Any} # Core.svec -rrule!!(_f::CoDual{typeof(applicable)}, f, args...) = simple_zero_adjoint(_f, f, args...) - -function rrule!!(f::CoDual{typeof(Core.fieldtype)}, args::Vararg{Any, N}) where {N} - return simple_zero_adjoint(f, args...) -end +@zero_adjoint MinimalCtx Tuple{typeof(applicable), Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(fieldtype), Vararg} function rrule!!(f::CoDual{typeof(getfield)}, x::CoDual{P}, name::CoDual) where {P} if tangent_type(P) == NoTangent @@ -680,17 +674,16 @@ is_homogeneous_and_immutable(::Any) = false # return y, pb!! # end -rrule!!(f::CoDual{typeof(getglobal)}, a, b) = simple_zero_adjoint(f, a, b) +@zero_adjoint MinimalCtx Tuple{typeof(getglobal), Any, Any} # invoke -rrule!!(f::CoDual{typeof(isa)}, x, T) = simple_zero_adjoint(f, x, T) - -rrule!!(f::CoDual{typeof(isdefined)}, args...) = simple_zero_adjoint(f, args...) +@zero_adjoint MinimalCtx Tuple{typeof(isa), Any, Any} +@zero_adjoint MinimalCtx Tuple{typeof(isdefined), Vararg} # modifyfield! -rrule!!(f::CoDual{typeof(nfields)}, x) = simple_zero_adjoint(f, x) +@zero_adjoint MinimalCtx Tuple{typeof(nfields), Any} # replacefield! @@ -732,7 +725,7 @@ function rrule!!(::CoDual{typeof(typeassert)}, x::CoDual, type::CoDual) return CoDual(typeassert(primal(x), primal(type)), tangent(x)), typeassert_pullback end -rrule!!(f::CoDual{typeof(typeof)}, x::CoDual) = simple_zero_adjoint(f, x) +@zero_adjoint MinimalCtx Tuple{typeof(typeof), Any} function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) diff --git a/src/rrules/foreigncall.jl b/src/rrules/foreigncall.jl index 4df6ceb40..a00c98034 100644 --- a/src/rrules/foreigncall.jl +++ b/src/rrules/foreigncall.jl @@ -65,30 +65,10 @@ end # Rules to handle / avoid foreigncall nodes # -@is_primitive MinimalCtx Tuple{typeof(Base.allocatedinline), Type} -function rrule!!(f::CoDual{typeof(Base.allocatedinline)}, T::CoDual{<:Type}) - return simple_zero_adjoint(f, T) -end - -@is_primitive MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Vararg} where {T, N} -function rrule!!( - f::CoDual{Type{Array{T, N}}}, u::CoDual{typeof(undef)}, m::Vararg{CoDual} -) where {T, N} - return simple_zero_adjoint(f, u, m...) -end - -function rrule!!( - f::CoDual{Type{Array{T, 0}}}, u::CoDual{typeof(undef)}, m::CoDual{Tuple{}} -) where {T} - return simple_zero_adjoint(f, u, m) -end - -@is_primitive MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), NTuple{N}} where {T, N} -function rrule!!( - ::CoDual{<:Type{<:Array{T, N}}}, ::CoDual{typeof(undef)}, m::CoDual{NTuple{N}}, -) where {T, N} - return rrule!!(zero_fcodual(Array{T, N}), zero_fcodual(undef), m) -end +@zero_adjoint MinimalCtx Tuple{typeof(Base.allocatedinline), Type} +@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Vararg} where {T, N} +@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Tuple{}} where {T, N} +@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), NTuple{N}} where {T, N} @is_primitive MinimalCtx Tuple{typeof(copy), Array} function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) @@ -260,8 +240,7 @@ function rrule!!(f::CoDual{typeof(sizehint!)}, x::CoDual{<:Vector}, sz::CoDual{< return x, NoPullback(f, x, sz) end -@is_primitive MinimalCtx Tuple{typeof(objectid), Any} -rrule!!(f::CoDual{typeof(objectid)}, @nospecialize(x)) = simple_zero_adjoint(f, x) +@zero_adjoint MinimalCtx Tuple{typeof(objectid), Any} @is_primitive MinimalCtx Tuple{typeof(pointer_from_objref), Any} function rrule!!(f::CoDual{typeof(pointer_from_objref)}, x) @@ -272,10 +251,7 @@ function rrule!!(f::CoDual{typeof(pointer_from_objref)}, x) return y, NoPullback(f, x) end -@is_primitive MinimalCtx Tuple{typeof(CC.return_type), Vararg} -function rrule!!(f::CoDual{typeof(Core.Compiler.return_type)}, args...) - return simple_zero_adjoint(f, args...) -end +@zero_adjoint MinimalCtx Tuple{typeof(CC.return_type), Vararg} @is_primitive MinimalCtx Tuple{typeof(Base.unsafe_pointer_to_objref), Ptr} function rrule!!(f::CoDual{typeof(Base.unsafe_pointer_to_objref)}, x::CoDual{<:Ptr}) @@ -283,11 +259,8 @@ function rrule!!(f::CoDual{typeof(Base.unsafe_pointer_to_objref)}, x::CoDual{<:P return y, NoPullback(f, x) end -@is_primitive MinimalCtx Tuple{typeof(Threads.threadid)} -rrule!!(f::CoDual{typeof(Threads.threadid)}) = simple_zero_adjoint(f) - -@is_primitive MinimalCtx Tuple{typeof(typeintersect), Any, Any} -rrule!!(f::CoDual{typeof(typeintersect)}, a, b) = simple_zero_adjoint(f, a, b) +@zero_adjoint MinimalCtx Tuple{typeof(Threads.threadid)} +@zero_adjoint MinimalCtx Tuple{typeof(typeintersect), Any, Any} function _increment_pointer!(x::Ptr{T}, y::Ptr{T}, N::Integer) where {T} increment!!(unsafe_wrap(Vector{T}, x, N), unsafe_wrap(Vector{T}, y, N)) @@ -476,14 +449,9 @@ function rrule!!(::CoDual{typeof(deepcopy)}, x::CoDual) return y, deepcopy_pb!! end -@is_primitive MinimalCtx Tuple{Type{UnionAll}, TypeVar, Any} -@is_primitive MinimalCtx Tuple{Type{UnionAll}, TypeVar, Type} -function rrule!!(f::CoDual{<:Type{UnionAll}}, x::CoDual{<:TypeVar}, y::CoDual{<:Type}) - return simple_zero_adjoint(f, x, y) -end - -@is_primitive MinimalCtx Tuple{typeof(hash), Vararg} -rrule!!(f::CoDual{typeof(hash)}, x::CoDual...) = simple_zero_adjoint(f, x...) +@zero_adjoint MinimalCtx Tuple{Type{UnionAll}, TypeVar, Any} +@zero_adjoint MinimalCtx Tuple{Type{UnionAll}, TypeVar, Type} +@zero_adjoint MinimalCtx Tuple{typeof(hash), Vararg} function rrule!!( f::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_string_ptr}}, args::Vararg{CoDual, N} diff --git a/src/rrules/misc.jl b/src/rrules/misc.jl index a73f594aa..37c1761d9 100644 --- a/src/rrules/misc.jl +++ b/src/rrules/misc.jl @@ -6,42 +6,31 @@ # deduce that these bits of code are inactive though. # -for name in [ - :size, - :(LinearAlgebra.lapack_size), - :(Base.require_one_based_indexing), - :in, - :iszero, - :isempty, - :isbitstype, - :sizeof, - :promote_type, - :(Base.elsize), - :(Core.Compiler.sizeof_nothrow), - :(Base.datatype_haspadding), - :(Base.datatype_nfields), - :(Base.datatype_pointerfree), - :(Base.datatype_alignment), - :(Base.datatype_fielddesc_type), - :(LinearAlgebra.chkstride1), - :(Threads.nthreads), - :(Base.depwarn), - :(Base.reduced_indices), - :(Base.check_reducedims), - :(Base.throw_boundserror), - :(Base.Broadcast.eltypes), - :(Base.eltype), -] - @eval @is_primitive DefaultCtx Tuple{typeof($name), Vararg} - @eval function rrule!!(f::CoDual{_typeof($name)}, args::Vararg{CoDual, N}) where {N} - return simple_zero_adjoint(f, args...) - end -end - -@is_primitive MinimalCtx Tuple{Type, TypeVar, Type} -function rrule!!(x::CoDual{<:Type}, y::CoDual{<:TypeVar}, z::CoDual{<:Type}) - return simple_zero_adjoint(x, y, z) -end +@zero_adjoint DefaultCtx Tuple{typeof(size), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(LinearAlgebra.lapack_size), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.require_one_based_indexing), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(in), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(iszero), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(isempty), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(isbitstype), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(sizeof), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(promote_type), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.elsize), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Core.Compiler.sizeof_nothrow), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_haspadding), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_nfields), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_pointerfree), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_alignment), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_fielddesc_type), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(LinearAlgebra.chkstride1), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Threads.nthreads), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.depwarn), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.reduced_indices), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.check_reducedims), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.throw_boundserror), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.Broadcast.eltypes), Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.eltype), Vararg} +@zero_adjoint MinimalCtx Tuple{Type, TypeVar, Type} """ lgetfield(x, f::Val) diff --git a/src/rrules/tasks.jl b/src/rrules/tasks.jl index b872fb383..5b5016c15 100644 --- a/src/rrules/tasks.jl +++ b/src/rrules/tasks.jl @@ -55,8 +55,7 @@ end set_tangent_field!(t::TaskTangent, f, ::NoTangent) = NoTangent() -@is_primitive MinimalCtx Tuple{typeof(current_task)} -rrule!!(f::CoDual{typeof(current_task)}) = simple_zero_adjoint(f) +@zero_adjoint MinimalCtx Tuple{typeof(current_task)} _verify_fdata_value(::Task, ::TaskTangent) = nothing diff --git a/src/chain_rules_interop.jl b/src/tools_for_rules.jl similarity index 53% rename from src/chain_rules_interop.jl rename to src/tools_for_rules.jl index ed027081f..6f8d8a9e5 100644 --- a/src/chain_rules_interop.jl +++ b/src/tools_for_rules.jl @@ -1,4 +1,178 @@ - """ +""" + @mooncake_overlay method_expr + +Define a method of a function which only Mooncake can see. This can be used to write +versions of methods which can be successfully differentiated by Mooncake if the original +cannot be. + +For example, suppose that you have a function +```jldoctest overlay +julia> foo(x::Float64) = bar(x) +foo (generic function with 1 method) +``` +where Mooncake.jl fails to differentiate `bar` for some reason. +If you have access to another function `baz`, which does the same thing as `bar`, but does + so in a way which Mooncake.jl can differentiate, you can simply write: +```jldoctest overlay +julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x) + +``` +When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than +the original, and differentiate it instead. + +# A Worked Example + +To demonstrate how to use `@mooncake_overlay`s in practice, we here demonstrate how the +answer that Mooncake.jl gives changes if you change the definition of a function using a +`@mooncake_overlay`. +Do not do this in practice -- this is just a simple way to demonostrate how to use overlays! + +First, consider a simple example: +```jldoctest overlay-doctest +julia> scale(x) = 2x +scale (generic function with 1 method) + +julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); + +julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) +(10.0, (NoTangent(), 2.0)) +``` + +We can use `@mooncake_overlay` to change the definition which Mooncake.jl sees: +```jldoctest overlay-doctest +julia> Mooncake.@mooncake_overlay scale(x) = 3x + +julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); + +julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) +(15.0, (NoTangent(), 3.0)) +``` +As can be seen from the output, the result of differentiating using Mooncake.jl has changed +to reflect the overlay-ed definition of the method. + +Additionally, it is possible to use the usual multi-line syntax to declare an overlay: +```jldoctest overlay-doctest +julia> Mooncake.@mooncake_overlay function scale(x) + return 4x + end + +julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); + +julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) +(20.0, (NoTangent(), 4.0)) +``` +""" +macro mooncake_overlay(method_expr) + def = splitdef(method_expr) + def[:name] = Expr(:overlay, :(Mooncake.mooncake_method_table), def[:name]) + return esc(combinedef(def)) +end + +function parse_signature_expr(sig::Expr) + # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. + if sig.head == :curly + @assert sig.args[1] == :Tuple + arg_type_symbols = sig.args[2:end] + where_params = nothing + elseif sig.head == :where + @assert sig.args[1].args[1] == :Tuple + arg_type_symbols = sig.args[1].args[2:end] + where_params = sig.args[2:end] + else + throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) + end + return arg_type_symbols, where_params +end + +function construct_def(arg_names, arg_types, where_params, body) + name = :(Mooncake.rrule!!) + arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) + def = Dict(:head => :function, :name => name, :args => arg_exprs, :body => body) + where_params !== nothing && setindex!(def, where_params, :whereparams) + return ExprTools.combinedef(def) +end + +""" + @zero_adjoint ctx sig + +Defines `is_primitive(context_type, sig) = true`, and defines a method of +`Mooncake.rrule!!` which returns zero for all inputs. +Users of ChainRules.jl should be familiar with this functionality -- it is morally the same +as `ChainRulesCore.@non_differentiable`. + +For example: +```jldoctest +julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive + +julia> foo(x) = 5 +foo (generic function with 1 method) + +julia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any} + +julia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any}) +true + +julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData()) +(NoRData(), 0.0) +``` + +WARNING: this is only correct if the output of the function does not alias any fields of the +function, or any of its arguments. For example, applying this macro to the function `x -> x` +will yield incorrect results. + +As always, you should use [`Mooncake.TestUtils.test_rule`](@ref) to ensure that you've not +made a mistake. +""" +macro zero_adjoint(ctx, sig) + + # Parse the signature, and construct the rule definition. If it is a vararg definition, + # then the last argument requires special treatment. + arg_type_symbols, where_params = parse_signature_expr(sig) + arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) + is_vararg = arg_type_symbols[end] === :Vararg + if is_vararg + arg_types = vcat( + map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols[1:end-1]), + :(Vararg{Mooncake.CoDual}), + ) + splat_symbol = Expr(Symbol("..."), arg_names[end]) + body = Expr( + :call, Mooncake.simple_zero_adjoint, arg_names[1:end-1]..., splat_symbol, + ) + else + arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) + body = Expr(:call, Mooncake.simple_zero_adjoint, arg_names...) + end + + # Return code to create a method of is_primitive and a rule. + ex = quote + Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true + $(construct_def(arg_names, arg_types, where_params, body)) + end + return esc(ex) +end + +""" + simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} + +Utility functionality for constructing `rrule!!`s for functions which produce adjoints which +always return zero. Equivalent to: +```julia +zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) +``` + +WARNING: this is only correct if the output of `primal(f)(map(primal, x)...)` does not alias +anything in `f` or `x`. This is always the case if the result is a bits type, but more care +may be required if it is not. + +Note: you should generally not call this function. Rather, you should make use of +`@zero_adjoint`. +""" +@inline function simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} + return zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) +end + +""" to_cr_tangent(t) Convert a Mooncake tangent into a type that ChainRules.jl `rrule`s expect to see. @@ -97,11 +271,15 @@ function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) return CoDual(y_primal, y_fdata), pb!! end +function construct_rrule_wrapper_def(arg_names, arg_types, where_params) + body = Expr(:call, rrule_wrapper, arg_names...) + return construct_def(arg_names, arg_types, where_params, body) +end + @doc""" @from_rrule ctx sig [has_kwargs=false] Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. -This macro is a thin wrapper around [`rrule_wrapper`](@ref). For example, ```julia @@ -145,22 +323,10 @@ only write a rule for these types: """ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) - # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. - if sig.head == :curly - @assert sig.args[1] == :Tuple - arg_type_symbols = sig.args[2:end] - where_params = nothing - elseif sig.head == :where - @assert sig.args[1].args[1] == :Tuple - arg_type_symbols = sig.args[1].args[2:end] - where_params = sig.args[2:end] - else - throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) - end - + arg_type_symbols, where_params = parse_signature_expr(sig) arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) - rule_expr = construct_def(arg_names, arg_types, where_params) + rule_expr = construct_rrule_wrapper_def(arg_names, arg_types, where_params) if has_kwargs kw_sig = Expr(:curly, :Tuple, :(typeof(Core.kwcall)), :NamedTuple, arg_type_symbols...) @@ -168,7 +334,7 @@ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) kw_is_primitive = :(Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$kw_sig}) = true) kwcall_type = :(Mooncake.CoDual{typeof(Core.kwcall)}) nt_type = :(Mooncake.CoDual{<:NamedTuple}) - kwargs_rule_expr = construct_def( + kwargs_rule_expr = construct_rrule_wrapper_def( vcat(:_kwcall, :kwargs, arg_names), vcat(kwcall_type, nt_type, arg_types), where_params, @@ -186,17 +352,3 @@ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) end return esc(ex) end - -function construct_def(arg_names, arg_types, where_params) - arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) - def = Dict( - :head => :function, - :name => :(Mooncake.rrule!!), - :args => arg_exprs, - :body => Expr(:call, rrule_wrapper, arg_names...), - ) - if where_params !== nothing - def[:whereparams] = where_params - end - return ExprTools.combinedef(def) -end \ No newline at end of file diff --git a/test/interpreter/method_overlays.jl b/test/interpreter/method_overlays.jl deleted file mode 100644 index 306e9b2e6..000000000 --- a/test/interpreter/method_overlays.jl +++ /dev/null @@ -1,7 +0,0 @@ -overlay_tester(x) = 2x -Mooncake.@mooncake_overlay overlay_tester(x) = 3x - -@testset "method_overlays" begin - rule = Mooncake.build_rrule(Tuple{typeof(overlay_tester), Float64}) - @test value_and_gradient!!(rule, overlay_tester, 5.0) == (15.0, (NoTangent(), 3.0)) -end diff --git a/test/runtests.jl b/test/runtests.jl index c99426cdd..fc3213f4d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,8 +16,8 @@ include("front_matter.jl") include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) - include(joinpath("interpreter", "method_overlays.jl")) end + include("tools_for_rules.jl") include("interface.jl") include("config.jl") elseif test_group == "rrules" @@ -46,7 +46,6 @@ include("front_matter.jl") @info "tasks" include(joinpath("rrules", "tasks.jl")) end - include("chain_rules_interop.jl") elseif test_group == "integration_testing/misc" include(joinpath("integration_testing", "battery_tests.jl")) include(joinpath("ext", "dynamic_ppl.jl")) diff --git a/test/chain_rules_interop.jl b/test/tools_for_rules.jl similarity index 58% rename from test/chain_rules_interop.jl rename to test/tools_for_rules.jl index 23f7cf526..1799ff9d9 100644 --- a/test/chain_rules_interop.jl +++ b/test/tools_for_rules.jl @@ -1,3 +1,12 @@ +overlay_tester(x) = 2x +Mooncake.@mooncake_overlay overlay_tester(x) = 3x + +zero_tester(x) = 0 +Mooncake.@zero_adjoint MinimalCtx Tuple{typeof(zero_tester), Float64} + +vararg_zero_tester(x...) = 0 +Mooncake.@zero_adjoint MinimalCtx Tuple{typeof(vararg_zero_tester), Vararg} + module ChainRulesInteropTestResources using ChainRulesCore, LinearAlgebra, Mooncake @@ -85,28 +94,41 @@ end end -@testset "chain_rules_macro" begin - @testset "to_cr_tangent" for (t, t_cr) in Any[ - (5.0, 5.0), - (ones(5), ones(5)), - (NoTangent(), ChainRulesCore.NoTangent()), - ] - @test Mooncake.to_cr_tangent(t) == t_cr +@testset "tools_for_rules" begin + @testset "mooncake_overlay" begin + rule = Mooncake.build_rrule(Tuple{typeof(overlay_tester), Float64}) + @test value_and_gradient!!(rule, overlay_tester, 5.0) == (15.0, (NoTangent(), 3.0)) end - @testset "rules: $(typeof(fargs))" for fargs in Any[ - (ChainRulesInteropTestResources.bleh, 5.0, 4), - (ChainRulesInteropTestResources.test_sum, ones(5)), - (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), - (ChainRulesInteropTestResources.test_nothing,), - (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), - (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), - (ChainRulesInteropTestResources.test_kwargs, 5.0), - ] - test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) + @testset "zero_adjoint" begin + test_rule(sr(123), zero_tester, 5.0; is_primitive=true, perf_flag=:stability_and_allocs) + test_rule( + sr(123), vararg_zero_tester, 5.0, 4.0; + is_primitive=true, perf_flag=:stability_and_allocs, + ) + end + @testset "chain_rules_macro" begin + @testset "to_cr_tangent" for (t, t_cr) in Any[ + (5.0, 5.0), + (ones(5), ones(5)), + (NoTangent(), ChainRulesCore.NoTangent()), + ] + @test Mooncake.to_cr_tangent(t) == t_cr + end + @testset "rules: $(typeof(fargs))" for fargs in Any[ + (ChainRulesInteropTestResources.bleh, 5.0, 4), + (ChainRulesInteropTestResources.test_sum, ones(5)), + (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), + (ChainRulesInteropTestResources.test_nothing,), + (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (ChainRulesInteropTestResources.test_kwargs, 5.0), + ] + test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) + end + @testset "bad rdata" begin + f = ChainRulesInteropTestResources.test_bad_rdata + out, pb!! = Mooncake.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) + @test_throws MethodError pb!!(5.0) + end end - @testset "bad rdata" begin - f = ChainRulesInteropTestResources.test_bad_rdata - out, pb!! = Mooncake.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) - @test_throws MethodError pb!!(5.0) - end -end +end \ No newline at end of file From d386cab61b535b799a6e0f756a1434a139bd555c Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Oct 2024 17:56:57 +0100 Subject: [PATCH 44/62] Tidy up build_rrule calls --- docs/src/known_limitations.md | 2 +- test/interpreter/s2s_reverse_mode_ad.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/known_limitations.md b/docs/src/known_limitations.md index dd0c4fc86..9d72c55a6 100644 --- a/docs/src/known_limitations.md +++ b/docs/src/known_limitations.md @@ -131,7 +131,7 @@ function foo(x::Vector{Float64}) return unsafe_load(p) end -rule = build_rrule(get_interpreter(), Tuple{typeof(foo), Vector{Float64}}) +rule = build_rrule(Tuple{typeof(foo), Vector{Float64}}) Mooncake.value_and_gradient!!(rule, foo, [5.0, 4.0]) # output diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 6255a35ff..d99792047 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -260,7 +260,6 @@ end @test_throws( Mooncake.UnhandledLanguageFeatureException, Mooncake.build_rrule( - Mooncake.get_interpreter(), Tuple{typeof(Mooncake.TestResources.non_const_global_ref), Float64}, ) ) From 4cf73de80cabb97f98280f74f493b78a5bc06e08 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Oct 2024 18:29:42 +0100 Subject: [PATCH 45/62] Improve zero_adjoint docs --- docs/src/tools_for_rules.md | 1 + src/rrules/builtins.jl | 2 +- src/tools_for_rules.jl | 145 +++++++++++++++++++++++++----------- 3 files changed, 104 insertions(+), 44 deletions(-) diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md index db46a23b9..45eefcbc3 100644 --- a/docs/src/tools_for_rules.md +++ b/docs/src/tools_for_rules.md @@ -17,6 +17,7 @@ situation that the adjoint of the derivative of your function is always zero, yo straightforwardly write a rule by making use of the following: ```@docs Mooncake.@zero_adjoint +Mooncake.zero_adjoint ``` ## Using ChainRules.jl diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 8d58960f6..626ea2c80 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -75,7 +75,7 @@ macro inactive_intrinsic(name) (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name), Vararg}}) = true translate(::Val{Intrinsics.$name}) = $name function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any, N}) where {N} - return Mooncake.simple_zero_adjoint(f, args...) + return Mooncake.zero_adjoint(f, args...) end end return esc(expr) diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index 6f8d8a9e5..f7d590add 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -1,3 +1,35 @@ +# +# General utilities +# + +function parse_signature_expr(sig::Expr) + # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. + if sig.head == :curly + @assert sig.args[1] == :Tuple + arg_type_symbols = sig.args[2:end] + where_params = nothing + elseif sig.head == :where + @assert sig.args[1].args[1] == :Tuple + arg_type_symbols = sig.args[1].args[2:end] + where_params = sig.args[2:end] + else + throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) + end + return arg_type_symbols, where_params +end + +function construct_def(arg_names, arg_types, where_params, body) + name = :(Mooncake.rrule!!) + arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) + def = Dict(:head => :function, :name => name, :args => arg_exprs, :body => body) + where_params !== nothing && setindex!(def, where_params, :whereparams) + return ExprTools.combinedef(def) +end + +# +# Functionality supporting @mooncake_overlay +# + """ @mooncake_overlay method_expr @@ -68,28 +100,42 @@ macro mooncake_overlay(method_expr) return esc(combinedef(def)) end -function parse_signature_expr(sig::Expr) - # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. - if sig.head == :curly - @assert sig.args[1] == :Tuple - arg_type_symbols = sig.args[2:end] - where_params = nothing - elseif sig.head == :where - @assert sig.args[1].args[1] == :Tuple - arg_type_symbols = sig.args[1].args[2:end] - where_params = sig.args[2:end] - else - throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) - end - return arg_type_symbols, where_params -end +# +# Functionality supporting @zero_adjoint +# -function construct_def(arg_names, arg_types, where_params, body) - name = :(Mooncake.rrule!!) - arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) - def = Dict(:head => :function, :name => name, :args => arg_exprs, :body => body) - where_params !== nothing && setindex!(def, where_params, :whereparams) - return ExprTools.combinedef(def) +""" + zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} + +Utility functionality for constructing `rrule!!`s for functions which produce adjoints which +always return zero. + +NOTE: you should only make use of this function if you cannot make use of the +[`@zero_adjoint`](@ref) macro. + +You make use of this functionality by writing a method of `Mooncake.rrule!!`, and +passing all of its arguments (including the function itself) to this function. For example: +```jldoctest +julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual + +julia> foo(x::Vararg{Int}) = 5 +foo (generic function with 1 method) + +julia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true; + +julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...); + +julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData()) +(NoRData(), NoRData(), NoRData()) +``` + +WARNING: this is only correct if the output of `primal(f)(map(primal, x)...)` does not alias +anything in `f` or `x`. This is always the case if the result is a bits type, but more care +may be required if it is not. +``` +""" +@inline function zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} + return zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) end """ @@ -116,12 +162,37 @@ julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData()) (NoRData(), 0.0) ``` +Limited support for `Vararg`s is also available. For example +```jldoctest +julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive + +julia> foo_varargs(x...) = 5 +foo_varargs (generic function with 1 method) + +julia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg} + +julia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int}) +true + +julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData()) +(NoRData(), 0.0, NoRData()) +``` +Be aware that it is not currently possible to specify any of the type parameters of the +`Vararg`. For example, the signature `Tuple{typeof(foo), Vararg{Float64, 5}}` will not work +with this macro. + WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function `x -> x` will yield incorrect results. -As always, you should use [`Mooncake.TestUtils.test_rule`](@ref) to ensure that you've not +As always, you should use `Mooncake.TestUtils.test_rule` to ensure that you've not made a mistake. + +# Signatures Unsupported By This Macro + +If the signature you wish to apply `@zero_adjoint` to is not supported, for example because +it uses a `Vararg` with a type parameter, you can still make use of +[`zero_adjoint`](@ref). """ macro zero_adjoint(ctx, sig) @@ -137,11 +208,11 @@ macro zero_adjoint(ctx, sig) ) splat_symbol = Expr(Symbol("..."), arg_names[end]) body = Expr( - :call, Mooncake.simple_zero_adjoint, arg_names[1:end-1]..., splat_symbol, + :call, Mooncake.zero_adjoint, arg_names[1:end-1]..., splat_symbol, ) else arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) - body = Expr(:call, Mooncake.simple_zero_adjoint, arg_names...) + body = Expr(:call, Mooncake.zero_adjoint, arg_names...) end # Return code to create a method of is_primitive and a rule. @@ -152,25 +223,9 @@ macro zero_adjoint(ctx, sig) return esc(ex) end -""" - simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} - -Utility functionality for constructing `rrule!!`s for functions which produce adjoints which -always return zero. Equivalent to: -```julia -zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) -``` - -WARNING: this is only correct if the output of `primal(f)(map(primal, x)...)` does not alias -anything in `f` or `x`. This is always the case if the result is a bits type, but more care -may be required if it is not. - -Note: you should generally not call this function. Rather, you should make use of -`@zero_adjoint`. -""" -@inline function simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} - return zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) -end +# +# Functionality supporting @from_rrule +# """ to_cr_tangent(t) @@ -320,6 +375,10 @@ only write a rule for these types: ```julia @from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} ``` + +# Extended Help + +Under the hood, this functionality relies on two functions: """ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) From 6ffc0769732492f32f1257b3b3dce4655983fb27 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Oct 2024 18:38:03 +0100 Subject: [PATCH 46/62] Improve documentation of from_rrule --- src/tools_for_rules.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index f7d590add..e3d326b9a 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -185,7 +185,7 @@ WARNING: this is only correct if the output of the function does not alias any f function, or any of its arguments. For example, applying this macro to the function `x -> x` will yield incorrect results. -As always, you should use `Mooncake.TestUtils.test_rule` to ensure that you've not +As always, you should use [`TestUtils.test_rule`](@ref) to ensure that you've not made a mistake. # Signatures Unsupported By This Macro @@ -376,9 +376,14 @@ only write a rule for these types: @from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} ``` -# Extended Help +# Conversions Between Different Tangent Type Systems -Under the hood, this functionality relies on two functions: +Under the hood, this functionality relies on two functions: `Mooncake.to_cr_tangent`, and +`Mooncake.increment_and_get_rdata!`. These two functions handle conversion to / from +`Mooncake` tangent types and `ChainRulesCore` tangent types. This functionality is known to +work well for simple types, but has not been tested to a great extent on complicated +composite types. If `@from_rrule` does not work in your case because the required method of +either of these functions does not exist, please open an issue. """ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) From d2e0764f1a0769b3ea4ea8cb7221bbeab785c76a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 09:04:40 +0100 Subject: [PATCH 47/62] Fix formatting --- src/tools_for_rules.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index e3d326b9a..fb4c6aad2 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -387,13 +387,13 @@ either of these functions does not exist, please open an issue. """ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) - arg_type_symbols, where_params = parse_signature_expr(sig) - arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) - arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) + arg_type_syms, where_params = parse_signature_expr(sig) + arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_syms)) + arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_syms) rule_expr = construct_rrule_wrapper_def(arg_names, arg_types, where_params) if has_kwargs - kw_sig = Expr(:curly, :Tuple, :(typeof(Core.kwcall)), :NamedTuple, arg_type_symbols...) + kw_sig = Expr(:curly, :Tuple, :(typeof(Core.kwcall)), :NamedTuple, arg_type_syms...) kw_sig = where_params === nothing ? kw_sig : Expr(:where, kw_sig, where_params...) kw_is_primitive = :(Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$kw_sig}) = true) kwcall_type = :(Mooncake.CoDual{typeof(Core.kwcall)}) From 36ec2765682de4611bc12620de8625b239810b59 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 09:46:06 +0100 Subject: [PATCH 48/62] Explain what is new --- docs/src/index.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index fd72a2a0b..6a47bdc8d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -2,6 +2,11 @@ Documentation for Mooncake.jl is on its way! +Note (03/10/2024): Various bits of utility functionality are now carefully documented. This +includes how to change the code which Mooncake sees, declare that the derivative of a +function is zero, make use of existing `ChainRules.rrule`s to quicky create new rules in +Mooncake, and more. + Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl -- you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl. From ee838f485cebe2d69576fe4ed0d853f9782e5e2a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 09:49:45 +0100 Subject: [PATCH 49/62] Improve from_rrule documentation --- docs/Project.toml | 1 + docs/src/tools_for_rules.md | 1 - src/tools_for_rules.jl | 91 ++++++++++++++++++++++++++++++------- 3 files changed, 76 insertions(+), 17 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 92285210d..d606397e8 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md index 45eefcbc3..52715adad 100644 --- a/docs/src/tools_for_rules.md +++ b/docs/src/tools_for_rules.md @@ -27,7 +27,6 @@ These rules are methods of the `ChainRulesCore.rrule` function. There are some instances where there is it most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`. There is enough similarity between these two systems that most of the boilerplate code can be avoided. -The docstrings below explain this functionality, and how it should / should not be used. ```@docs Mooncake.@from_rrule diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index fb4c6aad2..0d521acef 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -336,25 +336,86 @@ end Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. -For example, -```julia -@from_rrule DefaultCtx Tuple{typeof(sin), Float64} +# Arguments + +- `ctx`: A Mooncake context type +- `sig`: the signature which you wish to assert should be a primitive in `Mooncake.jl`, and + use an existing `ChainRulesCore.rrule` to implement this functionality. +- `has_kwargs`: a `Bool` state whether or not the function has keyword arguments. This + feature has the same limitations as `ChainRulesCore.rrule` -- the derivative w.r.t. all + kwargs must be zero. + +# Example Usage + +## A Basic Example + +```jldoctest +julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils + +julia> using ChainRulesCore + +julia> foo(x::Real) = 5x; + +julia> function ChainRulesCore.rrule(::typeof(foo), x::Real) + foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω + return foo(x), foo_pb + end; + +julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} + +julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0) +(NoRData(), 5.0) + +julia> # Check that the rule works as intended. + TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true) +Test Passed ``` -would define a `Mooncake.rrule!!` for `sin` of `Float64`s by calling `ChainRulesCore.rrule`. -```julia -@from_rrule DefaultCtx Tuple{typeof(foo), Float64} true +## An Example with Keyword Arguments + +```jldoctest +julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils + +julia> using ChainRulesCore + +julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x; + +julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool) + foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω + return foo(x; cond), foo_pb + end; + +julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true + +julia> _, pb = rrule!!( + zero_fcodual(Core.kwcall), + zero_fcodual((cond=false, )), + zero_fcodual(foo), + zero_fcodual(5.0), + ); + +julia> pb(3.0) +(NoRData(), NoRData(), NoRData(), 12.0) + +julia> # Check that the rule works as intended. + TestUtils.test_rule( + Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true + ) +Test Passed ``` -would define a method of `Mooncake.rrule!!` which can handle keyword arguments. +Notice that, in order to access the kwarg method we must call the method of `Core.kwcall`, +as Mooncake's `rrule!!` does not itself permit the use of kwargs. -Limitations: it is your responsibility to ensure that +# Limitations + +It is your responsibility to ensure that 1. calls with signature `sig` do not mutate their arguments, 2. the output of calls with signature `sig` does not alias any of the inputs. As with all hand-written rules, you should definitely make use of [`TestUtils.test_rule`](@ref) to verify correctness on some test cases. -# A Note On Type Constraints +# Argument Type Constraints Many methods of `ChainRuleCore.rrule` are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature @@ -364,17 +425,15 @@ Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}} There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length. -Suffice it to say, you should not write rules for this package which are so generically +Suffice it to say, you should not write rules for _this_ package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the `ChainRulesCore.rrule` will work correctly, and leave this package to derive rules for the rest. -For example, in the above case you might be confident that the rule will behave correctly -for input types `Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}}`. You should therefore -only write a rule for these types: -```julia -@from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} -``` +For example, it is quite common to be confident that a given rule will work correctly for +any `Base.IEEEFloat` argument, i.e. `Union{Float16, Float32, Float64}`, but it is usually +not possible to know that the rule is correct for all possible subtypes of `Real` that +someone might define. # Conversions Between Different Tangent Type Systems From 98d484099215fb04e160b9cef1285745f2ae7a08 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 09:53:58 +0100 Subject: [PATCH 50/62] Formatting --- test/tools_for_rules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/tools_for_rules.jl b/test/tools_for_rules.jl index 1799ff9d9..3a85010fd 100644 --- a/test/tools_for_rules.jl +++ b/test/tools_for_rules.jl @@ -100,7 +100,9 @@ end @test value_and_gradient!!(rule, overlay_tester, 5.0) == (15.0, (NoTangent(), 3.0)) end @testset "zero_adjoint" begin - test_rule(sr(123), zero_tester, 5.0; is_primitive=true, perf_flag=:stability_and_allocs) + test_rule( + sr(123), zero_tester, 5.0; is_primitive=true, perf_flag=:stability_and_allocs + ) test_rule( sr(123), vararg_zero_tester, 5.0, 4.0; is_primitive=true, perf_flag=:stability_and_allocs, From 45f1a3841fe015a8f89a7c8883d5b2b0144b7d83 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 10:05:34 +0100 Subject: [PATCH 51/62] Fix formatting --- Project.toml | 2 +- ext/MooncakeNNlibExt.jl | 105 ++++++++++++++++++------------------ src/interpreter/ir_utils.jl | 6 +-- test/tools_for_rules.jl | 2 +- 4 files changed, 56 insertions(+), 59 deletions(-) diff --git a/Project.toml b/Project.toml index bb42013af..ebdb481fa 100644 --- a/Project.toml +++ b/Project.toml @@ -82,4 +82,4 @@ TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"] \ No newline at end of file +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"] diff --git a/ext/MooncakeNNlibExt.jl b/ext/MooncakeNNlibExt.jl index fbd1b2fa7..dabc8729c 100644 --- a/ext/MooncakeNNlibExt.jl +++ b/ext/MooncakeNNlibExt.jl @@ -1,65 +1,66 @@ module MooncakeNNlibExt - using NNlib, Random, Mooncake - using Base: IEEEFloat - using NNlib: dropout +using NNlib, Random, Mooncake +using Base: IEEEFloat +using NNlib: dropout - using NNlib: conv, depthwiseconv - import Mooncake: @from_rrule, DefaultCtx, MinimalCtx +using NNlib: conv, depthwiseconv +import Mooncake: @from_rrule, DefaultCtx, MinimalCtx - @from_rrule( - MinimalCtx, - Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, - ) - @from_rrule( +@from_rrule( + MinimalCtx, + Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, +) +@from_rrule( + MinimalCtx, + Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, + true, +) +@from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) +@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) +@from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) +@from_rrule( + MinimalCtx, + Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, +) +@from_rrule( + MinimalCtx, + Tuple{ + typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, + }, +) +@from_rrule( + MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} +) +@from_rrule( + MinimalCtx, + Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + true, +) +for conv in [:conv, :depthwiseconv] + local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) + + @eval @from_rrule( MinimalCtx, - Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, + Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, true, ) - @from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) - @from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) - @from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) - @from_rrule( - MinimalCtx, - Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, - ) - @from_rrule( - MinimalCtx, - Tuple{ - typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, - }, - ) - @from_rrule( - MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} - ) - @from_rrule( + @eval @from_rrule( MinimalCtx, - Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, true, ) - for conv in [:conv, :depthwiseconv] - local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) - - @eval @from_rrule( - MinimalCtx, - Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, - true, - ) - @eval @from_rrule( - MinimalCtx, - Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, - true, - ) - end +end +@eval @from_rrule( + MinimalCtx, + Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, +) +for pool in [:maxpool, :meanpool] @eval @from_rrule( - MinimalCtx, - Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, - true, + MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true ) - for pool in [:maxpool, :meanpool] - @eval @from_rrule( - MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true - ) - end - @from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) +end +@from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) + end diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 9370f11c7..2a00cd315 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -191,11 +191,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_u match = match::Core.MethodMatch meth = Base.func_for_method_checked(match.method, tt, match.sparams) (code, ty) = CC.typeinf_ircode( - interp, - meth, - match.spec_types, - match.sparams, - optimize_until, + interp, meth, match.spec_types, match.sparams, optimize_until ) if code === nothing push!(asts, match.method => Any) diff --git a/test/tools_for_rules.jl b/test/tools_for_rules.jl index 3a85010fd..e4fead18c 100644 --- a/test/tools_for_rules.jl +++ b/test/tools_for_rules.jl @@ -133,4 +133,4 @@ end @test_throws MethodError pb!!(5.0) end end -end \ No newline at end of file +end From 4b1fff1ea69e006b79a6ad1045265c2b360eb980 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 10:05:57 +0100 Subject: [PATCH 52/62] Add compat for ChainRulesCore in docs --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index d606397e8..9a898dfa2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,5 +6,6 @@ DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [compat] +ChainRulesCore = "1" Documenter = "1" Mooncake = "0.4.0" From 3a98c22d5f91f826507dbf5cb78898a029ce0bd5 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 10:06:16 +0100 Subject: [PATCH 53/62] Tidy up mooncake_method_table usage --- src/interpreter/abstract_interpretation.jl | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 90f716219..69ac3e2c7 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -16,6 +16,7 @@ end MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance, Core.CodeInstance}()) +# The method table used by `Mooncake.@mooncake_overlay`. Base.Experimental.@MethodTable mooncake_method_table struct MooncakeInterpreter{C} <: CC.AbstractInterpreter @@ -26,7 +27,6 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult} code_cache::MooncakeCache oc_cache::Dict{ClosureCacheKey, Any} - method_table_to_overlay::CC.MethodTable function MooncakeInterpreter( ::Type{C}; meta=nothing, @@ -36,18 +36,8 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::MooncakeCache=MooncakeCache(), oc_cache::Dict{ClosureCacheKey, Any}=Dict{ClosureCacheKey, Any}(), - method_table_to_overlay::CC.MethodTable=mooncake_method_table, ) where {C} - return new{C}( - meta, - world, - inf_params, - opt_params, - inf_cache, - code_cache, - oc_cache, - method_table_to_overlay, - ) + return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) end end @@ -104,7 +94,7 @@ function CC.setindex!( return setindex!(wvc.cache.dict, ci, mi) end function CC.method_table(interp::MooncakeInterpreter) - return CC.OverlayMethodTable(interp.world, interp.method_table_to_overlay) + return CC.OverlayMethodTable(interp.world, mooncake_method_table) end _type(x) = x @@ -123,9 +113,7 @@ function CC.inlining_policy( # Do not inline away primitives. argtype_tuple = Tuple{map(_type, argtypes)...} - if is_primitive(C, argtype_tuple) - return nothing - end + is_primitive(C, argtype_tuple) && return nothing # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. return @invoke CC.inlining_policy( From 2d54e51c9fb192f8d1102f0fe7092fcb77f412da Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 10:13:06 +0100 Subject: [PATCH 54/62] Add extra luxlib test --- test/ext/luxlib.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 1748d37fe..5d356c07f 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -5,6 +5,14 @@ (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), (false, :none, true, LuxLib.Impl.batched_matmul, randn(5, 4, 3), randn(4, 3, 3)), (false, :none, false, LuxLib.Impl.activation, Lux.relu, randn(5, 4)), + ( + false, :none, false, + LuxLib.Impl.bias_activation_loop!, + randn(5, 4, 3), + Lux.relu, + randn(5, 4, 3), + randn(4), + ), ( false, :none, false, LuxLib.Impl.activation_loop!, randn(5, 3), NNlib.gelu, randn(5, 3), From 31daf96eecb3dd80d658cd6acbdc2b4284b71abb Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 10:17:41 +0100 Subject: [PATCH 55/62] Add another luxlib test --- test/ext/luxlib.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 5d356c07f..4e5331dea 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -31,6 +31,21 @@ LuxLib.Utils.True(), LuxLib.Utils.True(), ), + ( + false, :none, false, + function(opmode, act, x, m, sigma2, gamma, beta) + LuxLib.Impl.batchnorm_affine_normalize_internal( + opmode, act, x, m, sigma2, gamma, beta, 1e-3 + ) + end, + LuxLib.LoopedArrayOp(), + Lux.relu, + randn(5, 4, 3), + randn(4), + rand(4) .+ 1.0, + nothing, + nothing, + ), ], vec(map(Iterators.product( [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], From 40c8f1763c067ddb64d18f28b12c74f86e797ddf Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 10:20:44 +0100 Subject: [PATCH 56/62] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ebdb481fa..793bca4d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.5" +version = "0.4.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 672bf69c3ba0cf9e50217b1bd2fedb012cfe6349 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 3 Oct 2024 11:35:05 +0100 Subject: [PATCH 57/62] Update ext/MooncakeNNlibExt.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- ext/MooncakeNNlibExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MooncakeNNlibExt.jl b/ext/MooncakeNNlibExt.jl index dabc8729c..5fedbe7b4 100644 --- a/ext/MooncakeNNlibExt.jl +++ b/ext/MooncakeNNlibExt.jl @@ -51,7 +51,7 @@ for conv in [:conv, :depthwiseconv] true, ) end -@eval @from_rrule( +@from_rrule( MinimalCtx, Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, true, From b9be673b48a9a9250fa41636414f7c44b234a22d Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Oct 2024 12:11:40 +0100 Subject: [PATCH 58/62] Restrict CI to 1.10 for now --- .buildkite/pipeline.yml | 2 +- .github/workflows/CI.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index af00de5d7..b77011ae5 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -5,7 +5,7 @@ steps: - label: "Julia v1" plugins: - JuliaCI/julia#v1: - version: "1" + version: "1.10" - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: dirs: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1cc3f4109..d25a29a7e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,7 +35,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: '1' + version: '1.10' arch: x64 include-all-prereleases: false - uses: julia-actions/cache@v2 From 36b3cba1c46f8c390caaadb623cfbdded7d13eb9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Oct 2024 12:16:33 +0100 Subject: [PATCH 59/62] Apply suggestions from code review Co-authored-by: Markus Hauru --- docs/src/tools_for_rules.md | 2 +- src/tools_for_rules.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md index 52715adad..73fc48f4f 100644 --- a/docs/src/tools_for_rules.md +++ b/docs/src/tools_for_rules.md @@ -24,7 +24,7 @@ Mooncake.zero_adjoint [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the `ChainRulesCore.rrule` function. -There are some instances where there is it most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`. +There are some instances where it is most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`. There is enough similarity between these two systems that most of the boilerplate code can be avoided. diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index 0d521acef..712db722b 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -257,7 +257,7 @@ end Used to implement `rrule!!`s via `ChainRulesCore.rrule`. -Given a function `foo`, argument types `arg_types`, and a method `ChainRulesCore.rrule` of +Given a function `foo`, argument types `arg_types`, and a method of `ChainRulesCore.rrule` which applies to these, you can make use of this function as follows: ```julia Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} From 191843126252de8f24f7b6fd79c0a76bea87e30b Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Oct 2024 12:29:59 +0100 Subject: [PATCH 60/62] Restrict version consistently --- .github/workflows/CI.yml | 4 ++-- .github/workflows/documentation.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d25a29a7e..6044f287d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -62,7 +62,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: '1' + version: '1.10' arch: x64 include-all-prereleases: false - uses: julia-actions/cache@v2 @@ -81,7 +81,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: '1' + version: '1.10' arch: x64 include-all-prereleases: false - uses: julia-actions/cache@v2 diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 0ec2baa23..aeea6ad77 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: '1' + version: '1.10' arch: x64 include-all-prereleases: false - name: Install dependencies From 57fac4809c029a12e66609e5259adfc6aa586be7 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Oct 2024 12:30:09 +0100 Subject: [PATCH 61/62] Fix typo in docstring --- src/interpreter/ir_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 2a00cd315..915ce0a6c 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -179,7 +179,7 @@ Base.iterate(x::CC.MethodLookupResult, n::Int) = CC.iterate(x, n) sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance}, )::Tuple{IRCode, T} -Get the IR unique IR associated to `sig_or_mi` under `interp`. Throws `ArgumentError`s if +Get the unique IR associated to `sig_or_mi` under `interp`. Throws `ArgumentError`s if there is no code found, or if more than one `IRCode` instance returned. Returns a tuple containing the `IRCode` and its return type. From 62ba1f2ab81da1707955e48851b5a679fd086df1 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Oct 2024 12:32:21 +0100 Subject: [PATCH 62/62] Shove all testing functionality inside module --- test/tools_for_rules.jl | 46 +++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/test/tools_for_rules.jl b/test/tools_for_rules.jl index e4fead18c..d7602ba69 100644 --- a/test/tools_for_rules.jl +++ b/test/tools_for_rules.jl @@ -1,18 +1,17 @@ +module ToolsForRulesResources + +using ChainRulesCore, LinearAlgebra, Mooncake +using Base: IEEEFloat +using Mooncake: @mooncake_overlay, @zero_adjoint, @from_rrule, MinimalCtx, DefaultCtx + overlay_tester(x) = 2x -Mooncake.@mooncake_overlay overlay_tester(x) = 3x +@mooncake_overlay overlay_tester(x) = 3x zero_tester(x) = 0 -Mooncake.@zero_adjoint MinimalCtx Tuple{typeof(zero_tester), Float64} +@zero_adjoint MinimalCtx Tuple{typeof(zero_tester), Float64} vararg_zero_tester(x...) = 0 -Mooncake.@zero_adjoint MinimalCtx Tuple{typeof(vararg_zero_tester), Vararg} - -module ChainRulesInteropTestResources - -using ChainRulesCore, LinearAlgebra, Mooncake - -using Base: IEEEFloat -using Mooncake: DefaultCtx, @from_rrule +@zero_adjoint MinimalCtx Tuple{typeof(vararg_zero_tester), Vararg} # Test case with isbits data. @@ -96,15 +95,18 @@ end @testset "tools_for_rules" begin @testset "mooncake_overlay" begin - rule = Mooncake.build_rrule(Tuple{typeof(overlay_tester), Float64}) - @test value_and_gradient!!(rule, overlay_tester, 5.0) == (15.0, (NoTangent(), 3.0)) + f = ToolsForRulesResources.overlay_tester + rule = Mooncake.build_rrule(Tuple{typeof(f), Float64}) + @test value_and_gradient!!(rule, f, 5.0) == (15.0, (NoTangent(), 3.0)) end @testset "zero_adjoint" begin + f_zero = ToolsForRulesResources test_rule( - sr(123), zero_tester, 5.0; is_primitive=true, perf_flag=:stability_and_allocs + sr(123), ToolsForRulesResources.zero_tester, 5.0; + is_primitive=true, perf_flag=:stability_and_allocs, ) test_rule( - sr(123), vararg_zero_tester, 5.0, 4.0; + sr(123), ToolsForRulesResources.vararg_zero_tester, 5.0, 4.0; is_primitive=true, perf_flag=:stability_and_allocs, ) end @@ -117,18 +119,18 @@ end @test Mooncake.to_cr_tangent(t) == t_cr end @testset "rules: $(typeof(fargs))" for fargs in Any[ - (ChainRulesInteropTestResources.bleh, 5.0, 4), - (ChainRulesInteropTestResources.test_sum, ones(5)), - (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), - (ChainRulesInteropTestResources.test_nothing,), - (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), - (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), - (ChainRulesInteropTestResources.test_kwargs, 5.0), + (ToolsForRulesResources.bleh, 5.0, 4), + (ToolsForRulesResources.test_sum, ones(5)), + (ToolsForRulesResources.test_scale, 5.0, randn(3)), + (ToolsForRulesResources.test_nothing,), + (Core.kwcall, (y=true, ), ToolsForRulesResources.test_kwargs, 5.0), + (Core.kwcall, (y=false, ), ToolsForRulesResources.test_kwargs, 5.0), + (ToolsForRulesResources.test_kwargs, 5.0), ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end @testset "bad rdata" begin - f = ChainRulesInteropTestResources.test_bad_rdata + f = ToolsForRulesResources.test_bad_rdata out, pb!! = Mooncake.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) @test_throws MethodError pb!!(5.0) end