Skip to content

Commit

Permalink
Tidy up testing infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Sep 28, 2023
1 parent 375eb3d commit dd4d395
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 188 deletions.
3 changes: 1 addition & 2 deletions src/Taped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include("tracing.jl")
include("acceleration.jl")
include("tangents.jl")
include("reverse_mode_ad.jl")
include("test_utils.jl")

include(joinpath("rrules", "avoiding_non_differentiable_code.jl"))
include(joinpath("rrules", "blas.jl"))
Expand All @@ -29,8 +30,6 @@ include(joinpath("rrules", "misc.jl"))
include(joinpath("rrules", "umlaut_internals_rules.jl"))
include(joinpath("rrules", "unrolled_function.jl"))

include("test_utils.jl")

export
primal,
shadow,
Expand Down
271 changes: 271 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,274 @@ function test_rule_and_type_interactions(rng::AbstractRNG, x::P) where {P}
end

end



module TestResources

using ..Taped
using ..Taped: CoDual, Tangent, MutableTangent, NoTangent

using LinearAlgebra, Setfield


#
# Types used for testing purposes
#

function equal_field(a, b, f)
(!isdefined(a, f) || !isdefined(b, f)) && return true
return getfield(a, f) == getfield(b, f)
end

mutable struct Foo
x::Real
end

Base.:(==)(a::Foo, b::Foo) = equal_field(a, b, :x)

struct StructFoo
a::Real
b::Vector{Float64}
StructFoo(a::Float64, b::Vector{Float64}) = new(a, b)
StructFoo(a::Float64) = new(a)
end

Base.:(==)(a::StructFoo, b::StructFoo) = equal_field(a, b, :a) && equal_field(a, b, :b)

mutable struct MutableFoo
a::Float64
b::AbstractVector
MutableFoo(a::Float64, b::Vector{Float64}) = new(a, b)
MutableFoo(a::Float64) = new(a)
end

Base.:(==)(a::MutableFoo, b::MutableFoo) = equal_field(a, b, :a) && equal_field(a, b, :b)

for T in [Foo, StructFoo, MutableFoo]
@eval Taped._add_to_primal(p::$T, t) = Taped._containerlike_add_to_primal(p, t)
@eval Taped._diff(p::$T, q::$T) = Taped._containerlike_diff(p, q)
end




#
# Functions for which rules are implemented. Useful for testing basic test infrastructure,
# and ensuring that any modifications to the interface do not prevent certain functions from
# having rules written for them. If a function is found for which the design of `rrule!!`
# doesn't permit a rule to be written, it should be added here to prevent future
# regressions. For example, `primitive_setfield!` was added because a particular iteration
# of the design did not allow the implementation of a correct rule.
# The hope is that this list of functions catches any issues early, before a large-scale
# re-write of the rules begins.
#

p_sin(x) = sin(x)

function Taped.rrule!!(::CoDual{typeof(p_sin)}, x::CoDual{Float64, Float64})
p_sin_pb!!(ȳ::Float64, df, dx) = df, dx +* cos(primal(x))
return CoDual(sin(primal(x)), zero(Float64)), p_sin_pb!!
end

p_mul(x, y) = x * y

function Taped.rrule!!(::CoDual{typeof(p_mul)}, x::CoDual{Float64}, y::CoDual{Float64})
p_mul_pb!!(z̄, df, dx, dy) = df, dx +* primal(y), dy +* primal(x)
return CoDual(primal(x) * primal(y), zero(Float64)), p_mul_pb!!
end

p_mat_mul!(C, A, B) = mul!(C, A, B)

function Taped.rrule!!(
::CoDual{typeof(p_mat_mul!)}, C::T, A::T, B::T
) where {T<:CoDual{Matrix{Float64}}}
C_old = copy(C)
function p_mat_mul_pb!!(C̄::Matrix{Float64}, df, _, Ā, B̄)
.+=* primal(B)'
.+= primal(A)' *
primal(C) .= primal(C_old)
shadow(C) .= shadow(C_old)
return df, C̄, Ā, B̄
end
mul!(primal(C), primal(A), primal(B))
shadow(C) .= 0
return C, p_mat_mul_pb!!
end

p_setfield!(value, name::Symbol, x) = setfield!(value, name, x)

function __setfield!(value::MutableTangent, name, x)
@set value.fields.$name = x
return x
end

function Taped.rrule!!(::CoDual{typeof(p_setfield!)}, value, name::CoDual{Symbol}, x)
_name = primal(name)
_value = primal(value)
_dvalue = shadow(value)
old_x = getfield(_value, _name)
old_dx = getfield(_dvalue.fields, _name).tangent

function p_setfield!_pb!!(dy, df, dvalue, dname, dx)

# Add all increments to dx.
dx = increment!!(dx, getfield(dvalue.fields, _name).tangent)
dx = increment!!(dx, dy)

# Restore old values.
setfield!(primal(value), _name, old_x)
# set_field_to_zero!!(shadow(value), _name) # this gives the correct answer, but
# I don't understand why, because I don't _really_ understand what I'm doing
# I need a better mental model of what is going on in order to know for certain
# whether this rule is implemented incorrectly, or if my tests are checking for
# the wrong thing. I'm quite sure that this zeroing-out line can't be correct,
# but I'm not entirely sure.
__setfield!(shadow(value), _name, old_dx)

return df, dvalue, dname, dx
end

y = CoDual(setfield!(_value, _name, primal(x)), __setfield!(_dvalue, _name, shadow(x)))
return y, p_setfield!_pb!!
end

const __A = randn(3, 3)

const PRIMITIVE_TEST_FUNCTIONS = Any[
(p_sin, 5.0),
(p_mul, 5.0, 4.0),
(p_mat_mul!, randn(4, 5), randn(4, 3), randn(3, 5)),
(p_mat_mul!, randn(3, 3), __A, __A),
(p_setfield!, Foo(5.0), :x, 4.0),
(p_setfield!, MutableFoo(5.0, randn(5)), :y, randn(6)),
]

#
# Tests for AD. There are not rules defined directly on these functions, and they require
# that most language primitives have rules defined.
#

test_sin(x) = sin(x)

test_cos_sin(x) = cos(sin(x))

test_isbits_multiple_usage(x::Float64) = Core.Intrinsics.mul_float(x, x)

test_getindex(x::AbstractArray{<:Real}) = x[1]

function test_mutation!(x::AbstractVector{<:Real})
x[1] = sin(x[2])
return x[1]
end

function test_for_loop(x)
for _ in 1:5
x = sin(x)
end
return x
end

function test_while_loop(x)
n = 3
while n > 0
x = cos(x)
n -= 1
end
return x
end

test_mutable_struct_basic(x) = Foo(x).x

test_mutable_struct_basic_sin(x) = sin(Foo(x).x)

function test_mutable_struct_setfield(x)
foo = Foo(1.0)
foo.x = x
return foo.x
end

function test_mutable_struct(x)
foo = Foo(x)
foo.x = sin(foo.x)
return foo.x
end

test_struct_partial_init(a::Float64) = StructFoo(a).a

test_mutable_partial_init(a::Float64) = MutableFoo(a).a

function test_naive_mat_mul!(C::Matrix{T}, A::Matrix{T}, B::Matrix{T}) where {T<:Real}
for p in 1:size(C, 1)
for q in 1:size(C, 2)
C[p, q] = zero(T)
for r in 1:size(A, 2)
C[p, q] += A[p, r] * B[r, q]
end
end
end
return C
end

test_diagonal_to_matrix(D::Diagonal) = Matrix(D)

relu(x) = max(x, zero(x))

test_mlp(x, W1, W2) = W2 * relu.(W1 * x)

const TEST_FUNCTIONS = [
(false, test_sin, 1.0),
(false, test_cos_sin, 2.0),
(false, test_isbits_multiple_usage, 5.0),
(false, test_getindex, [1.0, 2.0]),
(false, test_mutation!, [1.0, 2.0]),
(false, test_while_loop, 2.0),
(false, test_for_loop, 3.0),
(false, test_mutable_struct_basic, 5.0),
(false, test_mutable_struct_basic_sin, 5.0),
(false, test_mutable_struct_setfield, 4.0),
(false, test_mutable_struct, 5.0),
(false, test_struct_partial_init, 3.5),
(false, test_mutable_partial_init, 3.3),
(false, test_naive_mat_mul!, randn(2, 1), randn(2, 1), randn(1, 1)),
(false, (A, C) -> test_naive_mat_mul!(C, A, A), randn(2, 2), randn(2, 2)),
(false, sum, randn(3)),
(false, test_diagonal_to_matrix, Diagonal(randn(3))),
(false, ldiv!, randn(2, 2), Diagonal(randn(2)), randn(2, 2)),
(false, kron!, randn(4, 4), Diagonal(randn(2)), randn(2, 2)),
(false, test_mlp, randn(5, 2), randn(7, 5), randn(3, 7)),
]

function value_dependent_control_flow(x, n)
while n > 0
x = cos(x)
n -= 1
end
return x
end

my_setfield!(args...) = setfield!(args...)

function _setfield!(value::MutableTangent, name, x)
@set value.fields.$name = x
return x
end

function Taped.rrule!!(::Taped.CoDual{typeof(my_setfield!)}, value, name, x)
_name = primal(name)
old_x = isdefined(primal(value), _name) ? getfield(primal(value), _name) : nothing
function setfield!_pullback(dy, df, dvalue, ::NoTangent, dx)
new_dx = increment!!(dx, getfield(dvalue.fields, _name).tangent)
set_field_to_zero!!(dvalue, _name)
new_dx = increment!!(new_dx, dy)
old_x !== nothing && setfield!(primal(value), _name, old_x)
return df, dvalue, NoTangent(), new_dx
end
y = Taped.CoDual(
setfield!(primal(value), _name, primal(x)),
_setfield!(shadow(value), _name, shadow(x)),
)
return y, setfield!_pullback
end

end
14 changes: 10 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@ using Base: unsafe_load, pointer_from_objref
using Core: bitcast
using Core.Intrinsics: pointerref, pointerset
using FunctionWrappers: FunctionWrapper
using Taped: IntrinsicsWrappers, TestUtils, CoDual, to_reverse_mode_ad, _wrap_field

using Taped:
IntrinsicsWrappers,
TestUtils,
TestResources,
CoDual,
to_reverse_mode_ad,
_wrap_field

using .TestUtils:
test_rrule!!,
test_taped_rrule!!,
Expand All @@ -20,14 +28,12 @@ using .TestUtils:
populate_address_map!,
populate_address_map

include("test_resources.jl")

@testset "Taped.jl" begin
include("test_utils.jl")
include("tracing.jl")
include("acceleration.jl")
include("tangents.jl")
include("reverse_mode_ad.jl")
include("test_utils.jl")
@testset "rrules" begin
@info "avoiding_non_differentiable_code"
include(joinpath("rrules", "avoiding_non_differentiable_code.jl"))
Expand Down
Loading

0 comments on commit dd4d395

Please sign in to comment.