-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Jeff Fessler <[email protected]>
- Loading branch information
1 parent
12c13fb
commit 0a42082
Showing
5 changed files
with
241 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
const Indexer = AbstractVector{<:Integer} | ||
|
||
Base.IndexStyle(::LinearMap) = IndexCartesian() | ||
# required in Base.to_indices for [:]-indexing (only size check) | ||
Base.eachindex(::IndexLinear, A::LinearMap) = Base.OneTo(length(A)) | ||
# Base.lastindex(A::LinearMap) = last(eachindex(IndexLinear(), A)) | ||
# Base.firstindex(A::LinearMap) = first(eachindex(IndexLinear(), A)) | ||
|
||
function Base.checkbounds(A::LinearMap, i, j) | ||
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j))) | ||
nothing | ||
end | ||
function Base.checkbounds(A::LinearMap, i) | ||
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i)) | ||
nothing | ||
end | ||
# checkbounds in indexing via CartesianIndex | ||
Base.checkbounds(A::LinearMap, i::Union{CartesianIndex{2}, AbstractArray{CartesianIndex{2}}}) = | ||
Base.checkbounds_indices(Bool, axes(A), (i,)) || throw(BoundsError(A, i)) | ||
Base.checkbounds(A::LinearMap, I::AbstractMatrix{Bool}) = | ||
axes(A) == axes(I) || throw(BoundsError(A, I)) | ||
|
||
# main entry point | ||
function Base.getindex(A::LinearMap, I...) | ||
@boundscheck checkbounds(A, I...) | ||
_getindex(A, Base.to_indices(A, I)...) | ||
end | ||
# quick pass forward | ||
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ * A.lmap[I...] | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...] | ||
|
||
######################## | ||
# linear indexing | ||
######################## | ||
_getindex(A::LinearMap, _) = error("linear indexing of LinearMaps is not supported") | ||
|
||
######################## | ||
# Cartesian indexing (partial slicing is not supported) | ||
######################## | ||
_getindex(A::LinearMap, i::Integer, j::Integer) = | ||
error("scalar indexing of LinearMaps is not supported, consider using A[:,j][i] instead") | ||
_getindex(A::LinearMap, I::Indexer, j::Integer) = | ||
error("partial vertical slicing of LinearMaps is not supported, consider using A[:,j][I] instead") | ||
_getindex(A::LinearMap, i::Integer, J::Indexer) = | ||
error("partial horizontal slicing of LinearMaps is not supported, consider using A[i,:][J] instead") | ||
_getindex(A::LinearMap, I::Indexer, J::Indexer) = | ||
error("partial two-dimensional slicing of LinearMaps is not supported, consider using A[:,J][I] or A[I,:][J] instead") | ||
|
||
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*unitvec(A, 2, j) | ||
function _getindex(A::LinearMap, i::Integer, J::Base.Slice) | ||
try | ||
# requires adjoint action to be defined | ||
return vec(unitvec(A, 1, i)'A) | ||
catch | ||
error("horizontal slicing A[$i,:] requires the adjoint of $(typeof(A)) to be defined") | ||
end | ||
end | ||
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = convert(AbstractMatrix, A) | ||
_getindex(A::LinearMap, I::Base.Slice, J::Indexer) = __getindex(A, I, J) | ||
_getindex(A::LinearMap, I::Indexer, J::Base.Slice) = __getindex(A, I, J) | ||
function __getindex(A, I, J) | ||
dest = zeros(eltype(A), Base.index_shape(I, J)) | ||
# choose whichever requires less map applications | ||
if length(I) <= length(J) | ||
try | ||
# requires adjoint action to be defined | ||
_fillbyrows!(dest, A, I, J) | ||
catch | ||
error("wide slicing A[I,J] with length(I) <= length(J) requires the adjoint of $(typeof(A)) to be defined") | ||
end | ||
else | ||
_fillbycols!(dest, A, I, J) | ||
end | ||
return dest | ||
end | ||
|
||
# helpers | ||
function unitvec(A, dim, i) | ||
x = zeros(eltype(A), size(A, dim)) | ||
@inbounds x[i] = one(eltype(A)) | ||
return x | ||
end | ||
|
||
function _fillbyrows!(dest, A, I, J) | ||
x = zeros(eltype(A), size(A, 1)) | ||
temp = similar(x, eltype(A), size(A, 2)) | ||
@views @inbounds for (di, i) in zip(eachrow(dest), I) | ||
x[i] = one(eltype(A)) | ||
_unsafe_mul!(temp, A', x) | ||
di .= adjoint.(temp[J]) | ||
x[i] = zero(eltype(A)) | ||
end | ||
return dest | ||
end | ||
function _fillbycols!(dest, A, I::Indexer, J) | ||
x = zeros(eltype(A), size(A, 2)) | ||
temp = similar(x, eltype(A), size(A, 1)) | ||
@inbounds for (ind, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(temp, A, x) | ||
dest[:,ind] .= temp[I] | ||
x[j] = zero(eltype(A)) | ||
end | ||
return dest | ||
end | ||
function _fillbycols!(dest, A, ::Base.Slice, J) | ||
x = zeros(eltype(A), size(A, 2)) | ||
@inbounds for (ind, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(selectdim(dest, 2, ind), A, x) | ||
x[j] = zero(eltype(A)) | ||
end | ||
return dest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
using LinearAlgebra, LinearMaps, Test | ||
using LinearMaps: VecOrMatMap, ScaledMap | ||
# using BenchmarkTools | ||
|
||
function test_getindex(A::LinearMap, M::AbstractMatrix) | ||
@assert size(A) == size(M) | ||
mask = rand(Bool, size(A)) | ||
imask = rand(Bool, size(A, 1)) | ||
jmask = rand(Bool, size(A, 2)) | ||
@test A[1,:] == M[1,:] | ||
@test A[:,1] == M[:,1] | ||
@test A[1:lastindex(A,1)-2,:] == M[1:lastindex(A,1)-2,:] | ||
@test A[:,1:4] == M[:,1:4] | ||
@test A[[2,1],:] == M[[2,1],:] | ||
@test A[:,[2,1]] == M[:,[2,1]] | ||
@test A[:,:] == M | ||
@test (lastindex(A, 1), lastindex(A, 2)) == size(A) | ||
if A isa VecOrMatMap || A isa ScaledMap{<:Any,<:Any,<:VecOrMatMap} | ||
@test A[:] == M[:] | ||
@test A[1,1] == M[1,1] | ||
@test A[1,1:3] == M[1,1:3] | ||
@test A[1:3,1] == M[1:3,1] | ||
@test A[2:end,1] == M[2:end,1] | ||
@test A[1:2,1:3] == M[1:2,1:3] | ||
@test A[[2,1],1:3] == M[[2,1],1:3] | ||
@test A[7] == M[7] | ||
@test A[3:7] == M[3:7] | ||
@test A[mask] == M[mask] | ||
@test A[findall(mask)] == M[findall(mask)] | ||
@test A[CartesianIndex(1,1)] == M[CartesianIndex(1,1)] | ||
@test A[imask, 1] == M[imask, 1] | ||
@test A[1, jmask] == M[1, jmask] | ||
@test A[imask, jmask] == M[imask, jmask] | ||
else | ||
@test_throws ErrorException A[:] | ||
@test_throws ErrorException A[1,1] | ||
@test_throws ErrorException A[1,1:3] | ||
@test_throws ErrorException A[1:3,1] | ||
@test_throws ErrorException A[2:end,1] | ||
@test_throws ErrorException A[1:2,1:3] | ||
@test_throws ErrorException A[[2,1],1:3] | ||
@test_throws ErrorException A[7] | ||
@test_throws ErrorException A[3:7] | ||
@test_throws ErrorException A[mask] | ||
@test_throws ErrorException A[findall(mask)] | ||
@test_throws ErrorException A[CartesianIndex(1,1)] | ||
@test_throws ErrorException A[imask, 1] | ||
@test_throws ErrorException A[1, jmask] | ||
@test_throws ErrorException A[imask, jmask] | ||
end | ||
@test_throws BoundsError A[lastindex(A,1)+1,1] | ||
@test_throws BoundsError A[1,lastindex(A,2)+1] | ||
@test_throws BoundsError A[2,1:lastindex(A,2)+1] | ||
@test_throws BoundsError A[1:lastindex(A,1)+1,2] | ||
@test_throws BoundsError A[ones(Bool, 2, 2)] | ||
@test_throws BoundsError A[[true, true], 1] | ||
@test_throws BoundsError A[1, [true, true]] | ||
return nothing | ||
end | ||
|
||
@testset "getindex" begin | ||
M = rand(4,6) | ||
A = LinearMap(M) | ||
test_getindex(A, M) | ||
test_getindex(2A, 2M) | ||
# @btime getindex($M, i) setup=(i = rand(1:24)); | ||
# @btime getindex($A, i) setup=(i = rand(1:24)); | ||
# @btime (getindex($M, i, j)) setup=(i = rand(1:4); j = rand(1:6)); | ||
# @btime (getindex($A, i, j)) setup=(i = rand(1:4); j = rand(1:6)); | ||
|
||
struct TwoMap <: LinearMaps.LinearMap{Float64} end | ||
Base.size(::TwoMap) = (5,5) | ||
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x)) | ||
T = TwoMap() | ||
@test_throws ErrorException T[1,:] | ||
|
||
Base.transpose(A::TwoMap) = A | ||
test_getindex(TwoMap(), fill(2.0, size(T))) | ||
|
||
MA = rand(ComplexF64, 5, 5) | ||
FA = LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5) | ||
F = LinearMap{ComplexF64}(x -> MA*x, y -> MA'y, 5, 5) | ||
test_getindex(FA, MA) | ||
test_getindex([FA FA], [MA MA]) | ||
test_getindex([FA; FA], [MA; MA]) | ||
test_getindex(F, MA) | ||
test_getindex(3FA, 3MA) | ||
test_getindex(FA + FA, 2MA) | ||
test_getindex(transpose(FA), transpose(MA)) | ||
test_getindex(transpose(3FA), transpose(3MA)) | ||
test_getindex(3transpose(FA), transpose(3MA)) | ||
test_getindex(adjoint(FA), adjoint(MA)) | ||
test_getindex(adjoint(3FA), adjoint(3MA)) | ||
test_getindex(3adjoint(FA), adjoint(3MA)) | ||
|
||
test_getindex(FillMap(0.5, (5, 5)), fill(0.5, (5, 5))) | ||
test_getindex(LinearMap(0.5I, 5), Matrix(0.5I, 5, 5)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,3 +35,5 @@ include("fillmap.jl") | |
include("nontradaxes.jl") | ||
|
||
include("embeddedmap.jl") | ||
|
||
include("getindex.jl") |