Skip to content

Commit

Permalink
Port tests from ReverseDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 15, 2024
1 parent ec33546 commit be6f8d2
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 2 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ Statistics = "1"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["PDMats", "Test"]
test = ["ChainRulesCore", "LinearAlgebra", "PDMats", "Test"]
9 changes: 8 additions & 1 deletion src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ macro grad_from_chainrules(fcall)
untrack_args = map(enumerate(xs)) do (i, x)
Meta.isexpr(x, :(::)) || return (x, nothing)
name, type = x.args
Meta.isexpr(type, :curly) && (type = type.args[1]) # Strip parameters from types
type = __strip_type(type)
type in (:TrackedArray, :TrackedVector, :TrackedMatrix, :TrackedReal) || return (name, nothing)
xdata = gensym(name)
return xdata, :($(xdata) = $(Tracker.data)($(name)))
Expand Down Expand Up @@ -82,3 +82,10 @@ end
@inline __no_crctangent(::CRC.ZeroTangent) = nothing
@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x)
@inline __no_crctangent(x) = x

@inline function __strip_type(type)
Meta.isexpr(type, :curly) && (type = type.args[1]) # Strip parameters from types
Meta.isexpr(type, :(.)) && (type = type.args[2]) # Strip Tracker from Tracker.<...>
type isa QuoteNode && (type = type.value) # Unwrap a QuoteNode
return type
end
183 changes: 183 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Adapted from https://github.com/JuliaDiff/Tracker.jl/blob/master/test/ChainRulesTests.jl
module ChainRulesTest # Run in isolatex environment

using LinearAlgebra
using ChainRulesCore
using Tracker
using Test

struct MyStruct end
f(::MyStruct, x) = sum(4x .+ 1)
f(x, y::MyStruct) = sum(4x .+ 1)
f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
r = f(x)
function back(d)
#=
The proper derivative of `f` is 4, but in order to
check if `ChainRulesCore.rrule` had taken over the compuation,
we define a rrule that returns 3 as `f`'s derivative.
After importing this rrule into Tracker, if we get 3
rather than 4 when we compute the derivative of `f`, it means
the importing mechanism works.
=#
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end
function ChainRulesCore.rrule(::typeof(f), ::MyStruct, x)
r = f(MyStruct(), x)
function back(d)
return NoTangent(), NoTangent(), fill(3 * d, size(x))
end
return r, back
end
function ChainRulesCore.rrule(::typeof(f), x, ::MyStruct)
r = f(x, MyStruct())
function back(d)
return NoTangent(), fill(3 * d, size(x)), NoTangent()
end
return r, back
end

Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray)
# test arg type hygiene
Tracker.@grad_from_chainrules f(::MyStruct, x::Tracker.TrackedArray)
Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray, y::MyStruct)

g(x, y) = sum(4x .+ 4y)

function ChainRulesCore.rrule(::typeof(g), x, y)
r = g(x, y)
function back(d)
# same as above, use 3 and 5 as the derivatives
return NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
end
return r, back
end

Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y)
Tracker.@grad_from_chainrules g(x, y::Tracker.TrackedArray)
Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y::Tracker.TrackedArray)

@testset "rrule in ChainRules and Tracker" begin
## ChainRules
# function f
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, input)
_, d = back(1)
@test output == f(input)
@test d == fill(3, size(input))
# function g
inputs = rand(3, 3), rand(3, 3)
output, back = ChainRulesCore.rrule(g, inputs...)
_, d1, d2 = back(1)
@test output == g(inputs...)
@test d1 == fill(3, size(inputs[1]))
@test d2 == fill(5, size(inputs[2]))
end

@testset "custom struct input" begin
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, MyStruct(), input)
_, _, d = back(1)
@test output == f(MyStruct(), input)
@test d == fill(3, size(input))

output, back = ChainRulesCore.rrule(f, input, MyStruct())
_, d, _ = back(1)
@test output == f(input, MyStruct())
@test d == fill(3, size(input))
end

### Functions with varargs and kwargs
# Varargs
f_vararg(x, args...) = sum(4x .+ sum(args))

function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
r = f_vararg(x, args...)
function back(d)
return (NoTangent(), fill(3 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
end
return r, back
end

Tracker.@grad_from_chainrules f_vararg(x::Tracker.TrackedArray, args...)

@testset "Function with Varargs" begin
grads = Tracker.gradient(x -> f_vararg(x, 1, 2, 3) + 2, rand(3, 3))

@test grads[1] == fill(3, (3, 3))
end

# Vargs and kwargs
f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))

function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
r = f_kw(x, args...; k=k, kwargs...)
function back(d)
return (NoTangent(), fill(3 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
end
return r, back
end

Tracker.@grad_from_chainrules f_kw(x::Tracker.TrackedArray, args...; k=1, kwargs...)

@testset "Function with Varargs and kwargs" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f_kw(x, 1, 2, 3; k=2, j=3) + 2, inputs)

@test results[1] == fill(3, size(inputs))
end

### Mix @grad and @grad_from_chainrules

h(x) = 10x
h(x::Tracker.TrackedArray) = Tracker.track(h, x)
Tracker.@grad function h(x)
xv = Tracker.data(x)
return h(xv), Δ ->* 7,) # use 7 asits derivatives
end

@testset "Tracker and ChainRules Mixed" begin
t(x) = g(x, h(x))
inputs = rand(3, 3)
results = Tracker.gradient(t, inputs)
@test results[1] == fill(38, size(inputs)) # 38 = 3 + 5 * 7
end

### Isolated Scope
module IsolatedModuleForTestingScoping
using ChainRulesCore
using Tracker: Tracker, @grad_from_chainrules

f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
r = f(x)
function back(d)
# return a distinguishable but improper grad
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end

@grad_from_chainrules f(x::Tracker.TrackedArray)

module SubModule
using Test
using Tracker: Tracker
using ..IsolatedModuleForTestingScoping: f
@testset "rrule in Isolated Scope" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f(x) + 2, inputs)

@test results[1] == fill(3, size(inputs))
end

end # end of SubModule
end # end of IsolatedModuleForTestingScoping

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Random.seed!(0)
@testset "Tracker" begin

include("tracker.jl")
include("chainrules.jl")

using Tracker: jacobian

Expand Down

0 comments on commit be6f8d2

Please sign in to comment.