Skip to content

Commit

Permalink
Generalize istriu/istril to accept a band index (#590)
Browse files Browse the repository at this point in the history
Currently, only `istriu(S)` and `istril(S)` are specialized for sparse matrices, and this PR generalizes these to accept the band index `k`. This improves performance.
  • Loading branch information
jishnub authored Dec 5, 2024
1 parent 780c4de commit 4fd3aad
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4141,7 +4141,7 @@ function is_hermsym(A::AbstractSparseMatrixCSC, check::Function)
return true
end

function istriu(A::AbstractSparseMatrixCSC)
function istriu(A::AbstractSparseMatrixCSC, k::Integer=0)
m, n = size(A)
colptr = getcolptr(A)
rowval = rowvals(A)
Expand All @@ -4150,7 +4150,8 @@ function istriu(A::AbstractSparseMatrixCSC)
for col = 1:min(n, m-1)
l1 = colptr[col+1]-1
for i = 0 : (l1 - colptr[col])
if rowval[l1-i] <= col
if rowval[l1-i] <= col - k
# rows preceeding the index would also lie above the band
break
end
if _isnotzero(nzval[l1-i])
Expand All @@ -4161,15 +4162,16 @@ function istriu(A::AbstractSparseMatrixCSC)
return true
end

function istril(A::AbstractSparseMatrixCSC)
function istril(A::AbstractSparseMatrixCSC, k::Integer=0)
m, n = size(A)
colptr = getcolptr(A)
rowval = rowvals(A)
nzval = nonzeros(A)

for col = 2:n
for i = colptr[col] : (colptr[col+1]-1)
if rowval[i] >= col
if rowval[i] >= col - k
# subsequent rows would also lie below the band
break
end
if _isnotzero(nzval[i])
Expand Down
13 changes: 13 additions & 0 deletions test/sparsematrix_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -626,4 +626,17 @@ end
@test_throws ArgumentError copytrito!(M, S, 'M')
end

@testset "istriu/istril" begin
for T in Any[Tridiagonal(1:3, 1:4, 1:3),
Bidiagonal(1:4, 1:3, :U), Bidiagonal(1:4, 1:3, :L),
Diagonal(1:4),
diagm(-2=>1:2, 2=>1:2)]
S = sparse(T)
for k in -5:5
@test istriu(S, k) == istriu(T, k)
@test istril(S, k) == istril(T, k)
end
end
end

end # module

0 comments on commit 4fd3aad

Please sign in to comment.