Skip to content

Commit

Permalink
Add getindex functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Dec 2, 2021
1 parent e05375f commit 7e4d6e2
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ include("kronecker.jl") # Kronecker product of linear maps
include("fillmap.jl") # linear maps representing constantly filled matrices
include("conversion.jl") # conversion of linear maps to matrices
include("show.jl") # show methods for LinearMap objects
include("getindex.jl") # getindex functionality

"""
LinearMap(A::LinearMap; kwargs...)::WrappedMap
Expand Down
161 changes: 161 additions & 0 deletions src/getindex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# module GetIndex

# using ..LinearMaps: LinearMap, AdjointMap, TransposeMap, FillMap, LinearCombination,
# ScaledMap, UniformScalingMap, WrappedMap

# required in Base.to_indices for [:]-indexing
Base.eachindex(::IndexLinear, A::LinearMap) = (Base.@_inline_meta; Base.oneto(length(A)))
# Base.IndexStyle(::LinearMap) = IndexCartesian()
# Base.IndexStyle(A::Union{WrappedMap,AdjointMap,TransposeMap,ScaledMap}) = IndexStyle(A.lmap)

function Base.checkbounds(A::LinearMap, i, j)
Base.@_inline_meta
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j)))
nothing
end
# Linear indexing is explicitly allowed when there is only one (non-cartesian) index
function Base.checkbounds(A::LinearMap, i)
Base.@_inline_meta
Base.checkindex(Bool, Base.oneto(length(A)), i) || throw(BoundsError(A, i))
nothing
end

# dispatch hierarchy
# Base.getindex (includes bounds checking)
# -> Base._getindex (conversion of linear indices to cartesian indices)
# -> _unsafe_getindex
# main entry point
Base.@propagate_inbounds function Base.getindex(A::LinearMap, I...)
# TODO: introduce some sort of switch?
Base.@_inline_meta
@boundscheck checkbounds(A, I...)
_getindex(A, Base.to_indices(A, I)...)
end
# quick pass forward
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ .* getindex(A.lmap, I...)
Base.@propagate_inbounds Base.getindex(A::AdjointMap, i::Integer) =
adjoint(A.lmap[i-1+first(axes(A.lmap)[1])])
Base.@propagate_inbounds Base.getindex(A::AdjointMap, i::Integer, j::Integer) =
adjoint(A.lmap[j, i])
Base.@propagate_inbounds Base.getindex(A::TransposeMap, i::Integer) =
transpose(A.lmap[i-1+first(axes(A.lmap)[1])])
Base.@propagate_inbounds Base.getindex(A::TransposeMap, i::Integer, j::Integer) =
transpose(A.lmap[j, i])
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...]
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer) = A.lmap[i]
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer, j::Integer) = A.lmap[i,j]

# Base._getindex, IndexLinear
# Base.@propagate_inbounds Base._getindex(::IndexLinear, A::LinearMap, i::Integer) = _unsafe_getindex(A, i)
# Base.@propagate_inbounds function Base._getindex(::IndexLinear, A::LinearMap, i::Integer, j::Integer)
# Base.@_inline_meta
# # @boundscheck checkbounds(A, i, j)
# return _unsafe_getindex(A, Base._sub2ind(axes(A), i, j))
# end
# Base._getindex, IndexCartesian
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer)
Base.@_inline_meta
@boundscheck checkbounds(A, i)
i1, i2 = Base._ind2sub(axes(A), i)
@inbounds r = _unsafe_getindex(A, i1, i2)
return r
end
Base.@propagate_inbounds _getindex(A::LinearMap, i::Integer, j::Integer) =
_unsafe_getindex(A, i, j)

########################
# scalar indexing
########################
# fallback via colon-based method
Base.@propagate_inbounds _unsafe_getindex(A::LinearMap, i::Integer, j::Integer) =
(Base.@_inline_meta; _getindex(A, Base.Slice(axes(A)[1]), j)[i])
# specialized methods
_unsafe_getindex(A::FillMap, ::Integer, ::Integer) = A.λ
Base.@propagate_inbounds _unsafe_getindex(A::LinearCombination, i::Integer, j::Integer) =
sum(a -> getindex(A.maps[a], i, j), eachindex(A.maps))
_unsafe_getindex(A::UniformScalingMap, i::Integer, j::Integer) =
ifelse(i == j, A.λ, zero(eltype(A)))

########################
# multidimensional slicing
########################
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer, J::AbstractVector{<:Integer})
try
return (basevec(A, i)'A)[J]
catch
x = zeros(eltype(A), size(A, 2))
y = similar(x, eltype(A), size(A, 1))
r = similar(x, eltype(A), length(J))
@inbounds for (ind, j) in enumerate(J)
x[j] = one(eltype(A))
_unsafe_mul!(y, A, x)
r[ind] = y[i]
x[j] = zero(eltype(A))
end
return r
end
end
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}, j::Integer) =
(Base.@_inline_meta; _getindex(A, Base.Slice(axes(A)[1]), j)[I])
Base.@propagate_inbounds function _getindex(A::LinearMap, Is::Vararg{AbstractVector{<:Integer},2})
shape = Base.index_shape(Is...)
dest = zeros(eltype(A), shape)
I, J = Is
for (ind, ij) in zip(eachindex(dest), Iterators.product(I, J))
i, j = ij
dest[ind] = _unsafe_getindex(A, i, j)
end
return dest
end
Base.@propagate_inbounds function _getindex(A::LinearMap, I::AbstractVector{<:Integer})
dest = Vector{eltype(A)}(undef, length(I))
for i in eachindex(dest, I)
dest[i] = _getindex(A, I[i])
end
return dest
end
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = Matrix(A)
_getindex(A::LinearMap, ::Base.Slice) = vec(Matrix(A))
function _getindex(A::LinearMap, i::Integer, J::Base.Slice)
try
return vec(basevec(A, i)'A)
catch
return vec(_getindex(A, i:i, J))
end
end
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*basevec(A, j)
# Needs to be defined for custom LinearMap subtypes
# Base.@propagate_inbounds function _unsafe_getindex(A::CustomMap, i::Union{Integer,AbstractVector{<:Integer}})
function _getindex(A::LinearMap, I::AbstractVector{<:Integer}, ::Base.Slice)
x = zeros(eltype(A), size(A, 2))
y = similar(x, eltype(A), size(A, 1))
r = similar(x, eltype(A), (length(I), size(A, 2)))
@views @inbounds for j in axes(A)[2]
x[j] = one(eltype(A))
_unsafe_mul!(y, A, x)
r[:,j] .= y[I]
x[j] = zero(eltype(A))
end
return r
end
function _getindex(A::LinearMap, ::Base.Slice, J::AbstractVector{<:Integer})
x = zeros(eltype(A), size(A, 2))
y = similar(x, eltype(A), (size(A, 1), length(J)))
@inbounds for (i, j) in enumerate(J)
x[j] = one(eltype(A))
_unsafe_mul!(selectdim(y, 2, i), A, x)
x[j] = zero(eltype(A))
end
return y
end

# helpers
function basevec(A, i::Integer)
x = zeros(eltype(A), size(A, 2))
@inbounds x[i] = one(eltype(A))
return x
end

nogetindex_error() = error("indexing not allowed for LinearMaps; consider setting `LinearMaps.allowgetindex = true`")

# end # module
73 changes: 73 additions & 0 deletions test/getindex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using BenchmarkTools, LinearAlgebra, LinearMaps, Test
# using LinearMaps.GetIndex

struct TwoMap <: LinearMaps.LinearMap{Float64} end
Base.size(::TwoMap) = (5,5)
Base.IndexStyle(::TwoMap) = IndexLinear()
LinearMaps._unsafe_getindex(::TwoMap, i::Integer) = 2.0
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x))

@testset "getindex" begin
A = rand(3,3)
L = LinearMap(A)
@test all((L[i,j] == A[i,j] for i in 1:3, j in 1:3))
@test all((L[i] == A[i] for i in 1:9))
@test L[1,:] == A[1,:]
@btime getindex($A, i) setup=(i = rand(1:9))
@btime getindex($L, i) setup=(i = rand(1:9))
@btime (getindex($A, i, j)) setup=(i = rand(1:3); j = rand(1:3))
@btime (getindex($L, i, j)) setup=(i = rand(1:3); j = rand(1:3))

@testset "minifillmap" begin
T = TwoMap()
@test T[1,1] == 2.0
@test T[:,1] == fill(2.0, 5)
@test T[1,:] == fill(2.0, 5)
@test T[2:3,:] == fill(2.0, 2, 5)
@test T[:,2:3] == fill(2.0, 5, 2)
@test T[2:3,3] == fill(2.0, 2)
@test T[2,2:3] == fill(2.0, 2)
@test_throws BoundsError T[6,1]
@test_throws BoundsError T[1,6]
@test_throws BoundsError T[2,1:6]
@test_throws BoundsError T[1:6,2]
@test_throws BoundsError T[0]
@test_throws BoundsError T[26]

Base.adjoint(A::TwoMap) = A
@test T[1,1] == 2.0
@test T[:,1] == fill(2.0, 5)
@test T[1,:] == fill(2.0, 5)
@test T[2:3,:] == fill(2.0, 2, 5)
@test T[:,2:3] == fill(2.0, 5, 2)
@test T[2:3,3] == fill(2.0, 2)
@test T[2,2:3] == fill(2.0, 2)
@test_throws BoundsError T[6,1]
@test_throws BoundsError T[1,6]
@test_throws BoundsError T[2,1:6]
@test_throws BoundsError T[1:6,2]
@test_throws BoundsError T[0]
@test_throws BoundsError T[26]
end

@testset "function wrap around matrix" begin
MA = rand(ComplexF64, 5, 5)
FA = LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5)
for transform in (identity, transpose, adjoint), (A, F) in ((MA, FA), (3MA, 3FA))
@test transform(F)[1,1] transform(A)[1,1]
@test transform(F)[:] transform(A)[:]
@test transform(F)[1,:] transform(A)[1,:]
@test transform(F)[:,1] transform(A)[:,1]
@test transform(F)[1:4,:] transform(A)[1:4,:]
@test transform(F)[:,1:4] transform(A)[:,1:4]
@test transform(F)[1,1:3] transform(A)[1,1:3]
@test transform(F)[1:3,1] transform(A)[1:3,1]
@test transform(F)[1:2,1:3] transform(A)[1:2,1:3]
@test transform(F)[[2,1],1:3] transform(A)[[2,1],1:3]
@test transform(F)[:,:] transform(A)
@test transform(F)[7] transform(A)[7]
@test_throws BoundsError transform(F)[0]
@test_throws BoundsError transform(F)[26]
end
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ include("fillmap.jl")
if VERSION v"1.1"
include("nontradaxes.jl")
end

include("getindex.jl")

0 comments on commit 7e4d6e2

Please sign in to comment.