Skip to content

Commit

Permalink
cleanup and add inverse benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
joshday committed Sep 10, 2024
1 parent c7f7424 commit 65375a5
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
matrix:
version:
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'lts'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -26,7 +27,7 @@ jobs:
arch: x86
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
Expand Down
24 changes: 17 additions & 7 deletions src/SweepOperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import LinearAlgebra: BlasFloat, checksquare
const AMat = AbstractMatrix
const AVec = AbstractVector

#-----------------------------------------------------------------------------# sweep!
"""
sweep!(A, k ; inv=false)
sweep!(A, ks; inv=false)
Expand All @@ -22,17 +23,19 @@ Perform the sweep operation (or inverse sweep if `inv=true`) on matrix `A` on el
sweep!(xtx, 1, true)
"""
function sweep!(A::AMat, k::Integer, inv::Bool = false)
sweep_with_buffer!(Vector{eltype(A)}(undef, size(A, 2)), A, k, inv)
akk = Vector{eltype(A)}(undef, size(A, 2))
sweep_with_buffer!(akk, A, k, inv)
end

function sweep!(A::AMat{T}, ks::AVec{I}, inv::Bool = false) where {T<:BlasFloat, I<:Integer}
function sweep!(A::AMat{T}, ks::AVec{I}, inv::Bool = false) where {T <: BlasFloat, I <: Integer}
akk = Vector{T}(undef, size(A,1))
for k in ks
sweep_with_buffer!(akk, A, k, inv)
end
A
end

#-----------------------------------------------------------------------------# sweep_with_buffer!
function sweep_with_buffer!(akk::AVec{T}, A::AMat{T}, k::Integer, inv::Bool = false) where {T}
# ensure @inbounds is safe
p = checksquare(A)
Expand All @@ -49,6 +52,14 @@ function sweep_with_buffer!(akk::AVec{T}, A::AMat{T}, k::Integer, inv::Bool = fa
return A
end

# This is slower than unbuffered version??
function sweep_with_buffer!(akk::AVec{T}, A::AMat{T}, ks::AVec{I}, inv::Bool = false) where {T, I <: Integer}
for k in ks
sweep_with_buffer!(akk, A, k, inv)
end
A
end

#-----------------------------------------------------------------------------# setrowcol!
# Set upper triangle of: (A[k, :] = x; A[:, k] = x)
function setrowcol!(A::StridedArray, k, x)
Expand All @@ -62,24 +73,23 @@ setrowcol!(A::Union{Hermitian,Symmetric,UpperTriangular}, k, x) = setrowcol!(A.d
const BlasNumber = Union{LinearAlgebra.BlasFloat, LinearAlgebra.BlasComplex}

# In-place update of (the upper triangle of) A + α * x * x'
function syrk!(A::StridedMatrix{T}, α::T, x::AbstractArray{<:T}) where {T<:BlasNumber}
function syrk!(A::StridedMatrix{T}, α::T, x::AbstractArray{<:T}) where {T <: BlasNumber}
BLAS.syrk!('U', 'N', α, x, one(T), A)
end

function syrk!(A::Hermitian{T, S}, α::T, x::AbstractArray{<:T}) where {T<:BlasNumber, S<:StridedMatrix{T}}
function syrk!(A::Hermitian{T, S}, α::T, x::AbstractArray{<:T}) where {T <: BlasNumber, S <: StridedMatrix{T}}
Hermitian(BLAS.syrk!('U', 'N', α, x, one(T), A.data))
end

function syrk!(A::Symmetric{T, S}, α::T, x::AbstractArray{<:T}) where {T<:BlasNumber, S<:StridedMatrix{T}}
function syrk!(A::Symmetric{T, S}, α::T, x::AbstractArray{<:T}) where {T <: BlasNumber, S <: StridedMatrix{T}}
Symmetric(BLAS.syrk!('U', 'N', α, x, one(T), A.data))
end

function syrk!(A, α, x) # fallback
p = checksquare(A)
for i in 1:p, j in i:p
@inbounds A[i,j] += α * x[i] * x[j]
@inbounds A[i, j] += α * x[i] * x[j]
end
end


end # module
83 changes: 83 additions & 0 deletions test/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using Test, LinearAlgebra, Random, BenchmarkTools, SweepOperator

versioninfo()


#-----------------------------------------------------------------------------# 1) Matrix Inverse
# Code from https://github.com/joshday/SweepOperator.jl/issues/9


"""
inv_by_chol!(A)
Invert a pd matrix `A` in-place by Cholesky decomposition.
"""
function inv_by_chol!(A::Matrix{T}) where T <: LinearAlgebra.BlasReal
LAPACK.potrf!('U', A)
LAPACK.potri!('U', A)
A
end

"""
sweep_block_kernel!(A, krange, invsw)
Perform the block form of sweep on a contiguous range of indices `krange`.
[ Ak̲k̲ Ak̲k Ak̲k̅ ] [ Ak̲k̲-Ak̲k*Akk⁻¹*Akk̲ Ak̲k*Akk⁻¹ Ak̲k̅-Ak̲k*Akk⁻¹*Akk̅ ]
[ - Akk Akk̅ ] -sweep-> [ - -Akk⁻¹ Akk⁻¹*Akk̅ ]
[ - - Ak̅k̅ ] [ - - Ak̅k̅-Ak̅k*Akk⁻¹*Akk̅ ]
"""
function sweep_block_kernel!(
A :: AbstractMatrix{T},
krange :: AbstractUnitRange{<:Integer},
invsw :: Bool = false
) where {T <: LinearAlgebra.BlasReal}
k̲range = 1:(krange[1] - 1)
k̅range = (krange[end]+1):size(A, 2)
Akk = view(A, krange, krange)
Ak̲k = view(A, k̲range, krange)
Ak̲k̲ = view(A, k̲range, k̲range)
Akk̅ = view(A, krange, k̅range)
Ak̅k̅ = view(A, k̅range, k̅range)
Ak̲k̅ = view(A, k̲range, k̅range)
# U = cholesky(Akk).U
U, _ = LAPACK.potrf!('U', Akk)
# Ak̲k = Ak̲k * U⁻¹
BLAS.trsm!('R', 'U', 'N', 'N', one(T), U, Ak̲k)
# Ak̲k̲ = Ak̲k̲ - Ak̲k * U⁻¹ * U⁻ᵀ * Akk̲
BLAS.syrk!('U', 'N', -one(T), Ak̲k, one(T), Ak̲k̲)
# Akk̅ = U⁻ᵀ * Akk̅
BLAS.trsm!('L', 'U', 'T', 'N', one(T), U, Akk̅)
# Ak̲k̅ = Ak̲k̅ - Ak̲k * U⁻¹ * U⁻ᵀ * Akk̅
BLAS.gemm!('N', 'N', -one(T), Ak̲k, Akk̅, one(T), Ak̲k̅)
# Ak̅k̅ = Ak̅k̅ - Ak̅k * U⁻¹ * U⁻ᵀ * Akk̅
BLAS.syrk!('U', 'T', -one(T), Akk̅, one(T), Ak̅k̅)
# Ak̲k = Ak̲k * Akk⁻¹ = Ak̲k * U⁻¹ * U⁻ᵀ
s = ifelse(invsw, -one(T), one(T))
BLAS.trsm!('R', 'U', 'T', 'N', s, Akk, Ak̲k)
# Akk̅ = Akk⁻¹ * Akk̅ = U⁻¹ * U⁻ᵀ * Akk̅
BLAS.trsm!('L', 'U', 'N', 'N', s, Akk, Akk̅)
# Akk = Akk⁻¹ = U⁻ᵀ U⁻¹
LAPACK.potri!('U', U)
UpperTriangular(Akk) .*= -1
A
end

# create an nxn pos-def test matrix
function run_benchmark(n::Int, seed::Int = 123)
Random.seed!(seed)
A = randn(n, n)
A = A'A + I
Ainv = UpperTriangular(inv(A))
@test UpperTriangular(inv_by_chol!(copy(A))) Ainv
@test -UpperTriangular(sweep!(copy(A), 1:n)) Ainv
@test -UpperTriangular(sweep_block!(copy(A), 1:n)) Ainv

out = Dict()

out["Cholesky"] = @benchmark inv_by_chol!(copy(A))
out["Sweep"] = @benchmark sweep!(copy(A), 1:n)
out["Block Sweep"] = @benchmark sweep_block!(copy(A), 1:n)

return out
end

0 comments on commit 65375a5

Please sign in to comment.