Skip to content

Commit

Permalink
Simplify BLAS and LAPACK testing
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 11, 2024
1 parent dd087fb commit e22f061
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
15 changes: 7 additions & 8 deletions src/rrules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,6 @@ function blas_vectors(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int)
xs = Any[
randn(rng, P, p),
view(randn(rng, P, p + 5), 3:(p + 2)),
view(randn(rng, P, 3p), 1:2:(2p)),
view(randn(rng, P, 3p, 3), 1:2:(2p), 2),
]
@assert all(x -> length(x) == p, xs)
Expand All @@ -743,7 +742,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
xs = blas_vectors(rng, P, N)
ys = blas_vectors(rng, P, M)
flags = (false, :stability, (lb=1e-3, ub=10.0))
return map_prod(As, xs, ys) do (A, x, y)
return map(As, xs, ys) do A, x, y
return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y)
end
end...,
Expand All @@ -753,7 +752,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
As = blas_matrices(rng, P, 5, 5)
ys = blas_vectors(rng, P, 5)
xs = blas_vectors(rng, P, 5)
return map_prod(As, xs, ys) do (A, x, y)
return map(As, xs, ys) do A, x, y
(false, :stability, nothing, BLAS.symv!, uplo, P(α), A, x, P(β), y)
end
end...,
Expand All @@ -763,7 +762,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3)
Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4)
Cs = blas_matrices(rng, P, 3, 5)
return map_prod(As, Bs, Cs) do (A, B, C)
return map(As, Bs, Cs) do A, B, C
(false, :none, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C)
end
end...,
Expand All @@ -774,7 +773,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
As = blas_matrices(rng, P, nA, nA)
Bs = blas_matrices(rng, P, 5, 7)
Cs = blas_matrices(rng, P, 5, 7)
return map_prod(As, Bs, Cs) do (A, B, C)
return map(As, Bs, Cs) do A, B, C
(false, :stability, nothing, BLAS.symm!, side, ul, P(α), A, B, P(β), C)
end
end...,
Expand Down Expand Up @@ -822,7 +821,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
map_prod(uplos, t_flags, dAs, [1, 3], Ps) do (ul, tA, dA, N, P)
As = blas_matrices(rng, P, N, N)
bs = blas_vectors(rng, P, N)
return map_prod(As, bs) do (A, b)
return map(As, bs) do A, b
(false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b)
end
end...,
Expand Down Expand Up @@ -858,7 +857,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
R = side == 'L' ? M : N
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
return map_prod(As, Bs) do (A, B)
return map(As, Bs) do A, B
(false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, randn(P), A, B)
end
end...,
Expand All @@ -874,7 +873,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
return A
end
Bs = blas_matrices(rng, P, M, N)
return map_prod(As, Bs) do (A, B)
return map(As, Bs) do A, B
(false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, randn(P), A, B)
end
end...,
Expand Down
8 changes: 4 additions & 4 deletions src/rrules/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
) do (ul, tA, diag, N, Nrhs, P)
As = blas_matrices(rng, P, N, N)
Bs = blas_matrices(rng, P, N, Nrhs)
return map_prod(As, Bs) do (A, B)
return map(As, Bs) do A, B
(false, :none, nothing, trtrs!, ul, tA, diag, A, B)
end
end...,
Expand All @@ -348,7 +348,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
return getrf!(A)
end
Bs = blas_matrices(rng, P, N, Nrhs)
return map_prod(As, Bs) do ((A, ipiv), B)
return map(As, Bs) do (A, ipiv), B
(false, :none, nothing, getrs!, trans, A, ipiv, B)
end
end...,
Expand All @@ -370,7 +370,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
A .= A * A' + I
return A
end
return map_prod(['L', 'U'], As) do (uplo, A)
return map(['L', 'U'], As) do uplo, A
return (false, :none, nothing, potrf!, uplo, A)
end
end...,
Expand All @@ -380,7 +380,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
X = randn(rng, P, N, N)
A = X * X' + I
Bs = blas_matrices(rng, P, N, Nrhs)
return map_prod(['L', 'U'], Bs) do (uplo, B)
return map(['L', 'U'], Bs) do uplo, B
(false, :none, nothing, potrs!, uplo, potrf!(uplo, copy(A))[1], copy(B))
end
end...,
Expand Down

0 comments on commit e22f061

Please sign in to comment.