-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add slicing functionality #165
Changes from 22 commits
1a52dea
97c1fb3
d56060f
0c9529c
743e456
f61f2a6
090d928
f546fe1
4e87153
6582b3e
04b86fc
a205d15
b9bdfe8
896d59c
c6be0bc
3e0f423
3fbd708
84a95bb
46865b8
8f17ed0
c5c849a
31bfd43
d378b04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
const Indexer = AbstractVector{<:Integer} | ||
|
||
Base.IndexStyle(::LinearMap) = IndexCartesian() | ||
# required in Base.to_indices for [:]-indexing | ||
Base.eachindex(::IndexLinear, A::LinearMap) = Base.OneTo(length(A)) | ||
Base.lastindex(A::LinearMap) = last(eachindex(IndexLinear(), A)) | ||
Base.firstindex(A::LinearMap) = first(eachindex(IndexLinear(), A)) | ||
|
||
function Base.checkbounds(A::LinearMap, i, j) | ||
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j))) | ||
nothing | ||
end | ||
function Base.checkbounds(A::LinearMap, i) | ||
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i)) | ||
nothing | ||
end | ||
# checkbounds in indexing via CartesianIndex | ||
Base.checkbounds(A::LinearMap, i::Union{CartesianIndex{2}, AbstractArray{CartesianIndex{2}}}) = | ||
Base.checkbounds_indices(Bool, axes(A), (i,)) || throw(BoundsError(A, i)) | ||
Base.checkbounds(A::LinearMap, I::AbstractMatrix{Bool}) = | ||
axes(A) == axes(I) || throw(BoundsError(A, I)) | ||
|
||
# main entry point | ||
function Base.getindex(A::LinearMap, I...) | ||
@boundscheck checkbounds(A, I...) | ||
_getindex(A, Base.to_indices(A, I)...) | ||
end | ||
# quick pass forward | ||
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ * A.lmap[I...] | ||
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...] | ||
|
||
######################## | ||
# linear indexing | ||
######################## | ||
_getindex(A::LinearMap, _) = error("linear indexing of LinearMaps is not supported") | ||
|
||
######################## | ||
# Cartesian indexing | ||
######################## | ||
_getindex(A::LinearMap, i::Integer, j::Integer) = | ||
error("scalar indexing of LinearMaps is not supported, consider using A[:,j][i] instead") | ||
_getindex(A::LinearMap, I::Indexer, j::Integer) = | ||
error("partial vertical slicing of LinearMaps is not supported, consider using A[:,j][I] instead") | ||
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*unitvec(A, 2, j) | ||
dkarrasch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_getindex(A::LinearMap, i::Integer, J::Indexer) = | ||
error("partial horizontal slicing of LinearMaps is not supported, consider using A[i,:][J] instead") | ||
function _getindex(A::LinearMap, i::Integer, J::Base.Slice) | ||
try | ||
# requires adjoint action to be defined | ||
return vec(unitvec(A, 1, i)'A) | ||
catch | ||
error("horizontal slicing A[$i,:] requires the adjoint of $(typeof(A)) to be defined") | ||
end | ||
end | ||
_getindex(A::LinearMap, I::Indexer, J::Indexer) = | ||
error("partial two-dimensional slicing of LinearMaps is not supported, consider using A[:,J][I] or A[I,:][J] instead") | ||
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = convert(AbstractMatrix, A) | ||
_getindex(A::LinearMap, I::Base.Slice, J::Indexer) = __getindex(A, I, J) | ||
_getindex(A::LinearMap, I::Indexer, J::Base.Slice) = __getindex(A, I, J) | ||
function __getindex(A, I, J) | ||
dest = zeros(eltype(A), Base.index_shape(I, J)) | ||
# choose whichever requires less map applications | ||
if length(I) <= length(J) | ||
try | ||
# requires adjoint action to be defined | ||
_fillbyrows!(dest, A, I, J) | ||
catch | ||
error("wide slicing A[I,J] with length(I) <= length(J) requires the adjoint of $(typeof(A)) to be defined") | ||
end | ||
else | ||
_fillbycols!(dest, A, I, J) | ||
end | ||
return dest | ||
end | ||
|
||
# helpers | ||
function unitvec(A, dim, i) | ||
x = zeros(eltype(A), size(A, dim)) | ||
@inbounds x[i] = one(eltype(A)) | ||
return x | ||
end | ||
|
||
function _fillbyrows!(dest, A, I, J) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should these be commented out since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests were not variable enough. This method is needed in general. |
||
x = zeros(eltype(A), size(A, 1)) | ||
temp = similar(x, eltype(A), size(A, 2)) | ||
@views @inbounds for (di, i) in zip(eachrow(dest), I) | ||
x[i] = one(eltype(A)) | ||
_unsafe_mul!(temp, A', x) | ||
di .= adjoint.(temp[J]) | ||
x[i] = zero(eltype(A)) | ||
end | ||
return dest | ||
end | ||
function _fillbycols!(dest, A, I::Indexer, J) | ||
x = zeros(eltype(A), size(A, 2)) | ||
temp = similar(x, eltype(A), size(A, 1)) | ||
@inbounds for (ind, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(temp, A, x) | ||
dest[:,ind] .= temp[I] | ||
x[j] = zero(eltype(A)) | ||
end | ||
return dest | ||
end | ||
function _fillbycols!(dest, A, ::Base.Slice, J) | ||
x = zeros(eltype(A), size(A, 2)) | ||
@inbounds for (ind, j) in enumerate(J) | ||
x[j] = one(eltype(A)) | ||
_unsafe_mul!(selectdim(dest, 2, ind), A, x) | ||
x[j] = zero(eltype(A)) | ||
end | ||
return dest | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
using LinearAlgebra, LinearMaps, Test | ||
using LinearMaps: VecOrMatMap, ScaledMap | ||
# using BenchmarkTools | ||
|
||
function test_getindex(A::LinearMap, M::AbstractMatrix) | ||
@assert size(A) == size(M) | ||
mask = rand(Bool, size(A)) | ||
imask = rand(Bool, size(A, 1)) | ||
jmask = rand(Bool, size(A, 2)) | ||
@test A[1,:] == M[1,:] | ||
@test A[:,1] == M[:,1] | ||
@test A[1:4,:] == M[1:4,:] | ||
@test A[:,1:4] == M[:,1:4] | ||
@test A[[2,1],:] == M[[2,1],:] | ||
@test A[:,[2,1]] == M[:,[2,1]] | ||
@test A[:,:] == M | ||
@test (lastindex(A, 1), lastindex(A, 2)) == size(A) | ||
if A isa VecOrMatMap || A isa ScaledMap{<:Any,<:Any,<:VecOrMatMap} | ||
@test A[:] == M[:] | ||
@test A[1,1] == M[1,1] | ||
@test A[1,1:3] == M[1,1:3] | ||
@test A[1:3,1] == M[1:3,1] | ||
@test A[2:end,1] == M[2:end,1] | ||
@test A[1:2,1:3] == M[1:2,1:3] | ||
@test A[[2,1],1:3] == M[[2,1],1:3] | ||
@test A[7] == M[7] | ||
@test A[3:7] == M[3:7] | ||
@test A[mask] == M[mask] | ||
@test A[findall(mask)] == M[findall(mask)] | ||
@test A[CartesianIndex(1,1)] == M[CartesianIndex(1,1)] | ||
@test A[imask, 1] == M[imask, 1] | ||
@test A[1, jmask] == M[1, jmask] | ||
@test A[imask, jmask] == M[imask, jmask] | ||
else | ||
@test_throws ErrorException A[:] | ||
@test_throws ErrorException A[1,1] | ||
@test_throws ErrorException A[1,1:3] | ||
@test_throws ErrorException A[1:3,1] | ||
@test_throws ErrorException A[2:end,1] | ||
@test_throws ErrorException A[1:2,1:3] | ||
@test_throws ErrorException A[[2,1],1:3] | ||
@test_throws ErrorException A[7] | ||
@test_throws ErrorException A[3:7] | ||
@test_throws ErrorException A[mask] | ||
@test_throws ErrorException A[findall(mask)] | ||
@test_throws ErrorException A[CartesianIndex(1,1)] | ||
@test_throws ErrorException A[imask, 1] | ||
@test_throws ErrorException A[1, jmask] | ||
@test_throws ErrorException A[imask, jmask] | ||
end | ||
@test_throws BoundsError A[lastindex(A,1)+1,1] | ||
@test_throws BoundsError A[1,lastindex(A,2)+1] | ||
@test_throws BoundsError A[2,1:lastindex(A,2)+1] | ||
@test_throws BoundsError A[1:lastindex(A,1)+1,2] | ||
@test_throws BoundsError A[ones(Bool, 2, 2)] | ||
@test_throws BoundsError A[[true, true], 1] | ||
@test_throws BoundsError A[1, [true, true]] | ||
return nothing | ||
end | ||
|
||
@testset "getindex" begin | ||
M = rand(4,6) | ||
A = LinearMap(M) | ||
test_getindex(A, M) | ||
test_getindex(2A, 2M) | ||
# @btime getindex($M, i) setup=(i = rand(1:24)); | ||
# @btime getindex($A, i) setup=(i = rand(1:24)); | ||
# @btime (getindex($M, i, j)) setup=(i = rand(1:4); j = rand(1:6)); | ||
# @btime (getindex($A, i, j)) setup=(i = rand(1:4); j = rand(1:6)); | ||
|
||
struct TwoMap <: LinearMaps.LinearMap{Float64} end | ||
Base.size(::TwoMap) = (5,5) | ||
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x)) | ||
T = TwoMap() | ||
@test_throws ErrorException T[1,:] | ||
|
||
Base.transpose(A::TwoMap) = A | ||
test_getindex(TwoMap(), fill(2.0, size(T))) | ||
|
||
MA = rand(ComplexF64, 5, 5) | ||
FA = LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5) | ||
F = LinearMap{ComplexF64}(x -> MA*x, y -> MA'y, 5, 5) | ||
test_getindex(FA, MA) | ||
test_getindex([FA FA], [MA MA]) | ||
test_getindex([FA; FA], [MA; MA]) | ||
test_getindex(F, MA) | ||
test_getindex(3FA, 3MA) | ||
test_getindex(FA + FA, 2MA) | ||
test_getindex(transpose(FA), transpose(MA)) | ||
test_getindex(transpose(3FA), transpose(3MA)) | ||
test_getindex(3transpose(FA), transpose(3MA)) | ||
test_getindex(adjoint(FA), adjoint(MA)) | ||
test_getindex(adjoint(3FA), adjoint(3MA)) | ||
test_getindex(3adjoint(FA), adjoint(3MA)) | ||
|
||
test_getindex(FillMap(0.5, (5, 5)), fill(0.5, (5, 5))) | ||
test_getindex(LinearMap(0.5I, 5), Matrix(0.5I, 5, 5)) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,3 +35,5 @@ include("fillmap.jl") | |
include("nontradaxes.jl") | ||
|
||
include("embeddedmap.jl") | ||
|
||
include("getindex.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that codecov complains about these two lines. I wonder if they should be commented out because we are not supporting
A[1]
andA[end]
at this point in time, right?