-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
237 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# module GetIndex | ||
|
||
# using ..LinearMaps: LinearMap, AdjointMap, TransposeMap, FillMap, LinearCombination, | ||
# ScaledMap, UniformScalingMap, WrappedMap | ||
|
||
# required in Base.to_indices for [:]-indexing | ||
Base.eachindex(::IndexLinear, A::LinearMap) = (Base.@_inline_meta; Base.oneto(length(A))) | ||
# Base.IndexStyle(::LinearMap) = IndexCartesian() | ||
# Base.IndexStyle(A::Union{WrappedMap,AdjointMap,TransposeMap,ScaledMap}) = IndexStyle(A.lmap) | ||
|
||
function Base.checkbounds(A::LinearMap, i, j) | ||
Base.@_inline_meta | ||
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j))) | ||
nothing | ||
end | ||
# Linear indexing is explicitly allowed when there is only one (non-cartesian) index | ||
function Base.checkbounds(A::LinearMap, i) | ||
Base.@_inline_meta | ||
Base.checkindex(Bool, Base.oneto(length(A)), i) || throw(BoundsError(A, i)) | ||
nothing | ||
end | ||
|
||
# dispatch hierarchy | ||
# Base.getindex (includes bounds checking) | ||
# -> Base._getindex (conversion of linear indices to cartesian indices) | ||
# -> _unsafe_getindex | ||
# main entry point | ||
Base.@propagate_inbounds function Base.getindex(A::LinearMap, I...) | ||
# TODO: introduce some sort of switch? | ||
Base.@_inline_meta | ||
@boundscheck checkbounds(A, I...) | ||
_getindex(A, Base.to_indices(A, I)...) | ||
end | ||
# quick pass forward | ||
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ .* getindex(A.lmap, I...) | ||
Base.@propagate_inbounds Base.getindex(A::AdjointMap, i::Integer) = | ||
adjoint(A.lmap[i-1+first(axes(A.lmap)[1])]) | ||
Base.@propagate_inbounds Base.getindex(A::AdjointMap, i::Integer, j::Integer) = | ||
adjoint(A.lmap[j, i]) | ||
Base.@propagate_inbounds Base.getindex(A::TransposeMap, i::Integer) = | ||
transpose(A.lmap[i-1+first(axes(A.lmap)[1])]) | ||
Base.@propagate_inbounds Base.getindex(A::TransposeMap, i::Integer, j::Integer) = | ||
transpose(A.lmap[j, i]) | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...] | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer) = A.lmap[i] | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer, j::Integer) = A.lmap[i,j] | ||
|
||
# Base._getindex, IndexLinear | ||
# Base.@propagate_inbounds Base._getindex(::IndexLinear, A::LinearMap, i::Integer) = _unsafe_getindex(A, i) | ||
# Base.@propagate_inbounds function Base._getindex(::IndexLinear, A::LinearMap, i::Integer, j::Integer) | ||
# Base.@_inline_meta | ||
# # @boundscheck checkbounds(A, i, j) | ||
# return _unsafe_getindex(A, Base._sub2ind(axes(A), i, j)) | ||
# end | ||
# Base._getindex, IndexCartesian | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer) | ||
Base.@_inline_meta | ||
@boundscheck checkbounds(A, i) | ||
i1, i2 = Base._ind2sub(axes(A), i) | ||
@inbounds r = _unsafe_getindex(A, i1, i2) | ||
return r | ||
end | ||
Base.@propagate_inbounds _getindex(A::LinearMap, i::Integer, j::Integer) = | ||
_unsafe_getindex(A, i, j) | ||
|
||
######################## | ||
# scalar indexing | ||
######################## | ||
# fallback via colon-based method | ||
Base.@propagate_inbounds _unsafe_getindex(A::LinearMap, i::Integer, j::Integer) = | ||
(Base.@_inline_meta; _getindex(A, Base.Slice(axes(A)[1]), j)[i]) | ||
# specialized methods | ||
_unsafe_getindex(A::FillMap, ::Integer, ::Integer) = A.λ | ||
Base.@propagate_inbounds _unsafe_getindex(A::LinearCombination, i::Integer, j::Integer) = | ||
sum(a -> getindex(A.maps[a], i, j), eachindex(A.maps)) | ||
_unsafe_getindex(A::UniformScalingMap, i::Integer, j::Integer) = | ||
ifelse(i == j, A.λ, zero(eltype(A))) | ||
|
||
######################## | ||
# multidimensional slicing | ||
######################## | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer, J::AbstractVector{<:Integer}) | ||
try | ||
return (basevec(A, i)'A)[J] | ||
catch | ||
x = zeros(eltype(A), size(A, 2)) | ||
y = similar(x, eltype(A), size(A, 1)) | ||
r = similar(x, eltype(A), length(J)) | ||
@inbounds for (ind, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(y, A, x) | ||
r[ind] = y[i] | ||
x[j] = zero(eltype(A)) | ||
end | ||
return r | ||
end | ||
end | ||
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}, j::Integer) = | ||
(Base.@_inline_meta; _getindex(A, Base.Slice(axes(A)[1]), j)[I]) | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, Is::Vararg{AbstractVector{<:Integer},2}) | ||
shape = Base.index_shape(Is...) | ||
dest = zeros(eltype(A), shape) | ||
I, J = Is | ||
for (ind, ij) in zip(eachindex(dest), Iterators.product(I, J)) | ||
i, j = ij | ||
dest[ind] = _unsafe_getindex(A, i, j) | ||
end | ||
return dest | ||
end | ||
Base.@propagate_inbounds function _getindex(A::LinearMap, I::AbstractVector{<:Integer}) | ||
dest = Vector{eltype(A)}(undef, length(I)) | ||
for i in eachindex(dest, I) | ||
dest[i] = _getindex(A, I[i]) | ||
end | ||
return dest | ||
end | ||
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = Matrix(A) | ||
_getindex(A::LinearMap, ::Base.Slice) = vec(Matrix(A)) | ||
function _getindex(A::LinearMap, i::Integer, J::Base.Slice) | ||
try | ||
return vec(basevec(A, i)'A) | ||
catch | ||
return vec(_getindex(A, i:i, J)) | ||
end | ||
end | ||
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*basevec(A, j) | ||
# Needs to be defined for custom LinearMap subtypes | ||
# Base.@propagate_inbounds function _unsafe_getindex(A::CustomMap, i::Union{Integer,AbstractVector{<:Integer}}) | ||
function _getindex(A::LinearMap, I::AbstractVector{<:Integer}, ::Base.Slice) | ||
x = zeros(eltype(A), size(A, 2)) | ||
y = similar(x, eltype(A), size(A, 1)) | ||
r = similar(x, eltype(A), (length(I), size(A, 2))) | ||
@views @inbounds for j in axes(A)[2] | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(y, A, x) | ||
r[:,j] .= y[I] | ||
x[j] = zero(eltype(A)) | ||
end | ||
return r | ||
end | ||
function _getindex(A::LinearMap, ::Base.Slice, J::AbstractVector{<:Integer}) | ||
x = zeros(eltype(A), size(A, 2)) | ||
y = similar(x, eltype(A), (size(A, 1), length(J))) | ||
@inbounds for (i, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(selectdim(y, 2, i), A, x) | ||
x[j] = zero(eltype(A)) | ||
end | ||
return y | ||
end | ||
|
||
# helpers | ||
function basevec(A, i::Integer) | ||
x = zeros(eltype(A), size(A, 2)) | ||
@inbounds x[i] = one(eltype(A)) | ||
return x | ||
end | ||
|
||
nogetindex_error() = error("indexing not allowed for LinearMaps; consider setting `LinearMaps.allowgetindex = true`") | ||
|
||
# end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
using BenchmarkTools, LinearAlgebra, LinearMaps, Test | ||
# using LinearMaps.GetIndex | ||
|
||
struct TwoMap <: LinearMaps.LinearMap{Float64} end | ||
Base.size(::TwoMap) = (5,5) | ||
Base.IndexStyle(::TwoMap) = IndexLinear() | ||
LinearMaps._unsafe_getindex(::TwoMap, i::Integer) = 2.0 | ||
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x)) | ||
|
||
@testset "getindex" begin | ||
A = rand(3,3) | ||
L = LinearMap(A) | ||
@test all((L[i,j] == A[i,j] for i in 1:3, j in 1:3)) | ||
@test all((L[i] == A[i] for i in 1:9)) | ||
@test L[1,:] == A[1,:] | ||
@btime getindex($A, i) setup=(i = rand(1:9)) | ||
@btime getindex($L, i) setup=(i = rand(1:9)) | ||
@btime (getindex($A, i, j)) setup=(i = rand(1:3); j = rand(1:3)) | ||
@btime (getindex($L, i, j)) setup=(i = rand(1:3); j = rand(1:3)) | ||
|
||
@testset "minifillmap" begin | ||
T = TwoMap() | ||
@test T[1,1] == 2.0 | ||
@test T[:,1] == fill(2.0, 5) | ||
@test T[1,:] == fill(2.0, 5) | ||
@test T[2:3,:] == fill(2.0, 2, 5) | ||
@test T[:,2:3] == fill(2.0, 5, 2) | ||
@test T[2:3,3] == fill(2.0, 2) | ||
@test T[2,2:3] == fill(2.0, 2) | ||
@test_throws BoundsError T[6,1] | ||
@test_throws BoundsError T[1,6] | ||
@test_throws BoundsError T[2,1:6] | ||
@test_throws BoundsError T[1:6,2] | ||
@test_throws BoundsError T[0] | ||
@test_throws BoundsError T[26] | ||
|
||
Base.adjoint(A::TwoMap) = A | ||
@test T[1,1] == 2.0 | ||
@test T[:,1] == fill(2.0, 5) | ||
@test T[1,:] == fill(2.0, 5) | ||
@test T[2:3,:] == fill(2.0, 2, 5) | ||
@test T[:,2:3] == fill(2.0, 5, 2) | ||
@test T[2:3,3] == fill(2.0, 2) | ||
@test T[2,2:3] == fill(2.0, 2) | ||
@test_throws BoundsError T[6,1] | ||
@test_throws BoundsError T[1,6] | ||
@test_throws BoundsError T[2,1:6] | ||
@test_throws BoundsError T[1:6,2] | ||
@test_throws BoundsError T[0] | ||
@test_throws BoundsError T[26] | ||
end | ||
|
||
@testset "function wrap around matrix" begin | ||
MA = rand(ComplexF64, 5, 5) | ||
FA = LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5) | ||
for transform in (identity, transpose, adjoint), (A, F) in ((MA, FA), (3MA, 3FA)) | ||
@test transform(F)[1,1] ≈ transform(A)[1,1] | ||
@test transform(F)[:] ≈ transform(A)[:] | ||
@test transform(F)[1,:] ≈ transform(A)[1,:] | ||
@test transform(F)[:,1] ≈ transform(A)[:,1] | ||
@test transform(F)[1:4,:] ≈ transform(A)[1:4,:] | ||
@test transform(F)[:,1:4] ≈ transform(A)[:,1:4] | ||
@test transform(F)[1,1:3] ≈ transform(A)[1,1:3] | ||
@test transform(F)[1:3,1] ≈ transform(A)[1:3,1] | ||
@test transform(F)[1:2,1:3] ≈ transform(A)[1:2,1:3] | ||
@test transform(F)[[2,1],1:3] ≈ transform(A)[[2,1],1:3] | ||
@test transform(F)[:,:] ≈ transform(A) | ||
@test transform(F)[7] ≈ transform(A)[7] | ||
@test_throws BoundsError transform(F)[0] | ||
@test_throws BoundsError transform(F)[26] | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,3 +40,5 @@ include("fillmap.jl") | |
if VERSION ≥ v"1.1" | ||
include("nontradaxes.jl") | ||
end | ||
|
||
include("getindex.jl") |