diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 36ab00aae..7674a5b55 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -738,45 +738,44 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) test_cases = vcat( # gemv! - map(flat_product(t_flags, [1, 3], [1, 2], Ps)) do (tA, M, N, P) + map_prod(t_flags, [1, 3], [1, 2], Ps) do (tA, M, N, P) As = blas_matrices(rng, P, tA == 'N' ? M : N, tA == 'N' ? N : M) xs = blas_vectors(rng, P, N) ys = blas_vectors(rng, P, M) flags = (false, :stability, (lb=1e-3, ub=10.0)) - return map(flat_product(As, xs, ys)) do (A, x, y) + return map_prod(As, xs, ys) do (A, x, y) return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y) end end..., # symv! - map(flat_product(['L', 'U'], alphas, betas, Ps)) do (uplo, α, β, P) + map_prod(['L', 'U'], alphas, betas, Ps) do (uplo, α, β, P) As = blas_matrices(rng, P, 5, 5) ys = blas_vectors(rng, P, 5) xs = blas_vectors(rng, P, 5) - return map(flat_product(As, xs, ys)) do (A, x, y) + return map_prod(As, xs, ys) do (A, x, y) (false, :stability, nothing, BLAS.symv!, uplo, P(α), A, x, P(β), y) end end..., # gemm! - map(flat_product(t_flags, t_flags, alphas, betas, Ps)) do (tA, tB, a, b, P) + map_prod(t_flags, t_flags, alphas, betas, Ps) do (tA, tB, a, b, P) As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3) Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4) Cs = blas_matrices(rng, P, 3, 5) - return map(flat_product(As, Bs, Cs)) do (A, B, C) + return map_prod(As, Bs, Cs) do (A, B, C) (false, :none, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C) end end..., # symm! - map(flat_product(['L', 'R'], ['L', 'U'], alphas, betas, Ps)) do (side, ul, α, β, P) + map_prod(['L', 'R'], ['L', 'U'], alphas, betas, Ps) do (side, ul, α, β, P) nA = side == 'L' ? 5 : 7 As = blas_matrices(rng, P, nA, nA) Bs = blas_matrices(rng, P, 5, 7) Cs = blas_matrices(rng, P, 5, 7) - flags = (false, :stability, nothing) - return map(flat_product(As, Bs, Cs)) do (A, B, C) - (flags..., BLAS.symm!, side, ul, P(α), A, B, P(β), C) + return map_prod(As, Bs, Cs) do (A, B, C) + (false, :stability, nothing, BLAS.symm!, side, ul, P(α), A, B, P(β), C) end end..., ) @@ -820,10 +819,10 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) # # trmv! - map(flat_product(uplos, t_flags, dAs, [1, 3], Ps)) do (ul, tA, dA, N, P) + map_prod(uplos, t_flags, dAs, [1, 3], Ps) do (ul, tA, dA, N, P) As = blas_matrices(rng, P, N, N) bs = blas_vectors(rng, P, N) - return map(flat_product(As, bs)) do (A, b) + return map_prod(As, bs) do (A, b) (false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b) end end..., @@ -833,17 +832,16 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) # # aliased gemm! - map(flat_product(t_flags, t_flags, Ps)) do (tA, tB, P) + map_prod(t_flags, t_flags, Ps) do (tA, tB, P) As = blas_matrices(rng, P, 5, 5) Bs = blas_matrices(rng, P, 5, 5) - flags = false, :none, nothing - return map(flat_product(As, Bs)) do (A, B) - (flags..., aliased_gemm!, tA, tB, randn(P), randn(P), A, B) + return map_prod(As, Bs) do (A, B) + (false, :none, nothing, aliased_gemm!, tA, tB, randn(P), randn(P), A, B) end end..., # syrk! - map(flat_product(uplos, t_flags, Ps)) do (uplo, t, P) + map_prod(uplos, t_flags, Ps) do (uplo, t, P) As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3) C = randn(P, 3, 3) flags = (false, :none, nothing) @@ -853,30 +851,28 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) end..., # trmm! - map( - flat_product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps) + map_prod( + ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps ) do (side, ul, tA, dA, M, N, P) t = tA == 'N' R = side == 'L' ? M : N As = blas_matrices(rng, P, R, R) Bs = blas_matrices(rng, P, M, N) - flags = (false, :none, nothing) - return map(flat_product(As, Bs)) do (A, B) - (flags..., BLAS.trmm!, side, ul, tA, dA, randn(P), A, B) + return map_prod(As, Bs) do (A, B) + (false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, randn(P), A, B) end end..., # trsm! - map( - flat_product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps) + map_prod( + ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps ) do (side, ul, tA, dA, M, N, P) t = tA == 'N' R = side == 'L' ? M : N As = blas_matrices(rng, P, R, R) Bs = blas_matrices(rng, P, M, N) - flags = (false, :none, nothing) - return map(flat_product(As, Bs)) do (A, B) - (flags..., BLAS.trsm!, side, ul, tA, dA, randn(P), A, B) + return map_prod(As, Bs) do (A, B) + (false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, randn(P), A, B) end end..., ) diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index d03170533..d4e9eeeaa 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -510,7 +510,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) test_cases = vcat( # getrf! - map(flat_product(bools, Ps)) do (check, P) + map_prod(bools, Ps) do (check, P) As = blas_matrices(rng, P, 5, 5) return map(As) do A (false, :none, nothing, getrf_wrapper!, A, check) @@ -518,30 +518,30 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) end..., # trtrs - map( - flat_product(['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2], Ps) + map_prod( + ['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2], Ps ) do (ul, tA, diag, N, Nrhs, P) As = blas_matrices(rng, P, N, N) Bs = blas_matrices(rng, P, N, Nrhs) - return map(flat_product(As, Bs)) do (A, B) + return map_prod(As, Bs) do (A, B) (false, :none, nothing, trtrs!, ul, tA, diag, A, B) end end..., # getrs - map(flat_product(['N', 'T'], [1, 9], [1, 2], Ps)) do (trans, N, Nrhs, P) + map_prod(['N', 'T'], [1, 9], [1, 2], Ps) do (trans, N, Nrhs, P) As = map(blas_matrices(rng, P, N, N)) do A A[diagind(A)] .+= 5 return getrf!(A) end Bs = blas_matrices(rng, P, N, Nrhs) - return map(flat_product(As, Bs)) do ((A, ipiv), B) + return map_prod(As, Bs) do ((A, ipiv), B) (false, :none, nothing, getrs!, trans, A, ipiv, B) end end..., # getri - map(flat_product([1, 9], Ps)) do (N, P) + map_prod([1, 9], Ps) do (N, P) As = map(blas_matrices(rng, P, N, N)) do A A[diagind(A)] .+= 5 return getrf!(A) @@ -552,22 +552,22 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) end..., # potrf - map(flat_product([1, 3, 9], Ps)) do (N, P) + map_prod([1, 3, 9], Ps) do (N, P) As = map(blas_matrices(rng, P, N, N)) do A A .= A * A' + I return A end - return map(flat_product(['L', 'U'], As)) do (uplo, A) + return map_prod(['L', 'U'], As) do (uplo, A) return (false, :none, nothing, potrf!, uplo, A) end end..., # potrs - map(flat_product([1, 3, 9], [1, 2], Ps)) do (N, Nrhs, P) + map_prod([1, 3, 9], [1, 2], Ps) do (N, Nrhs, P) X = randn(rng, P, N, N) A = X * X' + I Bs = blas_matrices(rng, P, N, Nrhs) - return map(flat_product(['L', 'U'], Bs)) do (uplo, B) + return map_prod(['L', 'U'], Bs) do (uplo, B) (false, :none, nothing, potrs!, uplo, potrf!(uplo, copy(A))[1], copy(B)) end end...,