Skip to content

Commit

Permalink
Merge pull request #168 from LuxDL/ap/crc_macro
Browse files Browse the repository at this point in the history
Add an import from ChainRules macro
  • Loading branch information
ChrisRackauckas authored Apr 19, 2024
2 parents 596f1c5 + dd29f0f commit e384881
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 52 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.33"
version = "0.2.34"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -27,6 +28,7 @@ TrackerPDMatsExt = "PDMats"

[compat]
Adapt = "3, 4"
ChainRulesCore = "1.23"
DiffRules = "1.4"
ForwardDiff = "0.10"
Functors = "0.3, 0.4"
Expand Down
2 changes: 2 additions & 0 deletions src/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using MacroTools: @q, @forward

using DiffRules
using ForwardDiff
import ChainRulesCore as CRC
import LogExpFunctions
import NaNMath
import SpecialFunctions
Expand Down Expand Up @@ -71,6 +72,7 @@ end

include("idset.jl")
include("params.jl")
include("macros.jl")
include("lib/real.jl")
include("lib/array.jl")
include("back.jl")
Expand Down
60 changes: 9 additions & 51 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,59 +560,17 @@ dims)
return Y, dropout_back
end

depthwiseconv(x::TrackedArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)

@grad depthwiseconv(x, w, cdims::DepthwiseConvDims; kw...) =
depthwiseconv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇depthwiseconv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

conv(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)

@grad conv(x, w, cdims::DenseConvDims; kw...) =
conv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

∇conv_data(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)

@grad function ∇conv_data(y, w, cdims::DenseConvDims; kw...)
return (
∇conv_data(data(y), data(w), cdims; kw...),
Δ -> begin
return nobacksies(:conv,
(NNlib.conv(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((Δ, y))..., cdims; kw...),
nothing)
)
end
)
end

maxpool(x::TrackedArray, pdims::PoolDims; kw...) = track(maxpool, x, pdims; kw...)

@grad function maxpool(x, pdims::PoolDims; kw...)
y = maxpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
for (xType, wType) in [(:TrackedArray, :TrackedArray), (:AbstractArray, :TrackedArray),
(:TrackedArray, :AbstractArray)]
@eval begin
@grad_from_chainrules depthwiseconv(::$xType, ::$wType, ::DepthwiseConvDims; kw...)
@grad_from_chainrules conv(::$xType, ::$wType, ::DenseConvDims; kw...)
@grad_from_chainrules ∇conv_data(::$xType, ::$wType, ::DenseConvDims; kw...)
end
end

meanpool(x::TrackedArray, pdims::PoolDims; kw...) = track(meanpool, x, pdims; kw...)


@grad function meanpool(x, pdims::PoolDims; kw...)
y = meanpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:meanpool, NNlib.∇meanpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
end
@grad_from_chainrules maxpool(::TrackedArray, ::PoolDims; kw...)
@grad_from_chainrules meanpool(::TrackedArray, ::PoolDims; kw...)

# Broadcasting

Expand Down
91 changes: 91 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
@grad_from_chainrules f(args...; kwargs...)
The `@grad_from_chainrules` macro provides a way to import adjoints(rrule) defined in
ChainRules to Tracker. One must provide a method signature to import the corresponding
rrule. In the provided method signature, one should replace the types of arguments to which
one wants to take derivatives with respect with Tracker.TrackedReal and Tracker.TrackedArray
respectively. For example, we can import rrule of `f(x::Real, y::Array)`` like below:
Tracker.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
Tracker.@grad_from_chainrules f(x::TrackedReal, y::Array)
Tracker.@grad_from_chainrules f(x::Real, y::TrackedArray)
Acceptable type annotations are `TrackedReal`, `TrackedArray`, `TrackedVector`, and
`TrackedMatrix`. These can have parameters like `TrackedArray{Float32}`.
"""
macro grad_from_chainrules(fcall)
@assert isdefined(__module__, :Tracker) "Tracker not found in module $__module__. Please load `Tracker.jl`."
Meta.isexpr(fcall, :call) && length(fcall.args) 2 ||
error("`@grad_from_chainrules` has to be applied to a function signature")

f = fcall.args[1]
# Check if kwargs... splatting is present
kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] :
nothing
rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] :
fcall.args[2:end]
xs = map(rem_args) do x
Meta.isexpr(x, :(::)) || return x
length(x.args) == 1 && return :($(gensym())::$(x.args[1])) # ::T without var name
@assert length(x.args) == 2
return :($(x.args[1])::$(x.args[2])) # x::T
end
xs_untyped = map(xs) do x
Meta.isexpr(x, :(::)) || return x
return x.args[1]
end

untrack_args = map(enumerate(xs)) do (i, x)
Meta.isexpr(x, :(::)) || return (x, nothing)
name, type = x.args
type = __strip_type(type)
type in (:TrackedArray, :TrackedVector, :TrackedMatrix, :TrackedReal) || return (name, nothing)
xdata = gensym(name)
return xdata, :($(xdata) = $(Tracker.data)($(name)))
end
untrack_calls = filter(Base.Fix2(!==, nothing), last.(untrack_args))
@assert length(untrack_calls) > 0 "No tracked arguments found."
var_names = first.(untrack_args)

f_sym = Meta.quot(Symbol(f))

if kws_var === nothing
return esc(quote
$(f)($(xs...)) = $(Tracker.track)($(f), $(xs_untyped...))
function Tracker._forward(::typeof($(f)), $(xs...))
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...))
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> return Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end
return esc(quote
function $(f)($(xs...); $(kws_var)...)
return Tracker.track($(f), $(xs_untyped...); $(kws_var)...)
end
function Tracker._forward(::typeof($(f)), $(xs...); $(kws_var)...)
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...); $(kws_var)...)
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end

@inline __no_crctangent(::CRC.NoTangent) = nothing
@inline __no_crctangent(::CRC.ZeroTangent) = nothing
@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x)
@inline __no_crctangent(x) = x

@inline function __strip_type(type)
Meta.isexpr(type, :curly) && (type = type.args[1]) # Strip parameters from types
Meta.isexpr(type, :(.)) && (type = type.args[2]) # Strip Tracker from Tracker.<...>
type isa QuoteNode && (type = type.value) # Unwrap a QuoteNode
return type
end
193 changes: 193 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Adapted from https://github.com/JuliaDiff/Tracker.jl/blob/master/test/ChainRulesTests.jl
module ChainRulesTest # Run in isolatex environment

using LinearAlgebra
using ChainRulesCore
using Tracker
using Test

struct MyStruct end
f(::MyStruct, x) = sum(4x .+ 1)
f(x, y::MyStruct) = sum(4x .+ 1)
f(x) = sum(4x .+ 1)

rrule_f_singleargs = Ref(0)
rrule_f_mystruct_x = Ref(0)
rrule_f_x_mystruct = Ref(0)

function ChainRulesCore.rrule(::typeof(f), x)
rrule_f_singleargs[] += 1
r = f(x)
back(d) = NoTangent(), fill(4 * d, size(x))
return r, back
end
function ChainRulesCore.rrule(::typeof(f), ::MyStruct, x)
rrule_f_mystruct_x[] += 1
r = f(MyStruct(), x)
back(d) = NoTangent(), NoTangent(), fill(4 * d, size(x))
return r, back
end
function ChainRulesCore.rrule(::typeof(f), x, ::MyStruct)
rrule_f_x_mystruct[] += 1
r = f(x, MyStruct())
back(d) = NoTangent(), fill(4 * d, size(x)), NoTangent()
return r, back
end

Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray)
# test arg type hygiene
Tracker.@grad_from_chainrules f(::MyStruct, x::Tracker.TrackedArray)
Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray, y::MyStruct)

g(x, y) = sum(4x .+ 4y)

rrule_g_x_y = Ref(0)

function ChainRulesCore.rrule(::typeof(g), x, y)
rrule_g_x_y[] += 1
r = g(x, y)
back(d) = NoTangent(), fill(4 * d, size(x)), fill(4 * d, size(x))
return r, back
end

Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y)
Tracker.@grad_from_chainrules g(x, y::Tracker.TrackedArray)
Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y::Tracker.TrackedArray)

@testset "rrule in ChainRules and Tracker" begin
## ChainRules
# function f
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, input)
_, d = back(1)
@test output == f(input)
@test d == fill(4, size(input))
@test rrule_f_singleargs[] == 1
# function g
inputs = rand(3, 3), rand(3, 3)
output, back = ChainRulesCore.rrule(g, inputs...)
_, d1, d2 = back(1)
@test output == g(inputs...)
@test d1 == fill(4, size(inputs[1]))
@test d2 == fill(4, size(inputs[2]))
@test rrule_g_x_y[] == 1
end

@testset "custom struct input" begin
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, MyStruct(), input)
_, _, d = back(1)
@test output == f(MyStruct(), input)
@test d == fill(4, size(input))
@test rrule_f_mystruct_x[] == 1

output, back = ChainRulesCore.rrule(f, input, MyStruct())
_, d, _ = back(1)
@test output == f(input, MyStruct())
@test d == fill(4, size(input))
@test rrule_f_x_mystruct[] == 1
end

### Functions with varargs and kwargs
# Varargs
f_vararg(x, args...) = sum(4x .+ sum(args))

rrule_f_vararg = Ref(0)

function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
rrule_f_vararg[] += 1
r = f_vararg(x, args...)
back(d) = (NoTangent(), fill(4 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
return r, back
end

Tracker.@grad_from_chainrules f_vararg(x::Tracker.TrackedArray, args...)

@testset "Function with Varargs" begin
grads = Tracker.gradient(x -> f_vararg(x, 1, 2, 3) + 2, rand(3, 3))

@test grads[1] == fill(4, (3, 3))
@test rrule_f_vararg[] == 1
end

# Vargs and kwargs
f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))

rrule_f_kw = Ref(0)

function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
rrule_f_kw[] += 1
r = f_kw(x, args...; k=k, kwargs...)
back(d) = (NoTangent(), fill(4 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
return r, back
end

Tracker.@grad_from_chainrules f_kw(x::Tracker.TrackedArray, args...; k=1, kwargs...)

@testset "Function with Varargs and kwargs" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f_kw(x, 1, 2, 3; k=2, j=3) + 2, inputs)

@test results[1] == fill(4, size(inputs))
@test rrule_f_kw[] == 1
end

### Mix @grad and @grad_from_chainrules

h(x) = 10x
h(x::Tracker.TrackedArray) = Tracker.track(h, x)

grad_hcalls = Ref(0)

Tracker.@grad function h(x)
grad_hcalls[] += 1
xv = Tracker.data(x)
return h(xv), Δ ->* 10,) # use 7 asits derivatives
end

@testset "Tracker and ChainRules Mixed" begin
t(x) = g(x, h(x))
inputs = rand(3, 3)
results = Tracker.gradient(t, inputs)
@test results[1] == fill(44, size(inputs)) # 44 = 4 + 4 * 10
@test rrule_g_x_y[] == 2
@test grad_hcalls[] == 1
end

### Isolated Scope
module IsolatedModuleForTestingScoping

using ChainRulesCore, Test
using Tracker: Tracker, @grad_from_chainrules

f(x) = sum(4x .+ 1)

rrule_f_singleargs = Ref(0)

function ChainRulesCore.rrule(::typeof(f), x)
rrule_f_singleargs[] += 1
r = f(x)
back(d) = NoTangent(), fill(4 * d, size(x))
return r, back
end

@grad_from_chainrules f(x::Tracker.TrackedArray)

module SubModule
using Test
using Tracker: Tracker
using ..IsolatedModuleForTestingScoping: f, rrule_f_singleargs

@testset "rrule in Isolated Scope" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f(x) + 2, inputs)

@test results[1] == fill(4, size(inputs))
@test rrule_f_singleargs[] == 1
end

end # end of SubModule

end # end of IsolatedModuleForTestingScoping

end
Loading

0 comments on commit e384881

Please sign in to comment.