Skip to content

Commit

Permalink
Tidy up blas testing
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 10, 2024
1 parent 0cc086a commit c661c7d
Showing 1 changed file with 91 additions and 140 deletions.
231 changes: 91 additions & 140 deletions src/rrules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,68 +738,47 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
test_cases = vcat(

# gemv!
vec(
reduce(
vcat,
map(product(t_flags, [1, 3], [1, 2], Ps)) do (tA, M, N, P)
t = tA == 'N'
As = blas_matrices(rng, P, t ? M : N, t ? N : M)
xs = blas_vectors(rng, P, N)
ys = blas_vectors(rng, P, M)
flags = (false, :stability, (lb=1e-3, ub=10.0))
return map(product(As, xs, ys)) do (A, x, y)
return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y)
end
end,
),
),
map(flat_product(t_flags, [1, 3], [1, 2], Ps)) do (tA, M, N, P)
As = blas_matrices(rng, P, tA == 'N' ? M : N, tA == 'N' ? N : M)
xs = blas_vectors(rng, P, N)
ys = blas_vectors(rng, P, M)
flags = (false, :stability, (lb=1e-3, ub=10.0))
return map(flat_product(As, xs, ys)) do (A, x, y)
return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y)
end
end...,

# symv!
vec(
reduce(
vcat,
map(product(['L', 'U'], alphas, betas, Ps)) do (uplo, α, β, P)
As = blas_matrices(rng, P, 5, 5)
ys = blas_vectors(rng, P, 5)
xs = blas_vectors(rng, P, 5)
return map(product(As, xs, ys)) do (A, x, y)
(false, :stability, nothing, BLAS.symv!, uplo, P(α), A, x, P(β), y)
end
end,
),
),
map(flat_product(['L', 'U'], alphas, betas, Ps)) do (uplo, α, β, P)
As = blas_matrices(rng, P, 5, 5)
ys = blas_vectors(rng, P, 5)
xs = blas_vectors(rng, P, 5)
return map(flat_product(As, xs, ys)) do (A, x, y)
(false, :stability, nothing, BLAS.symv!, uplo, P(α), A, x, P(β), y)
end
end...,

# gemm!
vec(
reduce(
vcat,
map(product(t_flags, t_flags, alphas, betas, Ps)) do (tA, tB, a, b, P)
As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3)
Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4)
Cs = blas_matrices(rng, P, 3, 5)
return map(product(As, Bs, Cs)) do (A, B, C)
(false, :none, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C)
end
end,
),
),
map(flat_product(t_flags, t_flags, alphas, betas, Ps)) do (tA, tB, a, b, P)
As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3)
Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4)
Cs = blas_matrices(rng, P, 3, 5)
return map(flat_product(As, Bs, Cs)) do (A, B, C)
(false, :none, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C)
end
end...,

# symm!
vec(
reduce(
vcat,
map(product(['L', 'R'], ['L', 'U'], alphas, betas, Ps)) do (side, uplo, α, β, P)
nA = side == 'L' ? 5 : 7
As = blas_matrices(rng, P, nA, nA)
Bs = blas_matrices(rng, P, 5, 7)
Cs = blas_matrices(rng, P, 5, 7)
flags = (false, :stability, nothing)
return map(product(As, Bs, Cs)) do (A, B, C)
(flags..., BLAS.symm!, side, uplo, P(α), A, B, P(β), C)
end
end,
),
),
map(flat_product(['L', 'R'], ['L', 'U'], alphas, betas, Ps)) do (side, ul, α, β, P)
nA = side == 'L' ? 5 : 7
As = blas_matrices(rng, P, nA, nA)
Bs = blas_matrices(rng, P, 5, 7)
Cs = blas_matrices(rng, P, 5, 7)
flags = (false, :stability, nothing)
return map(flat_product(As, Bs, Cs)) do (A, B, C)
(flags..., BLAS.symm!, side, ul, P(α), A, B, P(β), C)
end
end...,
)

memory = Any[]
Expand All @@ -826,108 +805,80 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
# BLAS LEVEL 1
#

reduce(
vcat,
map(Ps) do P
Any[
(false, :none, nothing, BLAS.dot, 3, randn(P, 5), 1, randn(P, 4), 1),
(false, :none, nothing, BLAS.dot, 3, randn(P, 6), 2, randn(P, 4), 1),
(false, :none, nothing, BLAS.dot, 3, randn(P, 6), 1, randn(P, 9), 3),
(false, :none, nothing, BLAS.dot, 3, randn(P, 12), 3, randn(P, 9), 2),
(false, :none, nothing, BLAS.scal!, 10, P(2.4), randn(P, 30), 2),
]
end,
),
map(Ps) do P
Any[
(false, :none, nothing, BLAS.dot, 3, randn(P, 5), 1, randn(P, 4), 1),
(false, :none, nothing, BLAS.dot, 3, randn(P, 6), 2, randn(P, 4), 1),
(false, :none, nothing, BLAS.dot, 3, randn(P, 6), 1, randn(P, 9), 3),
(false, :none, nothing, BLAS.dot, 3, randn(P, 12), 3, randn(P, 9), 2),
(false, :none, nothing, BLAS.scal!, 10, P(2.4), randn(P, 30), 2),
]
end...,

#
# BLAS LEVEL 2
#

# trmv!
vec(
reduce(
vcat,
map(product(uplos, t_flags, dAs, [1, 3], Ps)) do (ul, tA, dA, N, P)
As = blas_matrices(rng, P, N, N)
bs = blas_vectors(rng, P, N)
return map(product(As, bs)) do (A, b)
(false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b)
end
end,
),
),
map(flat_product(uplos, t_flags, dAs, [1, 3], Ps)) do (ul, tA, dA, N, P)
As = blas_matrices(rng, P, N, N)
bs = blas_vectors(rng, P, N)
return map(flat_product(As, bs)) do (A, b)
(false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b)
end
end...,

#
# BLAS LEVEL 3
#

# aliased gemm!
vec(
reduce(
vcat,
map(product(t_flags, t_flags, Ps)) do (tA, tB, P)
As = blas_matrices(rng, P, 5, 5)
Bs = blas_matrices(rng, P, 5, 5)
flags = false, :none, nothing
return map(product(As, Bs)) do (A, B)
(flags..., aliased_gemm!, tA, tB, randn(P), randn(P), A, B)
end
end,
),
),
map(flat_product(t_flags, t_flags, Ps)) do (tA, tB, P)
As = blas_matrices(rng, P, 5, 5)
Bs = blas_matrices(rng, P, 5, 5)
flags = false, :none, nothing
return map(flat_product(As, Bs)) do (A, B)
(flags..., aliased_gemm!, tA, tB, randn(P), randn(P), A, B)
end
end...,

# syrk!
vec(
reduce(
vcat,
map(product(uplos, t_flags, Ps)) do (uplo, t, P)
As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3)
C = randn(P, 3, 3)
flags = (false, :none, nothing)
return map(As) do A
return (flags..., BLAS.syrk!, uplo, t, randn(P), A, randn(P), C)
end
end,
),
),
map(flat_product(uplos, t_flags, Ps)) do (uplo, t, P)
As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3)
C = randn(P, 3, 3)
flags = (false, :none, nothing)
return map(As) do A
return (flags..., BLAS.syrk!, uplo, t, randn(P), A, randn(P), C)
end
end...,

# trmm!
vec(
reduce(
vcat,
map(
product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps)
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
flags = (false, :none, nothing)
return map(product(As, Bs)) do (A, B)
(flags..., BLAS.trmm!, side, ul, tA, dA, randn(P), A, B)
end
end,
),
),
map(
flat_product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps)
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
flags = (false, :none, nothing)
return map(flat_product(As, Bs)) do (A, B)
(flags..., BLAS.trmm!, side, ul, tA, dA, randn(P), A, B)
end
end...,

# trsm!
vec(
reduce(
vcat,
map(
product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps)
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
flags = (false, :none, nothing)
return map(product(As, Bs)) do (A, B)
(flags..., BLAS.trsm!, side, ul, tA, dA, randn(P), A, B)
end
end,
),
),
map(
flat_product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps)
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
flags = (false, :none, nothing)
return map(flat_product(As, Bs)) do (A, B)
(flags..., BLAS.trsm!, side, ul, tA, dA, randn(P), A, B)
end
end...,
)
memory = Any[]
return test_cases, memory
Expand Down

0 comments on commit c661c7d

Please sign in to comment.