diff --git a/Project.toml b/Project.toml index 11158846d..9788cf345 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.59" +version = "0.4.60" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 2c2fd5d95..f74cf02c1 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -720,7 +720,6 @@ function blas_vectors(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int) xs = Any[ randn(rng, P, p), view(randn(rng, P, p + 5), 3:(p + 2)), - view(randn(rng, P, 3p), 1:2:(2p)), view(randn(rng, P, 3p, 3), 1:2:(2p), 2), ] @assert all(x -> length(x) == p, xs) @@ -732,72 +731,52 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) t_flags = ['N', 'T', 'C'] alphas = [1.0, -0.25] betas = [0.0, 0.33] + Ps = [Float64, Float32] rng = rng_ctor(123456) test_cases = vcat( # gemv! - vec( - reduce( - vcat, - map(product(t_flags, [1, 3], [1, 2])) do (tA, M, N) - t = tA == 'N' - As = blas_matrices(rng, Float64, t ? M : N, t ? N : M) - xs = blas_vectors(rng, Float64, N) - ys = blas_vectors(rng, Float64, M) - flags = (false, :stability, (lb=1e-3, ub=10.0)) - return map(product(As, xs, ys)) do (A, x, y) - return (flags..., BLAS.gemv!, tA, randn(), A, x, randn(), y) - end - end, - ), - ), + map_prod(t_flags, [1, 3], [1, 2], Ps) do (tA, M, N, P) + As = blas_matrices(rng, P, tA == 'N' ? M : N, tA == 'N' ? N : M) + xs = blas_vectors(rng, P, N) + ys = blas_vectors(rng, P, M) + flags = (false, :stability, (lb=1e-3, ub=10.0)) + return map(As, xs, ys) do A, x, y + return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y) + end + end..., # symv! - vec( - reduce( - vcat, - map(product(['L', 'U'], alphas, betas)) do (uplo, α, β) - As = blas_matrices(rng, Float64, 5, 5) - ys = blas_vectors(rng, Float64, 5) - xs = blas_vectors(rng, Float64, 5) - return map(product(As, xs, ys)) do (A, x, y) - (false, :stability, nothing, BLAS.symv!, uplo, α, A, x, β, y) - end - end, - ), - ), + map_prod(['L', 'U'], alphas, betas, Ps) do (uplo, α, β, P) + As = blas_matrices(rng, P, 5, 5) + ys = blas_vectors(rng, P, 5) + xs = blas_vectors(rng, P, 5) + return map(As, xs, ys) do A, x, y + (false, :stability, nothing, BLAS.symv!, uplo, P(α), A, x, P(β), y) + end + end..., # gemm! - vec( - reduce( - vcat, - map(product(t_flags, t_flags, alphas, betas)) do (tA, tB, a, b) - As = blas_matrices(rng, Float64, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3) - Bs = blas_matrices(rng, Float64, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4) - Cs = blas_matrices(rng, Float64, 3, 5) - return map(product(As, Bs, Cs)) do (A, B, C) - (false, :none, nothing, BLAS.gemm!, tA, tB, a, A, B, b, C) - end - end, - ), - ), + map_prod(t_flags, t_flags, alphas, betas, Ps) do (tA, tB, a, b, P) + As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3) + Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4) + Cs = blas_matrices(rng, P, 3, 5) + return map(As, Bs, Cs) do A, B, C + (false, :none, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C) + end + end..., # symm! - vec( - reduce( - vcat, - map(product(['L', 'R'], ['L', 'U'], alphas, betas)) do (side, uplo, α, β) - nA = side == 'L' ? 5 : 7 - As = blas_matrices(rng, Float64, nA, nA) - Bs = blas_matrices(rng, Float64, 5, 7) - Cs = blas_matrices(rng, Float64, 5, 7) - return map(product(As, Bs, Cs)) do (A, B, C) - (false, :stability, nothing, BLAS.symm!, side, uplo, α, A, B, β, C) - end - end, - ), - ), + map_prod(['L', 'R'], ['L', 'U'], alphas, betas, Ps) do (side, ul, α, β, P) + nA = side == 'L' ? 5 : 7 + As = blas_matrices(rng, P, nA, nA) + Bs = blas_matrices(rng, P, 5, 7) + Cs = blas_matrices(rng, P, 5, 7) + return map(As, Bs, Cs) do A, B, C + (false, :stability, nothing, BLAS.symm!, side, ul, P(α), A, B, P(β), C) + end + end..., ) memory = Any[] @@ -807,6 +786,10 @@ end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) t_flags = ['N', 'T', 'C'] aliased_gemm! = (tA, tB, a, b, A, C) -> BLAS.gemm!(tA, tB, a, A, A, b, C) + Ps = [Float32, Float64] + uplos = ['L', 'U'] + dAs = ['N', 'U'] + rng = rng_ctor(123) test_cases = vcat( @@ -820,91 +803,80 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) # BLAS LEVEL 1 # - Any[ - (false, :none, nothing, BLAS.dot, 3, randn(5), 1, randn(4), 1), - (false, :none, nothing, BLAS.dot, 3, randn(6), 2, randn(4), 1), - (false, :none, nothing, BLAS.dot, 3, randn(6), 1, randn(9), 3), - (false, :none, nothing, BLAS.dot, 3, randn(12), 3, randn(9), 2), - (false, :none, nothing, BLAS.scal!, 10, 2.4, randn(30), 2), - ], + map(Ps) do P + Any[ + (false, :none, nothing, BLAS.dot, 3, randn(P, 5), 1, randn(P, 4), 1), + (false, :none, nothing, BLAS.dot, 3, randn(P, 6), 2, randn(P, 4), 1), + (false, :none, nothing, BLAS.dot, 3, randn(P, 6), 1, randn(P, 9), 3), + (false, :none, nothing, BLAS.dot, 3, randn(P, 12), 3, randn(P, 9), 2), + (false, :none, nothing, BLAS.scal!, 10, P(2.4), randn(P, 30), 2), + ] + end..., # # BLAS LEVEL 2 # # trmv! - vec( - reduce( - vcat, - map(product(['L', 'U'], t_flags, ['N', 'U'], [1, 3])) do (ul, tA, dA, N) - As = [randn(N, N), view(randn(15, 15), 3:(N + 2), 4:(N + 3))] - bs = [randn(N), view(randn(14), 4:(N + 3))] - return map(product(As, bs)) do (A, b) - (false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b) - end - end, - ), - ), + map_prod(uplos, t_flags, dAs, [1, 3], Ps) do (ul, tA, dA, N, P) + As = blas_matrices(rng, P, N, N) + bs = blas_vectors(rng, P, N) + return map(As, bs) do A, b + (false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b) + end + end..., # # BLAS LEVEL 3 # # aliased gemm! - vec( - map(product(t_flags, t_flags)) do (tA, tB) - A = randn(5, 5) - B = randn(5, 5) - (false, :none, nothing, aliased_gemm!, tA, tB, randn(), randn(), A, B) - end, - ), + map_prod(t_flags, t_flags, Ps) do (tA, tB, P) + As = blas_matrices(rng, P, 5, 5) + Bs = blas_matrices(rng, P, 5, 5) + return map_prod(As, Bs) do (A, B) + (false, :none, nothing, aliased_gemm!, tA, tB, randn(P), randn(P), A, B) + end + end..., # syrk! - vec( - map(product(['U', 'L'], t_flags)) do (uplo, t) - A = t == 'N' ? randn(3, 4) : randn(4, 3) - C = randn(3, 3) - return (false, :none, nothing, BLAS.syrk!, uplo, t, randn(), A, randn(), C) - end, - ), + map_prod(uplos, t_flags, Ps) do (uplo, t, P) + As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3) + C = randn(P, 3, 3) + flags = (false, :none, nothing) + return map(As) do A + return (flags..., BLAS.syrk!, uplo, t, randn(P), A, randn(P), C) + end + end..., # trmm! - vec( - reduce( - vcat, - map( - product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]) - ) do (side, ul, tA, dA, M, N) - t = tA == 'N' - R = side == 'L' ? M : N - As = [randn(R, R), view(randn(15, 15), 3:(R + 2), 4:(R + 3))] - Bs = [randn(M, N), view(randn(15, 15), 2:(M + 1), 5:(N + 4))] - flags = (false, :none, nothing) - return map(product(As, Bs)) do (A, B) - (flags..., BLAS.trmm!, side, ul, tA, dA, randn(), A, B) - end - end, - ), - ), + map_prod( + ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps + ) do (side, ul, tA, dA, M, N, P) + t = tA == 'N' + R = side == 'L' ? M : N + As = blas_matrices(rng, P, R, R) + Bs = blas_matrices(rng, P, M, N) + return map(As, Bs) do A, B + (false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, randn(P), A, B) + end + end..., # trsm! - vec( - reduce( - vcat, - map( - product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]) - ) do (side, ul, tA, dA, M, N) - t = tA == 'N' - R = side == 'L' ? M : N - As = [randn(R, R) + 5I, view(randn(15, 15), 3:(R + 2), 4:(R + 3)) + 5I] - Bs = [randn(M, N), view(randn(15, 15), 2:(M + 1), 5:(N + 4))] - flags = (false, :none, nothing) - return map(product(As, Bs)) do (A, B) - (flags..., BLAS.trsm!, side, ul, tA, dA, randn(), A, B) - end - end, - ), - ), + map_prod( + ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps + ) do (side, ul, tA, dA, M, N, P) + t = tA == 'N' + R = side == 'L' ? M : N + As = map(blas_matrices(rng, P, R, R)) do A + A[diagind(A)] .+= 1 + return A + end + Bs = blas_matrices(rng, P, M, N) + return map(As, Bs) do A, B + (false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, randn(P), A, B) + end + end..., ) memory = Any[] return test_cases, memory diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 783e5d2a2..837278906 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -744,8 +744,11 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) # Core.Intrinsics: (false, :stability, nothing, IntrinsicsWrappers.abs_float, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.abs_float, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.add_float, 4.0, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.add_float, 4.0f0, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.add_float_fast, 4.0, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.add_float_fast, 4.0f0, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.add_int, 1, 2), (false, :stability, nothing, IntrinsicsWrappers.and_int, 2, 3), ( @@ -793,28 +796,40 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability, nothing, IntrinsicsWrappers.checked_usub_int, 5, 4), (false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0, 4.0), (false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0, -3.0), + (false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0f0, 4.0f0), + (false, :stability, nothing, IntrinsicsWrappers.copysign_float, 5.0f0, -3.0f0), (false, :stability, nothing, IntrinsicsWrappers.ctlz_int, 5), (false, :stability, nothing, IntrinsicsWrappers.ctpop_int, 5), (false, :stability, nothing, IntrinsicsWrappers.cttz_int, 5), (false, :stability, nothing, IntrinsicsWrappers.div_float, 5.0, 3.0), (false, :stability, nothing, IntrinsicsWrappers.div_float_fast, 5.0, 3.0), + (false, :stability, nothing, IntrinsicsWrappers.div_float, 5.0f0, 3.0f0), + (false, :stability, nothing, IntrinsicsWrappers.div_float_fast, 5.0f0, 3.0f0), (false, :stability, nothing, IntrinsicsWrappers.eq_float, 5.0, 4.0), (false, :stability, nothing, IntrinsicsWrappers.eq_float, 4.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float, 5.0f0, 4.0f0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float, 4.0f0, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 5.0, 4.0), (false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 4.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 5.0f0, 4.0f0), + (false, :stability, nothing, IntrinsicsWrappers.eq_float_fast, 4.0f0, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.eq_int, 5, 4), (false, :stability, nothing, IntrinsicsWrappers.eq_int, 4, 4), (false, :stability, nothing, IntrinsicsWrappers.flipsign_int, 4, -3), (false, :stability, nothing, IntrinsicsWrappers.floor_llvm, 4.1), (false, :stability, nothing, IntrinsicsWrappers.fma_float, 5.0, 4.0, 3.0), + (false, :stability, nothing, IntrinsicsWrappers.fma_float, 5.0f0, 4.0f0, 3.0f0), (true, :stability_and_allocs, nothing, IntrinsicsWrappers.fpext, Float64, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.fpiseq, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.fpiseq, 4.0f1, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.fptosi, UInt32, 4.1), (false, :stability, nothing, IntrinsicsWrappers.fptoui, Int32, 4.1), (true, :stability, nothing, IntrinsicsWrappers.fptrunc, Float32, 5.0), (true, :stability, nothing, IntrinsicsWrappers.have_fma, Float64), (false, :stability, nothing, IntrinsicsWrappers.le_float, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.le_float, 4.0f1, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.le_float_fast, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.le_float_fast, 4.0f1, 4.0f0), # llvm_call -- NEEDS IMPLEMENTING AND TESTING ( false, @@ -825,17 +840,26 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) 0x0000000000000018, ), (false, :stability, nothing, IntrinsicsWrappers.lt_float, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.lt_float, 4.0f1, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.lt_float_fast, 4.1, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.lt_float_fast, 4.0f1, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.mul_float, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.mul_float, 5.0f0, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.mul_float_fast, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.mul_float_fast, 5.0f0, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.mul_int, 5, 4), (false, :stability, nothing, IntrinsicsWrappers.muladd_float, 5.0, 4.0, 3.0), + (false, :stability, nothing, IntrinsicsWrappers.muladd_float, 5.0f0, 4.0f0, 3.0f0), (false, :stability, nothing, IntrinsicsWrappers.ne_float, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.ne_float, 5.0f0, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.ne_float_fast, 5.0, 4.0), + (false, :stability, nothing, IntrinsicsWrappers.ne_float_fast, 5.0f0, 4.0f0), (false, :stability, nothing, IntrinsicsWrappers.ne_int, 5, 4), (false, :stability, nothing, IntrinsicsWrappers.ne_int, 5, 5), (false, :stability, nothing, IntrinsicsWrappers.neg_float, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.neg_float, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.neg_float_fast, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.neg_float_fast, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.neg_int, 5), (false, :stability, nothing, IntrinsicsWrappers.not_int, 5), (false, :stability, nothing, IntrinsicsWrappers.or_int, 5, 5), @@ -869,10 +893,14 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability, nothing, IntrinsicsWrappers.sle_int, 5, 4), (false, :stability, nothing, IntrinsicsWrappers.slt_int, 4, 5), (false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm_fast, 5.0), + (false, :stability, nothing, IntrinsicsWrappers.sqrt_llvm_fast, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.srem_int, 4, 1), (false, :stability, nothing, IntrinsicsWrappers.sub_float, 4.0, 1.0), + (false, :stability, nothing, IntrinsicsWrappers.sub_float, 4.0f0, 1.0f0), (false, :stability, nothing, IntrinsicsWrappers.sub_float_fast, 4.0, 1.0), + (false, :stability, nothing, IntrinsicsWrappers.sub_float_fast, 4.0f0, 1.0f0), (false, :stability, nothing, IntrinsicsWrappers.sub_int, 4, 1), (false, :stability, nothing, IntrinsicsWrappers.trunc_int, UInt8, 78), (false, :stability, nothing, IntrinsicsWrappers.trunc_llvm, 5.1), diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 935c0947f..6cfa99751 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -12,7 +12,7 @@ function rrule!!( ::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P} ) where {P<:IEEEFloat} yp = Base.FastMath.exp2_fast(primal(x)) - exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(2) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * P(log(2)) return CoDual(yp, NoFData()), exp2_fast_pb!! end @@ -21,7 +21,7 @@ function rrule!!( ::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P} ) where {P<:IEEEFloat} yp = Base.FastMath.exp10_fast(primal(x)) - exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(10) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * P(log(10)) return CoDual(yp, NoFData()), exp2_fast_pb!! end @@ -36,96 +36,107 @@ end @zero_adjoint MinimalCtx Tuple{typeof(log),Int} function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) - test_cases = Any[ - (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, 0.5), - (false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, 0.5), - (false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, 5.0), - (false, :stability_and_allocs, nothing, Base.FastMath.sincos, 3.0), - ] + test_cases = reduce( + vcat, + map([Float64, Float32]) do P + return Any[ + (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)), + (false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)), + (false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)), + (false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)), + ] + end, + ) memory = Any[] return test_cases, memory end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) - test_cases = Any[ - (false, :allocs, nothing, Base.FastMath.abs2_fast, -5.0), - (false, :allocs, nothing, Base.FastMath.abs_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.acos_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.acosh_fast, 1.2), - (false, :allocs, nothing, Base.FastMath.add_fast, 1.0, 2.0), - (false, :allocs, nothing, Base.FastMath.angle_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.asin_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.asinh_fast, 1.3), - (false, :allocs, nothing, Base.FastMath.atan_fast, 5.4), - (false, :allocs, nothing, Base.FastMath.atanh_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.cbrt_fast, 0.4), - (false, :allocs, nothing, Base.FastMath.cis_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.cmp_fast, 0.5, 0.4), - (false, :allocs, nothing, Base.FastMath.conj_fast, 0.4), - (false, :allocs, nothing, Base.FastMath.conj_fast, ComplexF64(0.5, 0.4)), - (false, :allocs, nothing, Base.FastMath.cos_fast, 0.4), - (false, :allocs, nothing, Base.FastMath.cosh_fast, 0.3), - (false, :allocs, nothing, Base.FastMath.div_fast, 5.0, 1.1), - (false, :allocs, nothing, Base.FastMath.eq_fast, 5.5, 5.5), - (false, :allocs, nothing, Base.FastMath.eq_fast, 5.5, 5.4), - (false, :allocs, nothing, Base.FastMath.expm1_fast, 5.4), - (false, :allocs, nothing, Base.FastMath.ge_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.ge_fast, 4.0, 5.0), - (false, :allocs, nothing, Base.FastMath.gt_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.gt_fast, 4.0, 5.0), - (false, :allocs, nothing, Base.FastMath.hypot_fast, 5.1, 3.2), - (false, :allocs, nothing, Base.FastMath.inv_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.isfinite_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.isinf_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.isnan_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.issubnormal_fast, 0.3), - (false, :allocs, nothing, Base.FastMath.le_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log10_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log1p_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log2_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.lt_fast, 0.5, 4.0), - (false, :allocs, nothing, Base.FastMath.lt_fast, 5.0, 0.4), - (false, :allocs, nothing, Base.FastMath.max_fast, 5.0, 4.0), - ( - false, - :none, - nothing, - Base.FastMath.maximum!_fast, - sin, - [0.0, 0.0], - [5.0 4.0; 3.0 2.0], - ), - (false, :allocs, nothing, Base.FastMath.maximum_fast, [5.0, 4.0, 3.0]), - (false, :allocs, nothing, Base.FastMath.min_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.min_fast, 4.0, 5.0), - ( - false, - :none, - nothing, - Base.FastMath.minimum!_fast, - sin, - [0.0, 0.0], - [5.0 4.0; 3.0 2.0], - ), - (false, :allocs, nothing, Base.FastMath.minimum_fast, [5.0, 3.0, 4.0]), - (false, :allocs, nothing, Base.FastMath.minmax_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.mul_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.ne_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.pow_fast, 5.0, 2.0), - # (:allocs, Base.FastMath.pow_fast, 5.0, 2), # errors -- ADD A RULE FOR ME! - # (:allocs, Base.FastMath.rem_fast, 5.0, 2.0), # error -- ADD A RULE FOR ME! - (false, :allocs, nothing, Base.FastMath.sign_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.sign_fast, -5.0), - (false, :allocs, nothing, Base.FastMath.sin_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.cos_fast, 4.0), - (false, :allocs, nothing, Base.FastMath.sincos_fast, 4.0), - (false, :allocs, nothing, Base.FastMath.sinh_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.sqrt_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.sub_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.tan_fast, 4.0), - (false, :allocs, nothing, Base.FastMath.tanh_fast, 0.5), - ] + test_cases = reduce( + vcat, + map([Float64, Float32]) do P + C = P === Float64 ? ComplexF64 : ComplexF32 + return Any[ + (false, :allocs, nothing, Base.FastMath.abs2_fast, P(-5.0)), + (false, :allocs, nothing, Base.FastMath.abs_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.acos_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.acosh_fast, P(1.2)), + (false, :allocs, nothing, Base.FastMath.add_fast, P(1.0), P(2.0)), + (false, :allocs, nothing, Base.FastMath.angle_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.asin_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.asinh_fast, P(1.3)), + (false, :allocs, nothing, Base.FastMath.atan_fast, P(5.4)), + (false, :allocs, nothing, Base.FastMath.atanh_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.cbrt_fast, P(0.4)), + (false, :allocs, nothing, Base.FastMath.cis_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.cmp_fast, P(0.5), P(0.4)), + (false, :allocs, nothing, Base.FastMath.conj_fast, P(0.4)), + (false, :allocs, nothing, Base.FastMath.conj_fast, C(0.5, 0.4)), + (false, :allocs, nothing, Base.FastMath.cos_fast, P(0.4)), + (false, :allocs, nothing, Base.FastMath.cosh_fast, P(0.3)), + (false, :allocs, nothing, Base.FastMath.div_fast, P(5.0), P(1.1)), + (false, :allocs, nothing, Base.FastMath.eq_fast, P(5.5), P(5.5)), + (false, :allocs, nothing, Base.FastMath.eq_fast, P(5.5), P(5.4)), + (false, :allocs, nothing, Base.FastMath.expm1_fast, P(5.4)), + (false, :allocs, nothing, Base.FastMath.ge_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.ge_fast, P(4.0), P(5.0)), + (false, :allocs, nothing, Base.FastMath.gt_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.gt_fast, P(4.0), P(5.0)), + (false, :allocs, nothing, Base.FastMath.hypot_fast, P(5.1), P(3.2)), + (false, :allocs, nothing, Base.FastMath.inv_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.isfinite_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.isinf_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.isnan_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.issubnormal_fast, P(0.3)), + (false, :allocs, nothing, Base.FastMath.le_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log10_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log1p_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log2_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.lt_fast, P(0.5), P(4.0)), + (false, :allocs, nothing, Base.FastMath.lt_fast, P(5.0), P(0.4)), + (false, :allocs, nothing, Base.FastMath.max_fast, P(5.0), P(4.0)), + ( + false, + :none, + nothing, + Base.FastMath.maximum!_fast, + sin, + P.([0.0, 0.0]), + P.([5.0 4.0; 3.0 2.0]), + ), + (false, :allocs, nothing, Base.FastMath.maximum_fast, P.([5.0, 4.0, 3.0])), + (false, :allocs, nothing, Base.FastMath.min_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.min_fast, P(4.0), P(5.0)), + ( + false, + :none, + nothing, + Base.FastMath.minimum!_fast, + sin, + P.([0.0, 0.0]), + P.([5.0 4.0; 3.0 2.0]), + ), + (false, :allocs, nothing, Base.FastMath.minimum_fast, P.([5.0, 3.0, 4.0])), + (false, :allocs, nothing, Base.FastMath.minmax_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.mul_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.ne_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.pow_fast, P(5.0), P(2.0)), + # (:allocs, Base.FastMath.pow_fast, P(5.0), 2), # errors -- NEEDS RULE! + # (:allocs, Base.FastMath.rem_fast, P(5.0), P(2.0)), # error -- NEEDS RULE! + (false, :allocs, nothing, Base.FastMath.sign_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.sign_fast, P(-5.0)), + (false, :allocs, nothing, Base.FastMath.sin_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.cos_fast, P(4.0)), + (false, :allocs, nothing, Base.FastMath.sincos_fast, P(4.0)), + (false, :allocs, nothing, Base.FastMath.sinh_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.sqrt_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.sub_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.tan_fast, P(4.0)), + (false, :allocs, nothing, Base.FastMath.tanh_fast, P(0.5)), + ] + end, + ) memory = Any[] return test_cases, memory end diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index 48571a8e3..490ad5587 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -1,620 +1,404 @@ -for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) - TInt = :(Ptr{BLAS.BlasInt}) - @eval @inline function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(fname))}}, - ::CoDual, # return type - ::CoDual, # argument types - ::CoDual, # nreq - ::CoDual, # calling convention - _M::CoDual{$TInt}, # Number of rows in matrix A. M >= 0 - _N::CoDual{$TInt}, # Number of cols in matrix A. N >= 0 - _A::CoDual{Ptr{$elty}}, # matrix of size (LDA, N) - _LDA::CoDual{$TInt}, # leading dimension of A - _IPIV::CoDual{$TInt}, # pivot indices - _INFO::CoDual{$TInt}, # some info of some kind - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - # Extract names. - M, N, LDA, IPIV, INFO = map(primal, (_M, _N, _LDA, _IPIV, _INFO)) - M_val = unsafe_load(M) - N_val = unsafe_load(N) - LDA_val = unsafe_load(LDA) - data_len = LDA_val * N_val - A, dA = primal(_A), tangent(_A) - - # This implementation is currently limited to square matrices, but should be - # extended when someone can find the time to do so. - @assert M_val === N_val - - # Store the initial state. - A_mat = wrap_ptr_as_view(A, LDA_val, M_val, N_val) - A_store = copy(A_mat) - - # Run the primal. - ccall( - $(blas_name(fname)), - Cvoid, - ($TInt, $TInt, Ptr{$elty}, $TInt, $TInt, $TInt), - M, - N, - A, - LDA, - IPIV, - INFO, - ) - - ipiv_vec = copy(unsafe_wrap(Array, IPIV, N_val)) - end +@is_primitive(MinimalCtx, Tuple{typeof(LAPACK.getrf!),AbstractMatrix{<:BlasRealFloat}}) +function rrule!!( + ::CoDual{typeof(LAPACK.getrf!)}, _A::CoDual{<:AbstractMatrix{P}} +) where {P<:BlasRealFloat} + A, dA = viewify(_A) + A_copy = copy(A) + + # Run the primal. + _, ipiv, code = LAPACK.getrf!(A) + + # Zero out the tangent. + dA .= zero(P) + + function getrf_pb!!(::NoRData) + _getrf_pb!(A, dA, ipiv, A_copy) + return NoRData(), NoRData() + end + dipiv = zero_tangent(ipiv) + return CoDual((_A.x, ipiv, code), (_A.dx, dipiv, NoFData())), getrf_pb!! +end - # Zero out the tangent. - foreach(n -> unsafe_store!(dA, zero($elty), n), 1:data_len) +@is_primitive( + MinimalCtx, + Tuple{ + typeof(Core.kwcall),NamedTuple,typeof(LAPACK.getrf!),AbstractMatrix{<:BlasRealFloat} + }, +) +function rrule!!( + ::CoDual{typeof(Core.kwcall)}, + _kwargs::CoDual{<:NamedTuple}, + ::CoDual{typeof(getrf!)}, + _A::CoDual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} + check = _kwargs.x.check + A, dA = viewify(_A) + A_copy = copy(A) + + # Run the primal. + _, ipiv, code = LAPACK.getrf!(A; check) + + # Zero out the tangent. + dA .= zero(P) + + function getrf_pb!!(::NoRData) + _getrf_pb!(A, dA, ipiv, A_copy) + return NoRData(), NoRData(), NoRData(), NoRData() + end + dipiv = zero_tangent(ipiv) + return CoDual((_A.x, ipiv, code), (_A.dx, dipiv, NoFData())), getrf_pb!! +end - dA = tangent(_A) - function getrf_pb!!(::NoRData) - GC.@preserve args begin - # Run reverse-pass. - L, U = UnitLowerTriangular(A_mat), UpperTriangular(A_mat) - dA_mat = wrap_ptr_as_view(dA, LDA_val, M_val, N_val) - dL, dU = tril(dA_mat, -1), UpperTriangular(dA_mat) +function _getrf_pb!(A, dA, ipiv, A_copy) - # Figure out the pivot matrix used. - p = LinearAlgebra.ipiv2perm(ipiv_vec, N_val) + # Run reverse-pass. + L = UnitLowerTriangular(A) + U = UpperTriangular(A) + dL = tril(dA, -1) + dU = UpperTriangular(dA) - # Compute pullback using Seth's method. - __dF = tril(L'dL, -1) + UpperTriangular(dU * U') - dA_mat .= (inv(L') * __dF * inv(U'))[invperm(p), :] + # Figure out the pivot matrix used. + p = LinearAlgebra.ipiv2perm(ipiv, size(A, 2)) - # Restore initial state. - A_mat .= A_store - end + # Compute pullback using Seth's method. + _dF = tril(L'dL, -1) + UpperTriangular(dU * U') + dA .= (inv(L') * _dF * inv(U'))[invperm(p), :] - return tuple_fill(NoRData(), Val(12 + Nargs)) - end - return zero_fcodual(Cvoid()), getrf_pb!! - end + # Restore initial state. + # ipiv .= ipiv_copy + A .= A_copy + + return nothing end -for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) - TInt = :(Ptr{BLAS.BlasInt}) - @eval @inline function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(fname))}}, - ::CoDual, # return type - ::CoDual, # argument types - ::CoDual, # nreq - ::CoDual, # calling convention - _ul::CoDual{Ptr{UInt8}}, - _tA::CoDual{Ptr{UInt8}}, - _diag::CoDual{Ptr{UInt8}}, - _N::CoDual{Ptr{BLAS.BlasInt}}, - _Nrhs::CoDual{Ptr{BLAS.BlasInt}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BLAS.BlasInt}}, - _B::CoDual{Ptr{$elty}}, - _ldb::CoDual{Ptr{BLAS.BlasInt}}, - _info::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - # Load in data. - ul_p, tA_p, diag_p = map(primal, (_ul, _tA, _diag)) - N_p, Nrhs_p, lda_p, ldb_p, info_p = map(primal, (_N, _Nrhs, _lda, _ldb, _info)) - ul, tA, diag, N, Nrhs, lda, ldb, info = map( - unsafe_load, (ul_p, tA_p, diag_p, N_p, Nrhs_p, lda_p, ldb_p, info_p) - ) - - A = wrap_ptr_as_view(primal(_A), lda, N, N) - B = wrap_ptr_as_view(primal(_B), ldb, N, Nrhs) - B_copy = copy(B) - - # Run the primal. - ccall( - $(blas_name(fname)), - Cvoid, - ( - Ptr{UInt8}, - Ptr{UInt8}, - Ptr{UInt8}, - Ptr{BlasInt}, - Ptr{BlasInt}, - Ptr{$elty}, - Ptr{BlasInt}, - Ptr{$elty}, - Ptr{BlasInt}, - Ptr{BlasInt}, - Clong, - Clong, - Clong, - ), - ul_p, - tA_p, - diag_p, - N_p, - Nrhs_p, - primal(_A), - lda_p, - primal(_B), - ldb_p, - info_p, - 1, - 1, - 1, - ) +@is_primitive( + MinimalCtx, + Tuple{ + typeof(trtrs!),Char,Char,Char,AbstractMatrix{P},AbstractVecOrMat{P} + } where {P<:BlasRealFloat}, +) +function rrule!!( + ::CoDual{typeof(trtrs!)}, + _uplo::CoDual{Char}, + _trans::CoDual{Char}, + _diag::CoDual{Char}, + _A::CoDual{<:AbstractMatrix{P}}, + _B::CoDual{<:AbstractVecOrMat{P}}, +) where {P<:BlasRealFloat} + # Extract everything and make a copy of B for the reverse-pass. + uplo, trans, diag = primal(_uplo), primal(_trans), primal(_diag) + A, dA = viewify(_A) + B, dB = viewify(_B) + B_copy = copy(B) + + # Run primal. + trtrs!(uplo, trans, diag, A, B) + + function trtrs_pb!!(::NoRData) + + # Compute cotangent of B. + LAPACK.trtrs!(uplo, trans == 'N' ? 'T' : 'N', diag, A, dB) + + # Compute cotangent of A. + if trans == 'N' + dA .-= tri!(dB * B', uplo, diag) + else + dA .-= tri!(B * dB', uplo, diag) end - _dA = tangent(_A) - _dB = tangent(_B) - function trtrs_pb!!(::NoRData) - GC.@preserve args begin - # Compute cotangent of B. - dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) - LAPACK.trtrs!(Char(ul), Char(tA) == 'N' ? 'T' : 'N', Char(diag), A, dB) - - # Compute cotangent of A. - dA = wrap_ptr_as_view(_dA, lda, N, N) - if Char(tA) == 'N' - dA .-= tri!(dB * B', Char(ul), Char(diag)) - else - dA .-= tri!(B * dB', Char(ul), Char(diag)) - end - - # Restore initial state. - B .= B_copy - end + # Restore initial state. + B .= B_copy - return tuple_fill(NoRData(), Val(16 + Nargs)) - end - return zero_fcodual(Cvoid()), trtrs_pb!! + return tuple_fill(NoRData(), Val(6)) end + return _B, trtrs_pb!! end -for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) - @eval @inline function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(fname))}}, - ::CoDual, # return type - ::CoDual, # argument types - ::CoDual, # nreq - ::CoDual, # calling convention - _tA::CoDual{Ptr{UInt8}}, - _N::CoDual{Ptr{BlasInt}}, - _Nrhs::CoDual{Ptr{BlasInt}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BlasInt}}, - _ipiv::CoDual{Ptr{BlasInt}}, - _B::CoDual{Ptr{$elty}}, - _ldb::CoDual{Ptr{BlasInt}}, - _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - # Load in values. - tA = Char(unsafe_load(primal(_tA))) - N, Nrhs, lda, ldb, info = map( - unsafe_load ∘ primal, (_N, _Nrhs, _lda, _ldb, _info) - ) - ipiv = unsafe_wrap(Vector{BlasInt}, primal(_ipiv), N) - A = wrap_ptr_as_view(primal(_A), lda, N, N) - B = wrap_ptr_as_view(primal(_B), ldb, N, Nrhs) - B0 = copy(B) - - # Pivot B. - p = LinearAlgebra.ipiv2perm(ipiv, N) - - if tA == 'N' - # Apply permutation matrix. - B .= B[p, :] - - # Run inv(L) * B and write result to B. - LAPACK.trtrs!('L', 'N', 'U', A, B) - B1 = copy(B) # record intermediate state for use in pullback. - - # Run inv(U) * B and write result to B. - LAPACK.trtrs!('U', 'N', 'N', A, B) - B2 = B - else - # Run inv(U)^T * B and write result to B. - LAPACK.trtrs!('U', 'T', 'N', A, B) - B1 = copy(B) # record intermediate state for use in pullback. - - # Run inv(L)^T * B and write result to B. - LAPACK.trtrs!('L', 'T', 'U', A, B) - B2 = B - - # Apply permutation matrix. - B2 .= B2[invperm(p), :] - end - - # We need to write to `info`. - unsafe_store!(primal(_info), 0) - end - - _dA = tangent(_A) - _dB = tangent(_B) - function getrs_pb!!(::NoRData) - GC.@preserve args begin - dA = wrap_ptr_as_view(_dA, lda, N, N) - dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) - - if tA == 'N' - - # Run pullback for inv(U) * B. - LAPACK.trtrs!('U', 'T', 'N', A, dB) - dA .-= tri!(dB * B2', 'U', 'N') +@is_primitive( + MinimalCtx, + Tuple{ + typeof(getrs!),Char,AbstractMatrix{P},AbstractVector{Int},AbstractVecOrMat{P} + } where {P<:BlasRealFloat} +) +function rrule!!( + ::CoDual{typeof(getrs!)}, + _trans::CoDual{Char}, + _A::CoDual{<:AbstractMatrix{P}}, + _ipiv::CoDual{<:AbstractVector{Int}}, + _B::CoDual{<:AbstractVecOrMat{P}}, +) where {P<:BlasRealFloat} + + # Extract data. + trans = _trans.x + A, dA = viewify(_A) + ipiv = _ipiv.x + B, dB = viewify(_B) + B0 = copy(B) + + # Pivot B. + p = LinearAlgebra.ipiv2perm(ipiv, size(A, 1)) + + if trans == 'N' + # Apply permutation matrix. + B .= B[p, :] + + # Run inv(L) * B and write result to B. + LAPACK.trtrs!('L', 'N', 'U', A, B) + B1 = copy(B) # record intermediate state for use in pullback. + + # Run inv(U) * B and write result to B. + LAPACK.trtrs!('U', 'N', 'N', A, B) + B2 = B + else + # Run inv(U)^T * B and write result to B. + LAPACK.trtrs!('U', 'T', 'N', A, B) + B1 = copy(B) # record intermediate state for use in pullback. + + # Run inv(L)^T * B and write result to B. + LAPACK.trtrs!('L', 'T', 'U', A, B) + B2 = B + + # Apply permutation matrix. + B2 .= B2[invperm(p), :] + end - # Run pullback for inv(L) * B. - LAPACK.trtrs!('L', 'T', 'U', A, dB) - dA .-= tri!(dB * B1', 'L', 'U') + function trtrs_pb!!(::NoRData) + if trans == 'N' - # Undo permutation. - dB .= dB[invperm(p), :] - else + # Run pullback for inv(U) * B. + LAPACK.trtrs!('U', 'T', 'N', A, dB) + dA .-= tri!(dB * B2', 'U', 'N') - # Undo permutation. - dB .= dB[p, :] - B2 .= B2[p, :] + # Run pullback for inv(L) * B. + LAPACK.trtrs!('L', 'T', 'U', A, dB) + dA .-= tri!(dB * B1', 'L', 'U') - # Run pullback for inv(L^T) * B. - LAPACK.trtrs!('L', 'N', 'U', A, dB) - dA .-= tri!(B2 * dB', 'L', 'U') + # Undo permutation. + dB .= dB[invperm(p), :] + else - # Run pullback for inv(U^T) * B. - LAPACK.trtrs!('U', 'N', 'N', A, dB) - dA .-= tri!(B1 * dB', 'U', 'N') - end + # Undo permutation. + dB .= dB[p, :] + B2 .= B2[p, :] - # Restore initial state. - B .= B0 - end + # Run pullback for inv(L^T) * B. + LAPACK.trtrs!('L', 'N', 'U', A, dB) + dA .-= tri!(B2 * dB', 'L', 'U') - return tuple_fill(NoRData(), Val(15 + Nargs)) + # Run pullback for inv(U^T) * B. + LAPACK.trtrs!('U', 'N', 'N', A, dB) + dA .-= tri!(B1 * dB', 'U', 'N') end - return zero_fcodual(Cvoid()), getrs_pb!! + + # Restore initial state. + B .= B0 + return tuple_fill(NoRData(), Val(5)) end + return _B, trtrs_pb!! end -for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) - @eval @inline function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(fname))}}, - ::CoDual, # return type - ::CoDual, # argument types - ::CoDual, # nreq - ::CoDual, # calling convention - _N::CoDual{Ptr{BlasInt}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BlasInt}}, - _ipiv::CoDual{Ptr{BlasInt}}, - _work::CoDual{Ptr{$elty}}, - _lwork::CoDual{Ptr{BlasInt}}, - _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - # Pull out data. - N_p, lda_p, lwork_p, info_p = map(primal, (_N, _lda, _lwork, _info)) - N, lda, lwork, info = map(unsafe_load, (N_p, lda_p, lwork_p, info_p)) - A_p = primal(_A) - A = wrap_ptr_as_view(A_p, lda, N, N) - A_copy = copy(A) - - # Run forwards-pass. - ccall( - $(blas_name(fname)), - Cvoid, - ( - Ptr{BlasInt}, - Ptr{$elty}, - Ptr{BlasInt}, - Ptr{BlasInt}, - Ptr{$elty}, - Ptr{BlasInt}, - Ptr{BlasInt}, - ), - N_p, - A_p, - lda_p, - primal(_ipiv), - primal(_work), - lwork_p, - info_p, - ) - - p = LinearAlgebra.ipiv2perm(unsafe_wrap(Array, primal(_ipiv), N), N) - end - - _dA = tangent(_A) - function getri_pb!!(::NoRData) - GC.@preserve args begin - if lwork != -1 - dA = wrap_ptr_as_view(_dA, lda, N, N) - A .= A[:, p] - dA .= dA[:, p] - - # Cotangent w.r.t. L. - dL = -(A' * dA) / UnitLowerTriangular(A_copy)' - dU = -(UpperTriangular(A_copy)' \ (dA * A')) - dA .= tri!(dL, 'L', 'U') .+ tri!(dU, 'U', 'N') - - # Restore initial state. - A .= A_copy - end - end - - return tuple_fill(NoRData(), Val(13 + Nargs)) - end - return zero_fcodual(Cvoid()), getri_pb!! +@is_primitive( + MinimalCtx, Tuple{typeof(getri!),AbstractMatrix{<:BlasRealFloat},AbstractVector{Int}}, +) +function rrule!!( + ::CoDual{typeof(getri!)}, + _A::CoDual{<:AbstractMatrix{<:BlasRealFloat}}, + _ipiv::CoDual{<:AbstractVector{Int}}, +) + # Extract args and copy A for reverse-pass. + A, dA = viewify(_A) + ipiv = _ipiv.x + A_copy = copy(A) + + # Run primal. + getri!(A, ipiv) + p = LinearAlgebra.ipiv2perm(ipiv, size(A, 1)) + + function getri_pb!!(::NoRData) + # Pivot. + A .= A[:, p] + dA .= dA[:, p] + + # Cotangent w.r.t. L. + dL = -(A' * dA) / UnitLowerTriangular(A_copy)' + dU = -(UpperTriangular(A_copy)' \ (dA * A')) + dA .= tri!(dL, 'L', 'U') .+ tri!(dU, 'U', 'N') + + # Restore initial state. + A .= A_copy + return NoRData(), NoRData(), NoRData() end + return _A, getri_pb!! end -__sym(X) = 0.5 * (X + X') - -for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) - @eval @inline function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(fname))}}, - ::CoDual, # return type - ::CoDual, # argument types - ::CoDual, # nreq - ::CoDual, # calling convention - _uplo::CoDual{Ptr{UInt8}}, - _N::CoDual{Ptr{BlasInt}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BlasInt}}, - _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - # Pull out the data. - uplo_p, N_p, A_p, lda_p, info_p = map(primal, (_uplo, _N, _A, _lda, _info)) - uplo, lda, N = map(unsafe_load, (uplo_p, lda_p, N_p)) - - # Make a copy of the initial state for later restoration. - A = wrap_ptr_as_view(A_p, lda, N, N) - A_copy = copy(A) - - # Run forwards-pass. - ccall( - $(blas_name(fname)), - Cvoid, - (Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}), - uplo_p, - N_p, - A_p, - lda_p, - info_p, - ) +__sym(X) = (X + X') / 2 + +@is_primitive(MinimalCtx, Tuple{typeof(potrf!),Char,AbstractMatrix{<:BlasRealFloat}}) +function rrule!!( + ::CoDual{typeof(potrf!)}, + _uplo::CoDual{Char}, + _A::CoDual{<:AbstractMatrix{<:BlasRealFloat}}, +) + # Extract args and take a copy of A. + uplo = _uplo.x + A, dA = viewify(_A) + A_copy = copy(A) + + # Run primal. + _, info = potrf!(uplo, A) + + function potrf_pb!!(::NoRData) + dA2 = dA + + # Compute cotangents. + N = size(A, 1) + if Char(uplo) == 'L' + E = LowerTriangular(2 * ones(N, N)) - Diagonal(ones(N)) + L = LowerTriangular(A) + B = L' \ (E' .* (dA2'L)) / L + dA .= 0.5 * __sym(B) .* E .+ triu!(dA2, 1) + else + E = UpperTriangular(2 * ones(N, N) - Diagonal(ones(N))) + U = UpperTriangular(A) + B = U \ ((U * dA2') .* E') / U' + dA .= 0.5 * __sym(B) .* E .+ tril!(dA2, -1) end - _dA = tangent(_A) - function potrf_pb!!(::NoRData) - GC.@preserve args begin - dA = wrap_ptr_as_view(_dA, lda, N, N) - dA2 = dA - - # Compute cotangents. - if Char(uplo) == 'L' - E = LowerTriangular(2 * ones(N, N)) - Diagonal(ones(N)) - L = LowerTriangular(A) - B = L' \ (E' .* (dA2'L)) / L - dA .= 0.5 * __sym(B) .* E .+ triu!(dA2, 1) - else - E = UpperTriangular(2 * ones(N, N) - Diagonal(ones(N))) - U = UpperTriangular(A) - B = U \ ((U * dA2') .* E') / U' - dA .= 0.5 * __sym(B) .* E .+ tril!(dA2, -1) - end - - # Restore initial state. - A .= A_copy - end + # Restore initial state. + A .= A_copy - return tuple_fill(NoRData(), Val(11 + Nargs)) - end - return zero_fcodual(Cvoid()), potrf_pb!! + return NoRData(), NoRData(), NoRData() end + return CoDual((_A.x, info), (_A.dx, NoFData())), potrf_pb!! end -for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) - @eval @inline function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(fname))}}, - ::CoDual, # return type - ::CoDual, # argument types - ::CoDual, # nreq - ::CoDual, # calling convention - _uplo::CoDual{Ptr{UInt8}}, - _N::CoDual{Ptr{BlasInt}}, - _Nrhs::CoDual{Ptr{BlasInt}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BlasInt}}, - _B::CoDual{Ptr{$elty}}, - _ldb::CoDual{Ptr{BlasInt}}, - _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - # Pull out the data. - uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p = map( - primal, (_uplo, _N, _Nrhs, _A, _lda, _B, _ldb, _info) - ) - uplo, lda, N, ldb, Nrhs = map(unsafe_load, (uplo_p, lda_p, N_p, ldb_p, Nrhs_p)) - - # Make a copy of the initial state for later restoration. - A = wrap_ptr_as_view(A_p, lda, N, N) - B = wrap_ptr_as_view(B_p, ldb, N, Nrhs) - B_copy = copy(B) - - # Run forwards-pass. - ccall( - $(blas_name(fname)), - Cvoid, - ( - Ptr{UInt8}, - Ptr{BlasInt}, - Ptr{BlasInt}, - Ptr{$elty}, - Ptr{BlasInt}, - Ptr{$elty}, - Ptr{BlasInt}, - Ptr{BlasInt}, - ), - uplo_p, - N_p, - Nrhs_p, - A_p, - lda_p, - B_p, - ldb_p, - info_p, - ) +@is_primitive( + MinimalCtx, + Tuple{ + typeof(potrs!),Char,AbstractMatrix{P},AbstractVecOrMat{P} + } where {P<:BlasRealFloat}, +) +function rrule!!( + ::CoDual{typeof(potrs!)}, + _uplo::CoDual{Char}, + _A::CoDual{<:AbstractMatrix{P}}, + _B::CoDual{<:AbstractVecOrMat{P}}, +) where {P<:BlasRealFloat} + + # Extract args and take a copy of B. + uplo = _uplo.x + A, dA = viewify(_A) + B, dB = viewify(_B) + B_copy = copy(B) + + # Run the primal. + potrs!(uplo, A, B) + + function potrs_pb!!(::NoRData) + + # Compute cotangents. + if uplo == 'L' + tmp = __sym(B_copy * dB') / LowerTriangular(A)' + dA .-= 2 .* tril!(LinearAlgebra.LAPACK.potrs!('L', A, tmp)) + LinearAlgebra.LAPACK.potrs!('L', A, dB) + else + tmp = UpperTriangular(A)' \ __sym(B_copy * dB') + dA .-= 2 .* triu!((tmp / UpperTriangular(A)) / UpperTriangular(A)') + LinearAlgebra.LAPACK.potrs!('U', A, dB) end - _dA = tangent(_A) - _dB = tangent(_B) - function potrs_pb!!(::NoRData) - GC.@preserve args begin - dA = wrap_ptr_as_view(_dA, lda, N, N) - dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) - - # Compute cotangents. - if Char(uplo) == 'L' - tmp = __sym(B_copy * dB') / LowerTriangular(A)' - dA .-= 2 .* tril!(LinearAlgebra.LAPACK.potrs!('L', A, tmp)) - LinearAlgebra.LAPACK.potrs!('L', A, dB) - else - tmp = UpperTriangular(A)' \ __sym(B_copy * dB') - dA .-= 2 .* triu!((tmp / UpperTriangular(A)) / UpperTriangular(A)') - LinearAlgebra.LAPACK.potrs!('U', A, dB) - end - - # Restore initial state. - B .= B_copy - end + # Restore initial state. + B .= B_copy - return tuple_fill(NoRData(), Val(14 + Nargs)) - end - return zero_fcodual(Cvoid()), potrs_pb!! + return tuple_fill(NoRData(), Val(4)) end + return _B, potrs_pb!! end -generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) = Any[], Any[] - -function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) - getrf_wrapper!(x, check) = getrf!(x; check) +function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) + rng = rng_ctor(123) + Ps = [Float64, Float32] + bools = [false, true] test_cases = vcat( # getrf! - [ - (false, :none, nothing, getrf_wrapper!, randn(5, 5), false), - (false, :none, nothing, getrf_wrapper!, randn(5, 5), true), - (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), false), - (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), true), - (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 2:7, 3:8), false), - (false, :none, nothing, getrf_wrapper!, view(randn(10, 10), 3:8, 2:7), true), - ], + map_prod(bools, Ps) do (check, P) + As = blas_matrices(rng, P, 5, 5) + ipiv = Vector{Int}(undef, 5) + return map(As) do A + (false, :none, nothing, getrf!, A) + end + end..., # trtrs - vec( - reduce( - vcat, - map( - product(['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2]) - ) do (ul, tA, diag, N, Nrhs) - As = [ - randn(N, N) + 10I, view(randn(15, 15) + 10I, 2:(N + 1), 2:(N + 1)) - ] - Bs = [randn(N, Nrhs), view(randn(15, 15), 4:(N + 3), 3:(N + 2))] - return map(product(As, Bs)) do (A, B) - (false, :none, nothing, trtrs!, ul, tA, diag, A, B) - end - end, - ), - ), + map_prod( + ['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2], Ps + ) do (ul, tA, diag, N, Nrhs, P) + As = blas_matrices(rng, P, N, N) + Bs = blas_matrices(rng, P, N, Nrhs) + return map(As, Bs) do A, B + (false, :none, nothing, trtrs!, ul, tA, diag, A, B) + end + end..., # getrs - vec( - reduce( - vcat, - map(product(['N', 'T'], [1, 9], [1, 2])) do (trans, N, Nrhs) - As = - getrf!.([ - randn(N, N) + 5I, view(randn(15, 15) + 5I, 2:(N + 1), 2:(N + 1)) - ]) - Bs = [randn(N, Nrhs), view(randn(15, 15), 4:(N + 3), 3:(Nrhs + 2))] - return map(product(As, Bs)) do ((A, ipiv), B) - (false, :none, nothing, getrs!, trans, A, ipiv, B) - end - end, - ), - ), + map_prod(['N', 'T'], [1, 9], [1, 2], Ps) do (trans, N, Nrhs, P) + As = map(blas_matrices(rng, P, N, N)) do A + A[diagind(A)] .+= 5 + return getrf!(A) + end + Bs = blas_matrices(rng, P, N, Nrhs) + return map(As, Bs) do (A, ipiv), B + (false, :none, nothing, getrs!, trans, A, ipiv, B) + end + end..., # getri - vec( - reduce( - vcat, - map([1, 9]) do N - As = - getrf!.([ - randn(N, N) + 5I, view(randn(15, 15) + I, 2:(N + 1), 2:(N + 1)) - ]) - As = getrf!.([randn(N, N) + 5I]) - return map(As) do (A, ipiv) - (false, :none, nothing, getri!, A, ipiv) - end - end, - ), - ), + map_prod([1, 9], Ps) do (N, P) + As = map(blas_matrices(rng, P, N, N)) do A + A[diagind(A)] .+= 5 + return getrf!(A) + end + return map(As) do (A, ipiv) + (false, :none, nothing, getri!, A, ipiv) + end + end..., # potrf - vec( - reduce( - vcat, - map([1, 3, 9]) do N - X = randn(N, N) - A = X * X' + I - return Any[ - (false, :none, nothing, potrf!, 'L', A), - (false, :none, nothing, potrf!, 'U', A), - ] - end, - ), - ), + map_prod([1, 3, 9], Ps) do (N, P) + As = map(blas_matrices(rng, P, N, N)) do A + A .= A * A' + I + return A + end + return map(['L', 'U'], As) do uplo, A + return (false, :none, nothing, potrf!, uplo, A) + end + end..., # potrs - vec( - reduce( - vcat, - map(product([1, 3, 9], [1, 2])) do (N, Nrhs) - X = randn(N, N) - A = X * X' + I - B = randn(N, Nrhs) - return Any[ - ( - false, - :none, - nothing, - potrs!, - 'L', - potrf!('L', copy(A))[1], - copy(B), - ), - ( - false, - :none, - nothing, - potrs!, - 'U', - potrf!('U', copy(A))[1], - copy(B), - ), - ] - end, - ), - ), + map_prod([1, 3, 9], [1, 2], Ps) do (N, Nrhs, P) + X = randn(rng, P, N, N) + A = X * X' + I + Bs = blas_matrices(rng, P, N, Nrhs) + return map(['L', 'U'], Bs) do uplo, B + (false, :none, nothing, potrs!, uplo, potrf!(uplo, copy(A))[1], copy(B)) + end + end..., ) memory = Any[] return test_cases, memory end + +function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) + rng = rng_ctor(123) + getrf_wrapper!(x, check) = getrf!(x; check) + test_cases = vcat(map_prod([false, true], [Float64, Float32]) do (check, P) + As = blas_matrices(rng, P, 5, 5) + # ipiv = Vector{Int}(undef, 5) + return map(As) do A + (false, :none, nothing, getrf_wrapper!, A, check) + end + end...) + memory = Any[] + return test_cases, memory +end diff --git a/src/rrules/linear_algebra.jl b/src/rrules/linear_algebra.jl index ce5e65ae1..a503fbb98 100644 --- a/src/rrules/linear_algebra.jl +++ b/src/rrules/linear_algebra.jl @@ -19,9 +19,13 @@ function rrule!!(::CoDual{typeof(exp)}, X::CoDual{Matrix{P}}) where {P<:IEEEFloa end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:linear_algebra}) - test_cases = Any[ - (false, :none, nothing, exp, randn(3, 3)), (false, :none, nothing, exp, randn(7, 7)) - ] + rng = rng_ctor(123) + Ps = [Float64, Float32] + test_cases = vcat( + map_prod([3, 7], Ps) do (N, P) + return (false, :none, nothing, exp, randn(rng, P, N, N)) + end, + ) memory = Any[] return test_cases, memory end diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index 2c297ff83..edbcb1520 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -92,17 +92,17 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_mat ) push!( test_cases, - (true, :stability, nothing, f, rand_inputs(rng, Float32, f, arity)...), + (false, :stability, nothing, f, rand_inputs(rng, Float32, f, arity)...), ) end # test cases for additional rules written in this file. push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.1)) - push!(test_cases, (true, :stability_and_allocs, nothing, sin, Float32(1.1))) + push!(test_cases, (false, :stability_and_allocs, nothing, sin, Float32(1.1))) push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.1)) - push!(test_cases, (true, :stability_and_allocs, nothing, cos, Float32(1.1))) + push!(test_cases, (false, :stability_and_allocs, nothing, cos, Float32(1.1))) push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.1)) - push!(test_cases, (true, :stability_and_allocs, nothing, exp, Float32(1.1))) + push!(test_cases, (false, :stability_and_allocs, nothing, exp, Float32(1.1))) memory = Any[] return test_cases, memory end diff --git a/src/test_utils.jl b/src/test_utils.jl index ef50e00d4..412753a82 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -166,8 +166,8 @@ function has_equal_data_internal( return x == y end function has_equal_data_internal( - x::Float64, y::Float64, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} -) + x::P, y::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) where {P<:Base.IEEEFloat} return (isapprox(x, y) && !isnan(x)) || (isnan(x) && isnan(y)) end function has_equal_data_internal( @@ -363,13 +363,19 @@ function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb: x_primal = _deepcopy(x) y_primal = x_primal[1](x_primal[2:end]...) - # Use finite differences to estimate vjps + # Use finite differences to estimate vjps. Compute the estimate at a range of different + # step sizes. We'll just require that one of them ends up being close to what AD gives. ẋ = map(_x -> randn_tangent(rng, _x), x) - ε = 1e-7 - x′ = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) - y′ = x′[1](x′[2:end]...) - ẏ = _scale(1 / ε, _diff(y′, y_primal)) - ẋ_post = map((_x′, _x_p) -> _scale(1 / ε, _diff(_x′, _x_p)), x′, x_primal) + fd_results = map([1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]) do ε + x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) + y′_l = x′_l[1](x′_l[2:end]...) + x′_r = _add_to_primal(x, _scale(-ε, ẋ), unsafe_perturb) + y′_r = x′_r[1](x′_r[2:end]...) + return ( + ẏ=_scale(1 / 2ε, _diff(y′_l, y′_r)), + ẋ_post=map((_x′, _x_p) -> _scale(1 / 2ε, _diff(_x′, _x_p)), x′_l, x′_r), + ) + end # Run rule on copies of `f` and `x`. We use randomly generated tangents so that we # can later verify that non-zero values do not get propagated by the rule. @@ -407,9 +413,19 @@ function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb: # Check that inputs have been returned to their original value. @test all(map(has_equal_data_up_to_undefs, x, map(primal, x_x̄_rule))) - # pullbacks increment, so have to compare to the incremented quantity. - @test _dot(ȳ_delta, ẏ) + _dot(x̄_delta, ẋ_post) ≈ _dot(x̄, ẋ) rtol = 1e-3 atol = - 1e-3 + # Pullbacks increment, so have to compare to the incremented quantity. Require only one + # precision to be close to the answer AD gives. i.e. prove that there exists a step size + # such that AD and central differences agree on the answer. + isapprox_results = map(fd_results) do result + ẏ, ẋ_post = result + return isapprox( + _dot(ȳ_delta, ẏ) + _dot(x̄_delta, ẋ_post), + _dot(x̄, ẋ); + rtol=1e-3, + atol=1e-3, + ) + end + @test any(isapprox_results) end get_address(x) = ismutable(x) ? pointer_from_objref(x) : nothing diff --git a/src/utils.jl b/src/utils.jl index 6374a2ae9..2a651605b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -228,3 +228,17 @@ One-liner which calls the `:new` instruction with type `T` with arguments `x`. @inline @generated function _new_(::Type{T}, x::Vararg{Any,N}) where {T,N} return Expr(:new, :T, map(n -> :(x[$n]), 1:N)...) end + +""" + flat_product(xs...) + +Equivalent to `vec(collect(Iterators.product(xs...)))`. +""" +flat_product(xs...) = vec(collect(Iterators.product(xs...))) + +""" + map_prod(f, xs...) + +Equivalent to `map(f, flat_product(xs...))`. +""" +map_prod(f, xs...) = map(f, flat_product(xs...)) diff --git a/test/ext/special_functions/special_functions.jl b/test/ext/special_functions/special_functions.jl index c8515deef..b5a44b43e 100644 --- a/test/ext/special_functions/special_functions.jl +++ b/test/ext/special_functions/special_functions.jl @@ -7,61 +7,71 @@ using Mooncake.TestUtils: test_rule # Rules in this file are only lightly tester, because they are all just @from_rrule rules. @testset "special_functions" begin - @testset for (perf_flag, f, x...) in [ - (:stability, airyai, 0.1), - (:stability, airyaix, 0.1), - (:stability, airyaiprime, 0.1), - (:stability, airybi, 0.1), - (:stability, airybiprime, 0.1), - (:stability_and_allocs, besselj0, 0.1), - (:stability_and_allocs, besselj1, 0.1), - (:stability_and_allocs, bessely0, 0.1), - (:stability_and_allocs, bessely1, 0.1), - (:stability_and_allocs, dawson, 0.1), - (:stability_and_allocs, digamma, 0.1), - (:stability_and_allocs, erf, 0.1), - (:stability_and_allocs, erf, 0.1, 0.5), - (:stability_and_allocs, erfc, 0.1), - (:stability_and_allocs, logerfc, 0.1), - (:stability_and_allocs, erfcinv, 0.1), - (:stability_and_allocs, erfcx, 0.1), - (:stability_and_allocs, logerfcx, 0.1), - (:stability_and_allocs, erfi, 0.1), - (:stability_and_allocs, erfinv, 0.1), - (:stability_and_allocs, gamma, 0.1), - (:stability_and_allocs, invdigamma, 0.1), - (:stability_and_allocs, trigamma, 0.1), - (:stability_and_allocs, polygamma, 3, 0.1), - (:stability_and_allocs, beta, 0.3, 0.1), - (:stability_and_allocs, logbeta, 0.3, 0.1), - (:stability_and_allocs, logabsgamma, 0.3), - (:stability_and_allocs, loggamma, 0.3), - (:stability_and_allocs, expint, 0.3), - (:stability_and_allocs, expintx, 0.3), - (:stability_and_allocs, expinti, 0.3), - (:stability_and_allocs, sinint, 0.3), - (:stability_and_allocs, cosint, 0.3), - (:stability_and_allocs, ellipk, 0.3), - (:stability_and_allocs, ellipe, 0.3), + @testset for (perf_flag, f, x...) in vcat( + map([Float64, Float32]) do P + return Any[ + (:stability, airyai, P(0.1)), + (:stability, airyaix, P(0.1)), + (:stability, airyaiprime, P(0.1)), + (:stability, airybi, P(0.1)), + (:stability, airybiprime, P(0.1)), + (:stability_and_allocs, besselj0, P(0.1)), + (:stability_and_allocs, besselj1, P(0.1)), + (:stability_and_allocs, bessely0, P(0.1)), + (VERSION >= v"1.11" ? :stability_and_allocs : :none, bessely1, P(0.1)), + (:stability_and_allocs, dawson, P(0.1)), + (:stability_and_allocs, digamma, P(0.1)), + (:stability_and_allocs, erf, P(0.1)), + (:stability_and_allocs, erf, P(0.1), P(0.5)), + (:stability_and_allocs, erfc, P(0.1)), + (:stability_and_allocs, logerfc, P(0.1)), + (:stability_and_allocs, erfcinv, P(0.1)), + (:stability_and_allocs, erfcx, P(0.1)), + (:stability_and_allocs, logerfcx, P(0.1)), + (:stability_and_allocs, erfi, P(0.1)), + (:stability_and_allocs, erfinv, P(0.1)), + (:stability_and_allocs, gamma, P(0.1)), + (:stability_and_allocs, invdigamma, P(0.1)), + (:stability_and_allocs, trigamma, P(0.1)), + (:stability_and_allocs, polygamma, 3, P(0.1)), + (:stability_and_allocs, beta, P(0.3), P(0.1)), + (:stability_and_allocs, logbeta, P(0.3), P(0.1)), + (:stability_and_allocs, logabsgamma, P(0.3)), + (:stability_and_allocs, loggamma, P(0.3)), + (:stability_and_allocs, expint, P(0.3)), + (:stability_and_allocs, expintx, P(0.3)), + (:stability_and_allocs, expinti, P(0.3)), + (:stability_and_allocs, sinint, P(0.3)), + (:stability_and_allocs, cosint, P(0.3)), + (:stability_and_allocs, ellipk, P(0.3)), + (:stability_and_allocs, ellipe, P(0.3)), + ] + end..., (:stability_and_allocs, logfactorial, 3), - ] + ) test_rule(StableRNG(123456), f, x...; perf_flag) end - @testset for (perf_flag, f, x...) in [ - (:none, logerf, 0.3, 0.5), # first branch - (:none, logerf, 1.1, 1.2), # second branch - (:none, logerf, -1.2, -1.1), # third branch - (:none, logerf, 0.3, 1.1), # fourth branch - (:allocs, SpecialFunctions.loggammadiv, 1.0, 9.0), + @testset for (perf_flag, f, x...) in vcat( + map([Float64, Float32]) do P + return Any[ + (:none, logerf, P(0.3), P(0.5)), # first branch + (:none, logerf, P(1.1), P(1.2)), # second branch + (:none, logerf, P(-1.2), P(-1.1)), # third branch + (:none, logerf, P(0.3), P(1.1)), # fourth branch + (:allocs, SpecialFunctions.loggammadiv, P(1.0), P(9.0)), + (:allocs, logabsbeta, P(0.3), P(0.1)), + ] + end..., + + # Functions which only support Float64. (:allocs, SpecialFunctions.gammax, 1.0), (:allocs, SpecialFunctions.rgammax, 3.0, 6.0), (:allocs, SpecialFunctions.rgamma1pm1, 0.1), (:allocs, SpecialFunctions.auxgam, 0.1), - (:allocs, logabsbeta, 0.3, 0.1), (:allocs, SpecialFunctions.loggamma1p, 0.3), (:allocs, SpecialFunctions.loggamma1p, -0.3), (:none, SpecialFunctions.lambdaeta, 5.0), - ] + ) test_rule(StableRNG(123456), f, x...; perf_flag, is_primitive=false) end end diff --git a/test/integration_testing/array/array.jl b/test/integration_testing/array/array.jl index cc12f4a62..baa26f6cc 100644 --- a/test/integration_testing/array/array.jl +++ b/test/integration_testing/array/array.jl @@ -424,7 +424,7 @@ _getter() = 5.0 (false, :allocs, fill!, randn(sr(0), 3, 2), randn(sr(9))), (false, :none, x -> filter(>(0), x), [0.5, -0.1, -0.4]), (false, :none, x -> filter(<(0), x), randn(sr(1), 2, 2)), - (false, :none, x -> findall(<(0), x), [0.5, 0.0, -0.3]), + (false, :none, x -> findall(<(0), x), [0.5, 0.1, -0.3]), (false, :allocs, x -> findfirst(<(0), x), [0.5, -0.1, -0.4]), (false, :allocs, x -> findlast(<(0), x), [0.5, -0.1, -0.4]), (false, :none, findmax, randn(sr(1), 2, 2)), diff --git a/test/integration_testing/distributions/distributions.jl b/test/integration_testing/distributions/distributions.jl index ebf96fe80..efbee197f 100644 --- a/test/integration_testing/distributions/distributions.jl +++ b/test/integration_testing/distributions/distributions.jl @@ -207,7 +207,7 @@ sr(n::Int) = StableRNG(n) ), randn(sr(2), 2, 3), ), - (:none, MatrixBeta(5, 6.0, 7.0), rand(sr(123456), MatrixBeta(5, 6.0, 6.0))), + (:none, MatrixBeta(5, 9.0, 10.0), rand(sr(123456), MatrixBeta(5, 9.0, 10.0))), ( :none, MatrixFDist(6.0, 7.0, _pdmat(randn(sr(1234), 5, 5))), diff --git a/test/integration_testing/logexpfunctions/logexpfunctions.jl b/test/integration_testing/logexpfunctions/logexpfunctions.jl index df3cec0a8..d9cf73512 100644 --- a/test/integration_testing/logexpfunctions/logexpfunctions.jl +++ b/test/integration_testing/logexpfunctions/logexpfunctions.jl @@ -8,39 +8,43 @@ using Mooncake.TestUtils: test_rule sr(n::Int) = StableRNG(n) @testset "logexpfunctions" begin - @testset for (perf_flag, f, x...) in [ - (:allocs, xlogx, 1.1), - (:allocs, xlogy, 0.3, 1.2), - (:allocs, xlog1py, 0.3, -0.5), - (:allocs, xexpx, -0.5), - (:allocs, xexpy, 1.0, -0.7), - (:allocs, logistic, 0.5), - (:allocs, logit, 0.3), - (:allocs, logcosh, 1.5), - (:allocs, logabssinh, 0.3), - (:allocs, log1psq, 0.3), - (:allocs, log1pexp, 0.1), - (:allocs, log1mexp, -0.5), - (:allocs, log2mexp, 0.1), - (:allocs, logexpm1, 0.1), - (:allocs, log1pmx, -0.95), - (:allocs, logmxp1, 0.02), - (:allocs, logaddexp, -0.5, 0.4), - (:allocs, logsubexp, -0.5, -5.0), - (:allocs, logsumexp, randn(sr(1), 5)), - (:allocs, logsumexp, randn(sr(2), 5, 4)), - (:allocs, logsumexp, randn(sr(3), 5, 4, 3)), - (:none, x -> logsumexp(x; dims=1), randn(sr(4), 5, 4)), - (:none, x -> logsumexp(x; dims=2), randn(sr(5), 5, 4)), - (:none, logsumexp!, rand(sr(6), 5), randn(sr(7), 5, 4)), - (:none, softmax, randn(sr(7), 10)), - (:allocs, cloglog, 0.5), - (:allocs, cexpexp, -0.3), - (:allocs, loglogistic, 0.5), - (:allocs, logitexp, -0.3), - (:allocs, log1mlogistic, -0.9), - (:allocs, logit1mexp, -0.6), - ] + @testset for (perf_flag, f, x...) in vcat( + map([Float64, Float32]) do P + return Any[ + (:allocs, xlogx, P(1.1)), + (:allocs, xlogy, P(0.3), P(1.2)), + (:allocs, xlog1py, P(0.3), -P(0.5)), + (:allocs, xexpx, -P(0.5)), + (:allocs, xexpy, P(1.0), -P(0.7)), + (:allocs, logistic, P(0.5)), + (:allocs, logit, P(0.3)), + (:allocs, logcosh, P(1.5)), + (:allocs, logabssinh, P(0.3)), + (:allocs, log1psq, P(0.3)), + (:allocs, log1pexp, P(0.1)), + (:allocs, log1mexp, -P(0.5)), + (:allocs, log2mexp, P(0.1)), + (:allocs, logexpm1, P(0.1)), + (:allocs, log1pmx, -P(0.95)), + (:allocs, logmxp1, P(0.02)), + (:allocs, logaddexp, -P(0.5), P(0.4)), + (:allocs, logsubexp, -P(0.5), -P(5.0)), + (:allocs, logsumexp, randn(sr(1), P, 5)), + (:allocs, logsumexp, randn(sr(2), P, 5, 4)), + (:allocs, logsumexp, randn(sr(3), P, 5, 4, 3)), + (:none, x -> logsumexp(x; dims=1), randn(sr(4), P, 5, 4)), + (:none, x -> logsumexp(x; dims=2), randn(sr(5), P, 5, 4)), + (:none, logsumexp!, rand(sr(6), 5), randn(sr(7), P, 5, 4)), + (:none, softmax, randn(sr(7), P, 10)), + (:allocs, cloglog, P(0.5)), + (:allocs, cexpexp, -P(0.3)), + (:allocs, loglogistic, P(0.5)), + (:allocs, logitexp, -P(0.3)), + (:allocs, log1mlogistic, -P(0.9)), + (:allocs, logit1mexp, -P(0.6)), + ] + end..., + ) test_rule(sr(123456), f, x...; perf_flag, is_primitive=false) end end diff --git a/test/integration_testing/lux/lux.jl b/test/integration_testing/lux/lux.jl index 74078f4a7..0e1a9af15 100644 --- a/test/integration_testing/lux/lux.jl +++ b/test/integration_testing/lux/lux.jl @@ -5,82 +5,115 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using Mooncake, Lux, StableRNGs, Test using Mooncake.TestUtils: test_rule +sr(x) = StableRNG(x) + @testset "lux" begin - @testset "$(typeof(f))" for (f, x_f32) in Any[ - (Dense(2, 4), randn(Float32, 2, 3)), - (Dense(2, 4, gelu), randn(Float32, 2, 3)), - (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), - (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), - (Scale(2), randn(Float32, 2, 3)), - (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), - (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + P = Float32 + @testset "$(typeof(f))" for (interface_only, f, x_f32) in Any[ + (false, Dense(2, 4), randn(sr(1), P, 2, 3)), + (false, Dense(2, 4, gelu), randn(sr(2), P, 2, 3)), + (false, Dense(2, 4, gelu; use_bias=false), randn(sr(3), P, 2, 3)), + (false, Chain(Dense(2, 4, relu), Dense(4, 3)), randn(sr(4), P, 2, 3)), + (false, Scale(2), randn(sr(5), P, 2, 3)), + (false, Conv((3, 3), 2 => 3), randn(sr(6), P, 3, 3, 2, 2)), + (false, Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(sr(7), P, 3, 3, 2, 2)), ( + false, Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), - randn(Float32, 3, 3, 2, 2), + randn(sr(8), P, 3, 3, 2, 2), ), ( + false, Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), - rand(Float32, 5, 5, 2, 2), + rand(sr(9), P, 5, 5, 2, 2), ), ( + false, Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), - rand(Float32, 5, 5, 2, 2), + rand(sr(10), P, 5, 5, 2, 2), ), ( + false, Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), - rand(Float32, 5, 5, 2, 2), + rand(sr(11), P, 5, 5, 2, 2), + ), + (false, Maxout(() -> Dense(5 => 4, tanh), 3), randn(sr(12), P, 5, 2)), + (false, Bilinear((2, 2) => 3), randn(sr(13), P, 2, 3)), + (false, SkipConnection(Dense(2 => 2), vcat), randn(sr(14), P, 2, 3)), + (false, ConvTranspose((3, 3), 3 => 2; stride=2), rand(sr(15), P, 5, 5, 3, 1)), + (false, StatefulRecurrentCell(RNNCell(3 => 5)), rand(sr(16), P, 3, 2)), + (false, StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(sr(17), P, 3, 2)), + ( + false, + StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), + rand(sr(18), P, 3, 2), ), - (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), - (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), - (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), - (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), - (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), - (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), - (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), ( + false, Chain( StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3)), ), - rand(Float32, 3, 2), + rand(sr(19), P, 3, 2), ), - (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), + (false, StatefulRecurrentCell(LSTMCell(3 => 5)), rand(sr(20), P, 3, 2)), ( + false, Chain( StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3)), ), - rand(Float32, 3, 2), + rand(sr(21), P, 3, 2), ), - (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), + (false, StatefulRecurrentCell(GRUCell(3 => 5)), rand(sr(22), P, 3, 10)), ( + false, Chain( StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3)), ), - rand(Float32, 3, 10), + rand(sr(23), P, 3, 10), ), - (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), - (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), - (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), - (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), - (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), - (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), - (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), - (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), - (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (true, Chain(Dense(2, 4), BatchNorm(4)), randn(sr(24), P, 2, 3)), + (true, Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(sr(25), P, 2, 3)), ( + true, + Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), + randn(sr(26), P, 2, 3), + ), + (true, Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(sr(27), P, 6, 6, 2, 2)), + ( + true, + Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), + randn(sr(28), P, 6, 6, 2, 2), + ), + (false, Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(sr(29), P, 2, 3)), + (false, Chain(Dense(2, 4), GroupNorm(4, 2)), randn(sr(30), P, 2, 3)), + (false, Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(sr(31), P, 6, 6, 2, 2)), + ( + false, + Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), + randn(sr(32), P, 6, 6, 2, 2), + ), + ( + false, Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), - randn(Float32, 4, 4, 2, 2), + randn(sr(33), P, 4, 4, 2, 2), + ), + (false, InstanceNorm(6), randn(sr(34), P, 6, 6, 2, 2)), + (false, Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(sr(35), P, 6, 6, 2, 2)), + ( + false, + Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), + randn(sr(36), P, 6, 6, 2, 2), ), - (InstanceNorm(6), randn(Float32, 6, 6, 2, 2)), - (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), - (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] @info "$(typeof((f, x_f32...)))" - rng = StableRNG(123456) + rng = sr(123546) ps, st = f32(Lux.setup(rng, f)) x = f32(x_f32) - test_rule(rng, f, x, ps, st; is_primitive=false, interface_only=true) + test_rule( + rng, f, x, ps, st; is_primitive=false, interface_only, unsafe_perturb=true + ) end end