Skip to content

Commit

Permalink
add logical and diagonal indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed May 23, 2022
1 parent 319b918 commit d946e3a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
31 changes: 31 additions & 0 deletions src/getindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

const Indexer = AbstractVector{<:Integer}

Base.IndexStyle(::LinearMap) = IndexCartesian()
# required in Base.to_indices for [:]-indexing
Base.eachindex(::IndexLinear, A::LinearMap) = Base.OneTo(length(A))
Base.lastindex(A::LinearMap) = last(eachindex(IndexLinear(), A))
Expand All @@ -19,6 +20,10 @@ function Base.checkbounds(A::LinearMap, i)
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i))
nothing
end
# checkbounds in indexing via CartesianIndex
Base.checkbounds(A::LinearMap, i::Union{CartesianIndex{2}, AbstractVecOrMat{CartesianIndex{2}}}) =
Base.checkbounds_indices(Bool, axes(A), (i,))
Base.checkbounds(A::LinearMap, I::AbstractArray{Bool,2}) = axes(A) == axes(I)

# main entry point
function Base.getindex(A::LinearMap, I...)
Expand All @@ -41,6 +46,7 @@ function _getindex(A::LinearMap, i::Integer)
end
_getindex(A::LinearMap, I::Indexer) = [_getindex(A, i) for i in I]
_getindex(A::LinearMap, ::Base.Slice) = vec(Matrix(A))
_getindex(A::LinearMap, I::Vector{CartesianIndex{2}}) = [(@inbounds A[i]) for i in I]

########################
# Cartesian indexing
Expand Down Expand Up @@ -130,6 +136,31 @@ end
@inline _copycol!(dest, ind, temp, I::Indexer) =
(@views @inbounds dest[:,ind] .= temp[I])

# diagonal indexing
function LinearAlgebra.diagind(A::LinearMap, k::Integer=0)
require_one_based_indexing(A)
diagind(size(A,1), size(A,2), k)
end

LinearAlgebra.diag(A::LinearMap, k::Integer=0) = A[diagind(A,k)]

# logical indexing
Base.getindex(A::LinearMap, mask::AbstractVecOrMat{Bool}) = A[findall(mask)]
Base.getindex(A::LinearMap, i, mask::AbstractVector{Bool}) = A[i, findall(mask)]
Base.getindex(A::LinearMap, mask::AbstractVector{Bool}, j) = A[findall(mask), j]
Base.getindex(A::LinearMap, im::AbstractVector{Bool}, jm::AbstractVector{Bool}) =
A[findall(im), findall(jm)]
# disambiguation
for typ in (:WrappedMap, :ScaledMap)
@eval begin
Base.getindex(A::$typ, mask::AbstractVecOrMat{Bool}) = A[findall(mask)]
Base.getindex(A::$typ, i, mask::AbstractVector{Bool}) = A[i, findall(mask)]
Base.getindex(A::$typ, mask::AbstractVector{Bool}, j) = A[findall(mask), j]
Base.getindex(A::$typ, im::AbstractVector{Bool}, jm::AbstractVector{Bool}) =
A[findall(im), findall(jm)]
end
end

# nogetindex_error() = error("indexing not allowed for LinearMaps; consider setting `LinearMaps.allowgetindex = true`")

# end # module
11 changes: 8 additions & 3 deletions test/getindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using LinearAlgebra, LinearMaps, Test

function test_getindex(A::LinearMap, M::AbstractMatrix)
@assert size(A) == size(M)
mask = rand(Bool, size(A))
@test all((A[i,j] == M[i,j] for i in axes(A, 1), j in axes(A, 2)))
@test all((A[i] == M[i] for i in 1:length(A)))
@test A[1,1] == M[1,1]
Expand All @@ -20,6 +21,10 @@ function test_getindex(A::LinearMap, M::AbstractMatrix)
@test A[:,:] == M
@test A[7] == M[7]
@test A[3:7] == M[3:7]
@test (lastindex(A, 1), lastindex(A, 2)) == size(A)
@test diagind(A) == diagind(M)
for k in -1:1; @test diag(A, k) == diag(M, k) end
@test A[mask] == M[mask]
@test_throws BoundsError A[firstindex(A)-1]
@test_throws BoundsError A[lastindex(A)+1]
@test_throws BoundsError A[6,1]
Expand All @@ -30,9 +35,9 @@ function test_getindex(A::LinearMap, M::AbstractMatrix)
end

@testset "getindex" begin
A = rand(4,6)
L = LinearMap(A)
@test test_getindex(L, A)
M = rand(4,6)
A = LinearMap(M)
@test test_getindex(A, M)
# @btime getindex($A, i) setup=(i = rand(1:24));
# @btime getindex($L, i) setup=(i = rand(1:24));
# @btime (getindex($A, i, j)) setup=(i = rand(1:4); j = rand(1:6));
Expand Down

0 comments on commit d946e3a

Please sign in to comment.