From e22f06187d423d8f7ec84a1ec831ba6f5e69b427 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Wed, 11 Dec 2024 09:55:27 +0000 Subject: [PATCH] Simplify BLAS and LAPACK testing --- src/rrules/blas.jl | 15 +++++++-------- src/rrules/lapack.jl | 8 ++++---- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index d6576e4d3..f74cf02c1 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -720,7 +720,6 @@ function blas_vectors(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int) xs = Any[ randn(rng, P, p), view(randn(rng, P, p + 5), 3:(p + 2)), - view(randn(rng, P, 3p), 1:2:(2p)), view(randn(rng, P, 3p, 3), 1:2:(2p), 2), ] @assert all(x -> length(x) == p, xs) @@ -743,7 +742,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) xs = blas_vectors(rng, P, N) ys = blas_vectors(rng, P, M) flags = (false, :stability, (lb=1e-3, ub=10.0)) - return map_prod(As, xs, ys) do (A, x, y) + return map(As, xs, ys) do A, x, y return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y) end end..., @@ -753,7 +752,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) As = blas_matrices(rng, P, 5, 5) ys = blas_vectors(rng, P, 5) xs = blas_vectors(rng, P, 5) - return map_prod(As, xs, ys) do (A, x, y) + return map(As, xs, ys) do A, x, y (false, :stability, nothing, BLAS.symv!, uplo, P(α), A, x, P(β), y) end end..., @@ -763,7 +762,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3) Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4) Cs = blas_matrices(rng, P, 3, 5) - return map_prod(As, Bs, Cs) do (A, B, C) + return map(As, Bs, Cs) do A, B, C (false, :none, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C) end end..., @@ -774,7 +773,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) As = blas_matrices(rng, P, nA, nA) Bs = blas_matrices(rng, P, 5, 7) Cs = blas_matrices(rng, P, 5, 7) - return map_prod(As, Bs, Cs) do (A, B, C) + return map(As, Bs, Cs) do A, B, C (false, :stability, nothing, BLAS.symm!, side, ul, P(α), A, B, P(β), C) end end..., @@ -822,7 +821,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) map_prod(uplos, t_flags, dAs, [1, 3], Ps) do (ul, tA, dA, N, P) As = blas_matrices(rng, P, N, N) bs = blas_vectors(rng, P, N) - return map_prod(As, bs) do (A, b) + return map(As, bs) do A, b (false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b) end end..., @@ -858,7 +857,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) R = side == 'L' ? M : N As = blas_matrices(rng, P, R, R) Bs = blas_matrices(rng, P, M, N) - return map_prod(As, Bs) do (A, B) + return map(As, Bs) do A, B (false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, randn(P), A, B) end end..., @@ -874,7 +873,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) return A end Bs = blas_matrices(rng, P, M, N) - return map_prod(As, Bs) do (A, B) + return map(As, Bs) do A, B (false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, randn(P), A, B) end end..., diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index 905d90489..490ad5587 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -336,7 +336,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) ) do (ul, tA, diag, N, Nrhs, P) As = blas_matrices(rng, P, N, N) Bs = blas_matrices(rng, P, N, Nrhs) - return map_prod(As, Bs) do (A, B) + return map(As, Bs) do A, B (false, :none, nothing, trtrs!, ul, tA, diag, A, B) end end..., @@ -348,7 +348,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) return getrf!(A) end Bs = blas_matrices(rng, P, N, Nrhs) - return map_prod(As, Bs) do ((A, ipiv), B) + return map(As, Bs) do (A, ipiv), B (false, :none, nothing, getrs!, trans, A, ipiv, B) end end..., @@ -370,7 +370,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) A .= A * A' + I return A end - return map_prod(['L', 'U'], As) do (uplo, A) + return map(['L', 'U'], As) do uplo, A return (false, :none, nothing, potrf!, uplo, A) end end..., @@ -380,7 +380,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) X = randn(rng, P, N, N) A = X * X' + I Bs = blas_matrices(rng, P, N, Nrhs) - return map_prod(['L', 'U'], Bs) do (uplo, B) + return map(['L', 'U'], Bs) do uplo, B (false, :none, nothing, potrs!, uplo, potrf!(uplo, copy(A))[1], copy(B)) end end...,