Skip to content

Commit

Permalink
Add getindex functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Dec 3, 2021
1 parent 7363170 commit 7b70fc8
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ include("kronecker.jl") # Kronecker product of linear maps
include("fillmap.jl") # linear maps representing constantly filled matrices
include("conversion.jl") # conversion of linear maps to matrices
include("show.jl") # show methods for LinearMap objects
include("getindex.jl") # getindex functionality

"""
LinearMap(A::LinearMap; kwargs...)::WrappedMap
Expand Down
132 changes: 132 additions & 0 deletions src/getindex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# module GetIndex

# using ..LinearMaps: LinearMap, AdjointMap, TransposeMap, FillMap, LinearCombination,
# ScaledMap, UniformScalingMap, WrappedMap

# required in Base.to_indices for [:]-indexing
Base.eachindex(::IndexLinear, A::LinearMap) = (Base.@_inline_meta; Base.OneTo(length(A)))
Base.lastindex(A::LinearMap) = (Base.@_inline_meta; last(eachindex(IndexLinear(), A)))
Base.firstindex(A::LinearMap) = (Base.@_inline_meta; first(eachindex(IndexLinear(), A)))

function Base.checkbounds(A::LinearMap, i, j)
Base.@_inline_meta
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j)))
nothing
end
# Linear indexing is explicitly allowed when there is only one (non-cartesian) index
function Base.checkbounds(A::LinearMap, i)
Base.@_inline_meta
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i))
nothing
end

# main entry point
Base.@propagate_inbounds function Base.getindex(A::LinearMap, I...)
# TODO: introduce some sort of switch?
Base.@_inline_meta
@boundscheck checkbounds(A, I...)
@inbounds _getindex(A, Base.to_indices(A, I)...)
end
# quick pass forward
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ .* getindex(A.lmap, I...)
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...]
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer) = A.lmap[i]
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer, j::Integer) = A.lmap[i,j]

########################
# linear indexing
########################
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer)
Base.@_inline_meta
i1, i2 = Base._ind2sub(axes(A), i)
return _getindex(A, i1, i2)
end
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}) =
[_getindex(A, i) for i in I]
_getindex(A::LinearMap, ::Base.Slice) = vec(Matrix(A))

########################
# Cartesian indexing
########################
Base.@propagate_inbounds _getindex(A::LinearMap, i::Integer, j::Integer) =
_getindex(A, Base.Slice(axes(A)[1]), j)[i]
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer, J::AbstractVector{<:Integer})
try
return (basevec(A, i)'A)[J]
catch
x = zeros(eltype(A), size(A, 2))
y = similar(x, eltype(A), size(A, 1))
r = similar(x, eltype(A), length(J))
for (ind, j) in enumerate(J)
x[j] = one(eltype(A))
_unsafe_mul!(y, A, x)
r[ind] = y[i]
x[j] = zero(eltype(A))
end
return r
end
end
function _getindex(A::LinearMap, i::Integer, J::Base.Slice)
try
return vec(basevec(A, i)'A)
catch
return vec(_getindex(A, i:i, J))
end
end
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}, j::Integer) =
_getindex(A, Base.Slice(axes(A)[1]), j)[I] # = A[:,j][I] w/o bounds check
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*basevec(A, j)
Base.@propagate_inbounds function _getindex(A::LinearMap, Is::Vararg{AbstractVector{<:Integer},2})
shape = Base.index_shape(Is...)
dest = zeros(eltype(A), shape)
I, J = Is
for (ind, ij) in zip(eachindex(dest), Iterators.product(I, J))
i, j = ij
dest[ind] = _getindex(A, i, j)
end
return dest
end
Base.@propagate_inbounds function _getindex(A::LinearMap, I::AbstractVector{<:Integer}, ::Base.Slice)
x = zeros(eltype(A), size(A, 2))
y = similar(x, eltype(A), size(A, 1))
r = similar(x, eltype(A), (length(I), size(A, 2)))
@views for j in axes(A)[2]
x[j] = one(eltype(A))
_unsafe_mul!(y, A, x)
r[:,j] .= y[I]
x[j] = zero(eltype(A))
end
return r
end
Base.@propagate_inbounds function _getindex(A::LinearMap, ::Base.Slice, J::AbstractVector{<:Integer})
x = zeros(eltype(A), size(A, 2))
y = similar(x, eltype(A), (size(A, 1), length(J)))
for (i, j) in enumerate(J)
x[j] = one(eltype(A))
_unsafe_mul!(selectdim(y, 2, i), A, x)
x[j] = zero(eltype(A))
end
return y
end
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = Matrix(A)

# specialized methods
_getindex(A::FillMap, ::Integer, ::Integer) = A.λ
Base.@propagate_inbounds _getindex(A::LinearCombination, i::Integer, j::Integer) =
sum(a -> A.maps[a][i, j], eachindex(A.maps))
Base.@propagate_inbounds _getindex(A::AdjointMap, i::Integer, j::Integer) =
adjoint(A.lmap[j, i])
Base.@propagate_inbounds _getindex(A::TransposeMap, i::Integer, j::Integer) =
transpose(A.lmap[j, i])
_getindex(A::UniformScalingMap, i::Integer, j::Integer) = ifelse(i == j, A.λ, zero(eltype(A)))

# helpers
function basevec(A, i::Integer)
x = zeros(eltype(A), size(A, 2))
@inbounds x[i] = one(eltype(A))
return x
end

nogetindex_error() = error("indexing not allowed for LinearMaps; consider setting `LinearMaps.allowgetindex = true`")

# end # module
67 changes: 67 additions & 0 deletions test/getindex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using BenchmarkTools, LinearAlgebra, LinearMaps, Test
# using LinearMaps.GetIndex

function test_getindex(A::LinearMap, M::AbstractMatrix)
@assert size(A) == size(M)
@test all((A[i,j] == M[i,j] for i in axes(A, 1), j in axes(A, 2)))
@test all((A[i] == M[i] for i in 1:length(A)))
@test A[1,1] == M[1,1]
@test A[:] == M[:]
@test A[1,:] == M[1,:]
@test A[:,1] == M[:,1]
@test A[1:4,:] == M[1:4,:]
@test A[:,1:4] == M[:,1:4]
@test A[1,1:3] == M[1,1:3]
@test A[1:3,1] == M[1:3,1]
@test A[2:end,1] == M[2:end,1]
@test A[1:2,1:3] == M[1:2,1:3]
@test A[[2,1],1:3] == M[[2,1],1:3]
@test A[:,:] == M
@test A[7] == M[7]
@test_throws BoundsError A[firstindex(A)-1]
@test_throws BoundsError A[lastindex(A)+1]
@test_throws BoundsError A[6,1]
@test_throws BoundsError A[1,6]
@test_throws BoundsError A[2,1:6]
@test_throws BoundsError A[1:6,2]
return true
end

@testset "getindex" begin
A = rand(5,5)
L = LinearMap(A)
# @btime getindex($A, i) setup=(i = rand(1:9));
# @btime getindex($L, i) setup=(i = rand(1:9));
# @btime (getindex($A, i, j)) setup=(i = rand(1:3); j = rand(1:3));
# @btime (getindex($L, i, j)) setup=(i = rand(1:3); j = rand(1:3));

struct TwoMap <: LinearMaps.LinearMap{Float64} end
Base.size(::TwoMap) = (5,5)
LinearMaps._getindex(::TwoMap, i::Integer, j::Integer) = 2.0
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x))

@test test_getindex(TwoMap(), fill(2.0, 5, 5))
Base.adjoint(A::TwoMap) = A
@test test_getindex(TwoMap(), fill(2.0, 5, 5))

MA = rand(ComplexF64, 5, 5)
for FA in (
LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5),
LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), 5, 5),
)
@test test_getindex(FA, MA)
@test test_getindex(3FA, 3MA)
@test test_getindex(FA + FA, 2MA)
if !isnothing(FA.fc)
@test test_getindex(transpose(FA), transpose(MA))
@test test_getindex(transpose(3FA), transpose(3MA))
@test test_getindex(3transpose(FA), transpose(3MA))
@test test_getindex(adjoint(FA), adjoint(MA))
@test test_getindex(adjoint(3FA), adjoint(3MA))
@test test_getindex(3adjoint(FA), adjoint(3MA))
end
end

@test test_getindex(FillMap(0.5, (5, 5)), fill(0.5, (5, 5)))
@test test_getindex(LinearMap(0.5I, 5), Matrix(0.5I, 5, 5))
end
64 changes: 32 additions & 32 deletions test/linearmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
end
end

# new type
struct SimpleFunctionMap <: LinearMap{Float64}
f::Function
N::Int
end
struct SimpleComplexFunctionMap <: LinearMap{Complex{Float64}}
f::Function
N::Int
end
Base.size(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}) = (A.N, A.N)
Base.:(*)(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, v::AbstractVector) = A.f(v)
LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, x::AbstractVector) = copyto!(y, *(A, x))

@testset "new LinearMap type" begin
# new type
struct SimpleFunctionMap <: LinearMap{Float64}
f::Function
N::Int
end
struct SimpleComplexFunctionMap <: LinearMap{Complex{Float64}}
f::Function
N::Int
end
Base.size(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}) = (A.N, A.N)
Base.:(*)(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, v::AbstractVector) = A.f(v)
LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, x::AbstractVector) = copyto!(y, *(A, x))

F = SimpleFunctionMap(cumsum, 10)
FC = SimpleComplexFunctionMap(cumsum, 10)
@test @inferred ndims(F) == 2
Expand Down Expand Up @@ -83,27 +83,27 @@ LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFu
@test Fs isa SparseMatrixCSC
end

struct MyFillMap{T} <: LinearMaps.LinearMap{T}
λ::T
size::Dims{2}
function MyFillMap::T, dims::Dims{2}) where {T}
all(d -> d >= 0, dims) || throw(ArgumentError("dims of MyFillMap must be non-negative"))
promote_type(T, typeof(λ)) == T || throw(InexactError())
return new{T}(λ, dims)
@testset "transpose of new LinearMap type" begin
struct MyFillMap{T} <: LinearMaps.LinearMap{T}
λ::T
size::Dims{2}
function MyFillMap::T, dims::Dims{2}) where {T}
all(d -> d >= 0, dims) || throw(ArgumentError("dims of MyFillMap must be non-negative"))
promote_type(T, typeof(λ)) == T || throw(InexactError())
return new{T}(λ, dims)
end
end
Base.size(A::MyFillMap) = A.size
function LinearAlgebra.mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
LinearMaps.check_dim_mul(y, A, x)
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
end
function LinearAlgebra.mul!(y::AbstractVecOrMat, transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap}, x::AbstractVector)
LinearMaps.check_dim_mul(y, transA, x)
λ = transA.lmap.λ
return fill!(y, iszero(λ) ? zero(eltype(y)) : transpose(λ)*sum(x))
end
end
Base.size(A::MyFillMap) = A.size
function LinearAlgebra.mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
LinearMaps.check_dim_mul(y, A, x)
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
end
function LinearAlgebra.mul!(y::AbstractVecOrMat, transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap}, x::AbstractVector)
LinearMaps.check_dim_mul(y, transA, x)
λ = transA.lmap.λ
return fill!(y, iszero(λ) ? zero(eltype(y)) : transpose(λ)*sum(x))
end

@testset "transpose of new LinearMap type" begin
A = MyFillMap(5.0, (3, 3))
x = ones(3)
@test A * x == fill(15.0, 3)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ include("left.jl")
include("fillmap.jl")

include("nontradaxes.jl")

include("getindex.jl")

0 comments on commit 7b70fc8

Please sign in to comment.