-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
234 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# module GetIndex | ||
|
||
# using ..LinearMaps: LinearMap, AdjointMap, TransposeMap, FillMap, LinearCombination, | ||
# ScaledMap, UniformScalingMap, WrappedMap | ||
|
||
# required in Base.to_indices for [:]-indexing | ||
Base.eachindex(::IndexLinear, A::LinearMap) = (Base.@_inline_meta; Base.OneTo(length(A))) | ||
Base.lastindex(A::LinearMap) = (Base.@_inline_meta; last(eachindex(IndexLinear(), A))) | ||
Base.firstindex(A::LinearMap) = (Base.@_inline_meta; first(eachindex(IndexLinear(), A))) | ||
|
||
function Base.checkbounds(A::LinearMap, i, j) | ||
Base.@_inline_meta | ||
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j))) | ||
nothing | ||
end | ||
# Linear indexing is explicitly allowed when there is only one (non-cartesian) index | ||
function Base.checkbounds(A::LinearMap, i) | ||
Base.@_inline_meta | ||
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i)) | ||
nothing | ||
end | ||
|
||
# main entry point | ||
Base.@propagate_inbounds function Base.getindex(A::LinearMap, I...) | ||
# TODO: introduce some sort of switch? | ||
Base.@_inline_meta | ||
@boundscheck checkbounds(A, I...) | ||
@inbounds _getindex(A, Base.to_indices(A, I)...) | ||
end | ||
# quick pass forward | ||
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ .* getindex(A.lmap, I...) | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...] | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer) = A.lmap[i] | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer, j::Integer) = A.lmap[i,j] | ||
|
||
######################## | ||
# linear indexing | ||
######################## | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer) | ||
Base.@_inline_meta | ||
i1, i2 = Base._ind2sub(axes(A), i) | ||
return _getindex(A, i1, i2) | ||
end | ||
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}) = | ||
[_getindex(A, i) for i in I] | ||
_getindex(A::LinearMap, ::Base.Slice) = vec(Matrix(A)) | ||
|
||
######################## | ||
# Cartesian indexing | ||
######################## | ||
Base.@propagate_inbounds _getindex(A::LinearMap, i::Integer, j::Integer) = | ||
_getindex(A, Base.Slice(axes(A)[1]), j)[i] | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer, J::AbstractVector{<:Integer}) | ||
try | ||
return (basevec(A, i)'A)[J] | ||
catch | ||
x = zeros(eltype(A), size(A, 2)) | ||
y = similar(x, eltype(A), size(A, 1)) | ||
r = similar(x, eltype(A), length(J)) | ||
for (ind, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(y, A, x) | ||
r[ind] = y[i] | ||
x[j] = zero(eltype(A)) | ||
end | ||
return r | ||
end | ||
end | ||
function _getindex(A::LinearMap, i::Integer, J::Base.Slice) | ||
try | ||
return vec(basevec(A, i)'A) | ||
catch | ||
return vec(_getindex(A, i:i, J)) | ||
end | ||
end | ||
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}, j::Integer) = | ||
_getindex(A, Base.Slice(axes(A)[1]), j)[I] # = A[:,j][I] w/o bounds check | ||
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*basevec(A, j) | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, Is::Vararg{AbstractVector{<:Integer},2}) | ||
shape = Base.index_shape(Is...) | ||
dest = zeros(eltype(A), shape) | ||
I, J = Is | ||
for (ind, ij) in zip(eachindex(dest), Iterators.product(I, J)) | ||
i, j = ij | ||
dest[ind] = _getindex(A, i, j) | ||
end | ||
return dest | ||
end | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, I::AbstractVector{<:Integer}, ::Base.Slice) | ||
x = zeros(eltype(A), size(A, 2)) | ||
y = similar(x, eltype(A), size(A, 1)) | ||
r = similar(x, eltype(A), (length(I), size(A, 2))) | ||
@views for j in axes(A)[2] | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(y, A, x) | ||
r[:,j] .= y[I] | ||
x[j] = zero(eltype(A)) | ||
end | ||
return r | ||
end | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, ::Base.Slice, J::AbstractVector{<:Integer}) | ||
x = zeros(eltype(A), size(A, 2)) | ||
y = similar(x, eltype(A), (size(A, 1), length(J))) | ||
for (i, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(selectdim(y, 2, i), A, x) | ||
x[j] = zero(eltype(A)) | ||
end | ||
return y | ||
end | ||
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = Matrix(A) | ||
|
||
# specialized methods | ||
_getindex(A::FillMap, ::Integer, ::Integer) = A.λ | ||
Base.@propagate_inbounds _getindex(A::LinearCombination, i::Integer, j::Integer) = | ||
sum(a -> A.maps[a][i, j], eachindex(A.maps)) | ||
Base.@propagate_inbounds _getindex(A::AdjointMap, i::Integer, j::Integer) = | ||
adjoint(A.lmap[j, i]) | ||
Base.@propagate_inbounds _getindex(A::TransposeMap, i::Integer, j::Integer) = | ||
transpose(A.lmap[j, i]) | ||
_getindex(A::UniformScalingMap, i::Integer, j::Integer) = ifelse(i == j, A.λ, zero(eltype(A))) | ||
|
||
# helpers | ||
function basevec(A, i::Integer) | ||
x = zeros(eltype(A), size(A, 2)) | ||
@inbounds x[i] = one(eltype(A)) | ||
return x | ||
end | ||
|
||
nogetindex_error() = error("indexing not allowed for LinearMaps; consider setting `LinearMaps.allowgetindex = true`") | ||
|
||
# end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
using BenchmarkTools, LinearAlgebra, LinearMaps, Test | ||
# using LinearMaps.GetIndex | ||
|
||
function test_getindex(A::LinearMap, M::AbstractMatrix) | ||
@assert size(A) == size(M) | ||
@test all((A[i,j] == M[i,j] for i in axes(A, 1), j in axes(A, 2))) | ||
@test all((A[i] == M[i] for i in 1:length(A))) | ||
@test A[1,1] == M[1,1] | ||
@test A[:] == M[:] | ||
@test A[1,:] == M[1,:] | ||
@test A[:,1] == M[:,1] | ||
@test A[1:4,:] == M[1:4,:] | ||
@test A[:,1:4] == M[:,1:4] | ||
@test A[1,1:3] == M[1,1:3] | ||
@test A[1:3,1] == M[1:3,1] | ||
@test A[2:end,1] == M[2:end,1] | ||
@test A[1:2,1:3] == M[1:2,1:3] | ||
@test A[[2,1],1:3] == M[[2,1],1:3] | ||
@test A[:,:] == M | ||
@test A[7] == M[7] | ||
@test_throws BoundsError A[firstindex(A)-1] | ||
@test_throws BoundsError A[lastindex(A)+1] | ||
@test_throws BoundsError A[6,1] | ||
@test_throws BoundsError A[1,6] | ||
@test_throws BoundsError A[2,1:6] | ||
@test_throws BoundsError A[1:6,2] | ||
return true | ||
end | ||
|
||
@testset "getindex" begin | ||
A = rand(5,5) | ||
L = LinearMap(A) | ||
# @btime getindex($A, i) setup=(i = rand(1:9)); | ||
# @btime getindex($L, i) setup=(i = rand(1:9)); | ||
# @btime (getindex($A, i, j)) setup=(i = rand(1:3); j = rand(1:3)); | ||
# @btime (getindex($L, i, j)) setup=(i = rand(1:3); j = rand(1:3)); | ||
|
||
struct TwoMap <: LinearMaps.LinearMap{Float64} end | ||
Base.size(::TwoMap) = (5,5) | ||
LinearMaps._getindex(::TwoMap, i::Integer, j::Integer) = 2.0 | ||
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x)) | ||
|
||
@test test_getindex(TwoMap(), fill(2.0, 5, 5)) | ||
Base.adjoint(A::TwoMap) = A | ||
@test test_getindex(TwoMap(), fill(2.0, 5, 5)) | ||
|
||
MA = rand(ComplexF64, 5, 5) | ||
for FA in ( | ||
LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5), | ||
LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), 5, 5), | ||
) | ||
@test test_getindex(FA, MA) | ||
@test test_getindex(3FA, 3MA) | ||
@test test_getindex(FA + FA, 2MA) | ||
if !isnothing(FA.fc) | ||
@test test_getindex(transpose(FA), transpose(MA)) | ||
@test test_getindex(transpose(3FA), transpose(3MA)) | ||
@test test_getindex(3transpose(FA), transpose(3MA)) | ||
@test test_getindex(adjoint(FA), adjoint(MA)) | ||
@test test_getindex(adjoint(3FA), adjoint(3MA)) | ||
@test test_getindex(3adjoint(FA), adjoint(3MA)) | ||
end | ||
end | ||
|
||
@test test_getindex(FillMap(0.5, (5, 5)), fill(0.5, (5, 5))) | ||
@test test_getindex(LinearMap(0.5I, 5), Matrix(0.5I, 5, 5)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,3 +34,5 @@ include("left.jl") | |
include("fillmap.jl") | ||
|
||
include("nontradaxes.jl") | ||
|
||
include("getindex.jl") |