Skip to content

Commit

Permalink
Tidy up blas and lapack rules
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 10, 2024
1 parent 5b52498 commit b9c17c1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 38 deletions.
50 changes: 23 additions & 27 deletions src/rrules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,45 +738,44 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
test_cases = vcat(

# gemv!
map(flat_product(t_flags, [1, 3], [1, 2], Ps)) do (tA, M, N, P)
map_prod(t_flags, [1, 3], [1, 2], Ps) do (tA, M, N, P)
As = blas_matrices(rng, P, tA == 'N' ? M : N, tA == 'N' ? N : M)
xs = blas_vectors(rng, P, N)
ys = blas_vectors(rng, P, M)
flags = (false, :stability, (lb=1e-3, ub=10.0))
return map(flat_product(As, xs, ys)) do (A, x, y)
return map_prod(As, xs, ys) do (A, x, y)
return (flags..., BLAS.gemv!, tA, randn(P), A, x, randn(P), y)
end
end...,

# symv!
map(flat_product(['L', 'U'], alphas, betas, Ps)) do (uplo, α, β, P)
map_prod(['L', 'U'], alphas, betas, Ps) do (uplo, α, β, P)
As = blas_matrices(rng, P, 5, 5)
ys = blas_vectors(rng, P, 5)
xs = blas_vectors(rng, P, 5)
return map(flat_product(As, xs, ys)) do (A, x, y)
return map_prod(As, xs, ys) do (A, x, y)
(false, :stability, nothing, BLAS.symv!, uplo, P(α), A, x, P(β), y)
end
end...,

# gemm!
map(flat_product(t_flags, t_flags, alphas, betas, Ps)) do (tA, tB, a, b, P)
map_prod(t_flags, t_flags, alphas, betas, Ps) do (tA, tB, a, b, P)
As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3)
Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4)
Cs = blas_matrices(rng, P, 3, 5)
return map(flat_product(As, Bs, Cs)) do (A, B, C)
return map_prod(As, Bs, Cs) do (A, B, C)
(false, :none, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C)
end
end...,

# symm!
map(flat_product(['L', 'R'], ['L', 'U'], alphas, betas, Ps)) do (side, ul, α, β, P)
map_prod(['L', 'R'], ['L', 'U'], alphas, betas, Ps) do (side, ul, α, β, P)
nA = side == 'L' ? 5 : 7
As = blas_matrices(rng, P, nA, nA)
Bs = blas_matrices(rng, P, 5, 7)
Cs = blas_matrices(rng, P, 5, 7)
flags = (false, :stability, nothing)
return map(flat_product(As, Bs, Cs)) do (A, B, C)
(flags..., BLAS.symm!, side, ul, P(α), A, B, P(β), C)
return map_prod(As, Bs, Cs) do (A, B, C)
(false, :stability, nothing, BLAS.symm!, side, ul, P(α), A, B, P(β), C)
end
end...,
)
Expand Down Expand Up @@ -820,10 +819,10 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
#

# trmv!
map(flat_product(uplos, t_flags, dAs, [1, 3], Ps)) do (ul, tA, dA, N, P)
map_prod(uplos, t_flags, dAs, [1, 3], Ps) do (ul, tA, dA, N, P)
As = blas_matrices(rng, P, N, N)
bs = blas_vectors(rng, P, N)
return map(flat_product(As, bs)) do (A, b)
return map_prod(As, bs) do (A, b)
(false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b)
end
end...,
Expand All @@ -833,17 +832,16 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
#

# aliased gemm!
map(flat_product(t_flags, t_flags, Ps)) do (tA, tB, P)
map_prod(t_flags, t_flags, Ps) do (tA, tB, P)
As = blas_matrices(rng, P, 5, 5)
Bs = blas_matrices(rng, P, 5, 5)
flags = false, :none, nothing
return map(flat_product(As, Bs)) do (A, B)
(flags..., aliased_gemm!, tA, tB, randn(P), randn(P), A, B)
return map_prod(As, Bs) do (A, B)
(false, :none, nothing, aliased_gemm!, tA, tB, randn(P), randn(P), A, B)
end
end...,

# syrk!
map(flat_product(uplos, t_flags, Ps)) do (uplo, t, P)
map_prod(uplos, t_flags, Ps) do (uplo, t, P)
As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3)
C = randn(P, 3, 3)
flags = (false, :none, nothing)
Expand All @@ -853,30 +851,28 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
end...,

# trmm!
map(
flat_product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps)
map_prod(
['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
flags = (false, :none, nothing)
return map(flat_product(As, Bs)) do (A, B)
(flags..., BLAS.trmm!, side, ul, tA, dA, randn(P), A, B)
return map_prod(As, Bs) do (A, B)
(false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, randn(P), A, B)
end
end...,

# trsm!
map(
flat_product(['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps)
map_prod(
['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps
) do (side, ul, tA, dA, M, N, P)
t = tA == 'N'
R = side == 'L' ? M : N
As = blas_matrices(rng, P, R, R)
Bs = blas_matrices(rng, P, M, N)
flags = (false, :none, nothing)
return map(flat_product(As, Bs)) do (A, B)
(flags..., BLAS.trsm!, side, ul, tA, dA, randn(P), A, B)
return map_prod(As, Bs) do (A, B)
(false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, randn(P), A, B)
end
end...,
)
Expand Down
22 changes: 11 additions & 11 deletions src/rrules/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,38 +510,38 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
test_cases = vcat(

# getrf!
map(flat_product(bools, Ps)) do (check, P)
map_prod(bools, Ps) do (check, P)
As = blas_matrices(rng, P, 5, 5)
return map(As) do A
(false, :none, nothing, getrf_wrapper!, A, check)

Check warning on line 516 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L514-L516

Added lines #L514 - L516 were not covered by tests
end
end...,

# trtrs
map(
flat_product(['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2], Ps)
map_prod(
['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2], Ps
) do (ul, tA, diag, N, Nrhs, P)
As = blas_matrices(rng, P, N, N)
Bs = blas_matrices(rng, P, N, Nrhs)
return map(flat_product(As, Bs)) do (A, B)
return map_prod(As, Bs) do (A, B)
(false, :none, nothing, trtrs!, ul, tA, diag, A, B)

Check warning on line 527 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L524-L527

Added lines #L524 - L527 were not covered by tests
end
end...,

# getrs
map(flat_product(['N', 'T'], [1, 9], [1, 2], Ps)) do (trans, N, Nrhs, P)
map_prod(['N', 'T'], [1, 9], [1, 2], Ps) do (trans, N, Nrhs, P)
As = map(blas_matrices(rng, P, N, N)) do A
A[diagind(A)] .+= 5
return getrf!(A)

Check warning on line 535 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L533-L535

Added lines #L533 - L535 were not covered by tests
end
Bs = blas_matrices(rng, P, N, Nrhs)
return map(flat_product(As, Bs)) do ((A, ipiv), B)
return map_prod(As, Bs) do ((A, ipiv), B)
(false, :none, nothing, getrs!, trans, A, ipiv, B)

Check warning on line 539 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L537-L539

Added lines #L537 - L539 were not covered by tests
end
end...,

# getri
map(flat_product([1, 9], Ps)) do (N, P)
map_prod([1, 9], Ps) do (N, P)
As = map(blas_matrices(rng, P, N, N)) do A
A[diagind(A)] .+= 5
return getrf!(A)

Check warning on line 547 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L545-L547

Added lines #L545 - L547 were not covered by tests
Expand All @@ -552,22 +552,22 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
end...,

# potrf
map(flat_product([1, 3, 9], Ps)) do (N, P)
map_prod([1, 3, 9], Ps) do (N, P)
As = map(blas_matrices(rng, P, N, N)) do A
A .= A * A' + I
return A

Check warning on line 558 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L556-L558

Added lines #L556 - L558 were not covered by tests
end
return map(flat_product(['L', 'U'], As)) do (uplo, A)
return map_prod(['L', 'U'], As) do (uplo, A)
return (false, :none, nothing, potrf!, uplo, A)

Check warning on line 561 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L560-L561

Added lines #L560 - L561 were not covered by tests
end
end...,

# potrs
map(flat_product([1, 3, 9], [1, 2], Ps)) do (N, Nrhs, P)
map_prod([1, 3, 9], [1, 2], Ps) do (N, Nrhs, P)
X = randn(rng, P, N, N)
A = X * X' + I
Bs = blas_matrices(rng, P, N, Nrhs)
return map(flat_product(['L', 'U'], Bs)) do (uplo, B)
return map_prod(['L', 'U'], Bs) do (uplo, B)
(false, :none, nothing, potrs!, uplo, potrf!(uplo, copy(A))[1], copy(B))

Check warning on line 571 in src/rrules/lapack.jl

View check run for this annotation

Codecov / codecov/patch

src/rrules/lapack.jl#L567-L571

Added lines #L567 - L571 were not covered by tests
end
end...,
Expand Down

0 comments on commit b9c17c1

Please sign in to comment.