Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change == to ignore measure-zero branches #481

Merged
merged 22 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.32"
version = "0.10.33"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand Down
27 changes: 25 additions & 2 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,40 @@ for pred in UNARY_PREDICATES
@eval Base.$(pred)(d::Dual) = $(pred)(value(d))
end

for pred in BINARY_PREDICATES
# Before PR#481 this loop ran over this list:
# BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
# Not a minimal set, as Base defines some in terms of others.
for pred in [:isless, :<, :>, :(<=), :(>=)]
@eval begin
@define_binary_dual_op(
Base.$(pred),
$(pred)(value(x), value(y)),
$(pred)(value(x), y),
$(pred)(x, value(y))
$(pred)(x, value(y)),
)
end
end

Base.iszero(x::Dual) = iszero(value(x)) && iszero(partials(x)) # shortcut, equivalent to x == zero(x)

for pred in [:isequal, :(==)]
@eval begin
@define_binary_dual_op(
Base.$(pred),
$(pred)(value(x), value(y)) && $(pred)(partials(x), partials(y)),
$(pred)(value(x), y) && iszero(partials(x)),
$(pred)(x, value(y)) && iszero(partials(y)),
)
end
end

@define_binary_dual_op(
Base.:(!=),
(!=)(value(x), value(y)) || (!=)(partials(x), partials(y)),
(!=)(value(x), y) || !iszero(partials(x)),
(!=)(x, value(y)) || !iszero(partials(y)),
)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

########################
# Promotion/Conversion #
########################
Expand Down
2 changes: 1 addition & 1 deletion src/prelude.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, Rou

const UNARY_PREDICATES = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]

const BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
const DEFAULT_CHUNK_THRESHOLD = 12

struct Chunk{N} end

Expand Down
62 changes: 46 additions & 16 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ForwardDiff.:≺(::Int, ::Type{TestTag()}) = false
ForwardDiff.:≺(::Type{TestTag}, ::Type{OuterTestTag}) = true
ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false

for N in (0,3), M in (0,4), V in (Int, Float32)
@testset "Dual{Z,$V,$N} and Dual{Z,Dual{Z,$V,$M},$N}" for N in (0,3), M in (0,4), V in (Int, Float32)
println(" ...testing Dual{TestTag(),$V,$N} and Dual{TestTag(),Dual{TestTag(),$V,$M},$N}")

PARTIALS = Partials{N,V}(ntuple(n -> intrand(V), N))
Expand All @@ -44,6 +44,13 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
PARTIALS3 = Partials{N,V}(ntuple(n -> intrand(V), N))
PRIMAL3 = intrand(V)
FDNUM3 = Dual{TestTag()}(PRIMAL3, PARTIALS3)

if !allunique([PRIMAL, PRIMAL2, PRIMAL3])
@info "testing with non-unique primals" PRIMAL PRIMAL2 PRIMAL3
end
if N > 0 && !allunique([PARTIALS, PARTIALS2, PARTIALS3])
@info "testing with non-unique partials" PARTIALS PARTIALS2 PARTIALS3
end

M_PARTIALS = Partials{M,V}(ntuple(m -> intrand(V), M))
NESTED_PARTIALS = convert(Partials{N,Dual{TestTag(),V,M}}, PARTIALS)
Expand Down Expand Up @@ -231,15 +238,27 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
@test ForwardDiff.isconstant(one(NESTED_FDNUM))
@test ForwardDiff.isconstant(NESTED_FDNUM) == (N == 0)

@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2))
@test isequal(PRIMAL, PRIMAL2) == isequal(FDNUM, FDNUM2)

@test isequal(NESTED_FDNUM, Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS2), NESTED_PARTIALS2))
@test isequal(PRIMAL, PRIMAL2) == isequal(NESTED_FDNUM, NESTED_FDNUM2)

@test FDNUM == Dual{TestTag()}(PRIMAL, PARTIALS2)
@test (PRIMAL == PRIMAL2) == (FDNUM == FDNUM2)
@test (PRIMAL == PRIMAL2) == (NESTED_FDNUM == NESTED_FDNUM2)
# Recall that FDNUM = Dual{TestTag()}(PRIMAL, PARTIALS) has N partials,
# and FDNUM2 has everything with a 2, and all random numbers nonzero.
# M is the length of M_PARTIALS, which affects:
# NESTED_FDNUM = Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS), NESTED_PARTIALS)

@test (FDNUM == Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
@test isequal(NESTED_FDNUM, Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS2), NESTED_PARTIALS2)) == ((M_PARTIALS == M_PARTIALS2) && (NESTED_PARTIALS == NESTED_PARTIALS2))

if PRIMAL == PRIMAL2
@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
@test isequal(FDNUM, FDNUM2) == (PARTIALS == PARTIALS2)

@test (FDNUM == FDNUM2) == (PARTIALS == PARTIALS2)
@test (NESTED_FDNUM == NESTED_FDNUM2) == ((M_PARTIALS == M_PARTIALS2) && (NESTED_PARTIALS == NESTED_PARTIALS2))
else
@test !isequal(FDNUM, FDNUM2)

@test FDNUM != FDNUM2
@test NESTED_FDNUM != NESTED_FDNUM2
end

@test isless(Dual{TestTag()}(1, PARTIALS), Dual{TestTag()}(2, PARTIALS2))
@test !(isless(Dual{TestTag()}(1, PARTIALS), Dual{TestTag()}(1, PARTIALS2)))
Expand Down Expand Up @@ -344,7 +363,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
@test typeof(WIDE_NESTED_FDNUM) === Dual{TestTag(),Dual{TestTag(),WIDE_T,M},N}

@test value(WIDE_FDNUM) == PRIMAL
@test value(WIDE_NESTED_FDNUM) == PRIMAL
@test (value(WIDE_NESTED_FDNUM) == PRIMAL) == (M == 0)

@test convert(Dual, FDNUM) === FDNUM
@test convert(Dual, NESTED_FDNUM) === NESTED_FDNUM
Expand Down Expand Up @@ -395,6 +414,8 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
#----------#

if M > 0 && N > 0
# Recall that FDNUM = Dual{TestTag()}(PRIMAL, PARTIALS) has N partials,
# all random numbers nonzero, and FDNUM2 another draw. M only affects NESTED_FDNUM.
@test Dual{1}(FDNUM) / Dual{1}(PRIMAL) === Dual{1}(FDNUM / PRIMAL)
@test Dual{1}(PRIMAL) / Dual{1}(FDNUM) === Dual{1}(PRIMAL / FDNUM)
@test_broken Dual{1}(FDNUM) / FDNUM2 === Dual{1}(FDNUM / FDNUM2)
Expand All @@ -413,6 +434,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)

# Exponentiation #
#----------------#

# If V == Int, the LHS terms are Int's. Large inputs cause integer overflow
# within the generic fallback of `isapprox`, resulting in a DomainError.
# Promote to Float64 to avoid issues.
Expand Down Expand Up @@ -442,7 +464,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
@test abs(NESTED_FDNUM) === NESTED_FDNUM

if V != Int
for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
@testset "$f" for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if f in (:/, :rem2pi)
continue # Skip these rules
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
Expand Down Expand Up @@ -502,10 +524,14 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
else
@test dx isa Complex{<:Dual{TestTag()}}
@test dy isa Complex{<:Dual{TestTag()}}
@test real(value(dx)) == real(actualval)
@test real(value(dy)) == real(actualval)
@test imag(value(dx)) == imag(actualval)
@test imag(value(dy)) == imag(actualval)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
# @test real(value(dx)) == real(actualval)
# @test real(value(dy)) == real(actualval)
# @test imag(value(dx)) == imag(actualval)
# @test imag(value(dy)) == imag(actualval)
@test value(real(dx)) == real(actualval)
@test value(real(dy)) == real(actualval)
@test value(imag(dx)) == imag(actualval)
@test value(imag(dy)) == imag(actualval)
@test partials(real(dx), 1) ≈ real(actualdx) nans=true
@test partials(real(dy), 1) ≈ real(actualdy) nans=true
@test partials(imag(dx), 1) ≈ imag(actualdx) nans=true
Expand Down Expand Up @@ -568,6 +594,10 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
end
end

#############
# bug fixes #
#############

@testset "Exponentiation of zero" begin
x0 = 0.0
x1 = Dual{:t1}(x0, 1.0)
Expand Down
57 changes: 53 additions & 4 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module GradientTest
import Calculus

using Test
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Tag
using StaticArrays
Expand All @@ -19,7 +20,7 @@ x = [0.1, 0.2, 0.3]
v = f(x)
g = [-9.4, 15.6, 52.0]

for c in (1, 2, 3), tag in (nothing, Tag(f, eltype(x)))
@testset "Rosenbrock, chunk size = $c and tag = $(repr(tag))" for c in (1, 2, 3), tag in (nothing, Tag(f, eltype(x)))
println(" ...running hardcoded test with chunk size = $c and tag = $(repr(tag))")
cfg = ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{c}(), tag)

Expand Down Expand Up @@ -55,7 +56,7 @@ cfgx = ForwardDiff.GradientConfig(sin, x)
# test vs. Calculus.jl #
########################

for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
@testset "$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
v = f(X)
g = ForwardDiff.gradient(f, X)
@test isapprox(g, Calculus.gradient(f, X), atol=FINITEDIFF_ERROR)
Expand Down Expand Up @@ -83,9 +84,9 @@ end

println(" ...testing specialized StaticArray codepaths")

x = rand(3, 3)
@testset "$T" for T in (StaticArrays.SArray, StaticArrays.MArray)
x = rand(3, 3)

for T in (StaticArrays.SArray, StaticArrays.MArray)
sx = T{Tuple{3,3}}(x)

cfg = ForwardDiff.GradientConfig(nothing, x)
Expand Down Expand Up @@ -148,6 +149,10 @@ end
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
end

#############
# bug fixes #
#############

# Issue 399
@testset "chunk size zero" begin
f_const(x) = 1.0
Expand All @@ -162,11 +167,55 @@ end
@test_throws DimensionMismatch ForwardDiff.gradient(identity, fill(2pi, 10^6)) # chunk_mode_gradient
end

# Issue 548
@testset "ArithmeticStyle" begin
function f(p)
sum(collect(0.0:p[1]:p[2]))
end
@test ForwardDiff.gradient(f, [0.2,25.0]) == [7875.0, 0.0]
end

@testset "det with branches" begin
# Issue 197
det2(A) = return (
A[1,1]*(A[2,2]*A[3,3]-A[2,3]*A[3,2]) -
A[1,2]*(A[2,1]*A[3,3]-A[2,3]*A[3,1]) +
A[1,3]*(A[2,1]*A[3,2]-A[2,2]*A[3,1])
)

A = [1 0 0; 0 2 0; 0 pi 3]
@test det2(A) == det(A) == 6
@test istril(A)

∇A = [6 0 0; 0 3 -pi; 0 0 2]
@test ForwardDiff.gradient(det2, A) ≈ ∇A
@test ForwardDiff.gradient(det, A) ≈ ∇A

# And issue 407
@test ForwardDiff.hessian(det, A) ≈ ForwardDiff.hessian(det2, A)

# https://discourse.julialang.org/t/forwarddiff-and-zygote-return-wrong-jacobian-for-log-det-l/77961
S = [1.0 0.8; 0.8 1.0]
L = cholesky(S).L
@test ForwardDiff.gradient(L -> log(det(L)), Matrix(L)) ≈ [1.0 -1.3333333333333337; 0.0 1.666666666666667]
@test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) ≈ [1.0 -1.3333333333333337; 0.0 1.666666666666667]
end

@testset "branches in mul!" begin
a, b = rand(3,3), rand(3,3)

# Issue 536, version with 3-arg *, Julia 1.7:
@test ForwardDiff.derivative(x -> sum(x*a*b), 0.0) ≈ sum(a * b)

if VERSION >= v"1.3"
# version with just mul!
dx = ForwardDiff.derivative(0.0) do x
c = similar(a, typeof(x))
mul!(c, a, b, x, false)
sum(c)
end
@test dx ≈ sum(a * b)
end
end

end # module
8 changes: 8 additions & 0 deletions test/HessianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module HessianTest
import Calculus

using Test
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Tag
using StaticArrays
Expand Down Expand Up @@ -157,4 +158,11 @@ for T in (StaticArrays.SArray, StaticArrays.MArray)
@test DiffResults.hessian(sresult3) == DiffResults.hessian(result)
end

@testset "branches in dot" begin
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/551
H = [1 2 3; 4 5 6; 7 8 9];
@test ForwardDiff.hessian(x->dot(x,H,x), fill(0.00001, 3)) ≈ [2 6 10; 6 10 14; 10 14 18]
@test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) ≈ [2 6 10; 6 10 14; 10 14 18]
end

end # module
4 changes: 4 additions & 0 deletions test/JacobianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ for T in (StaticArrays.SArray, StaticArrays.MArray)
@test DiffResults.jacobian(sresult3) == DiffResults.jacobian(result)
end

#########
# misc. #
#########

@testset "dimension errors for jacobian" begin
@test_throws DimensionMismatch ForwardDiff.jacobian(identity, 2pi) # input
@test_throws DimensionMismatch ForwardDiff.jacobian(sum, fill(2pi, 2)) # vector_mode_jacobian
Expand Down
2 changes: 1 addition & 1 deletion test/PartialsTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ForwardDiff: Partials

samerng() = MersenneTwister(1)

for N in (0, 3), T in (Int, Float32, Float64)
@testset "Partials{$N,$T}" for N in (0, 3), T in (Int, Float32, Float64)
println(" ...testing Partials{$N,$T}")

VALUES = (rand(T,N)...,)
Expand Down
Loading