-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Jeff Fessler <[email protected]>
- Loading branch information
1 parent
df9ccd3
commit 98843dd
Showing
6 changed files
with
186 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
struct EmbeddedMap{T, As <: LinearMap, Rs <: AbstractVector{Int}, | ||
Cs <: AbstractVector{Int}} <: LinearMap{T} | ||
lmap::As | ||
dims::Dims{2} | ||
rows::Rs # typically i1:i2 with 1 <= i1 <= i2 <= size(map,1) | ||
cols::Cs # typically j1:j2 with 1 <= j1 <= j2 <= size(map,2) | ||
|
||
function EmbeddedMap{T}(map::As, dims::Dims{2}, rows::Rs, cols::Cs) where {T, | ||
As <: LinearMap, Rs <: AbstractVector{Int}, Cs <: AbstractVector{Int}} | ||
check_index(rows, size(map, 1), dims[1]) | ||
check_index(cols, size(map, 2), dims[2]) | ||
return new{T,As,Rs,Cs}(map, dims, rows, cols) | ||
end | ||
end | ||
|
||
EmbeddedMap(map::LinearMap{T}, dims::Dims{2}; offset::Dims{2}) where {T} = | ||
EmbeddedMap{T}(map, dims, offset[1] .+ (1:size(map, 1)), offset[2] .+ (1:size(map, 2))) | ||
EmbeddedMap(map::LinearMap, dims::Dims{2}, rows::AbstractVector{Int}, cols::AbstractVector{Int}) = | ||
EmbeddedMap{eltype(map)}(map, dims, rows, cols) | ||
|
||
@static if VERSION >= v"1.8-" | ||
Base.reverse(A::LinearMap; dims=:) = _reverse(A, dims) | ||
function _reverse(A, dims::Integer) | ||
if dims == 1 | ||
return EmbeddedMap(A, size(A), reverse(axes(A, 1)), axes(A, 2)) | ||
elseif dims == 2 | ||
return EmbeddedMap(A, size(A), axes(A, 1), reverse(axes(A, 2))) | ||
else | ||
throw(ArgumentError("invalid dims argument to reverse, should be 1 or 2, got $dims")) | ||
end | ||
end | ||
_reverse(A, ::Colon) = EmbeddedMap(A, size(A), map(reverse, axes(A))...) | ||
_reverse(A, dims::NTuple{1,Integer}) = _reverse(A, first(dims)) | ||
function _reverse(A, dims::NTuple{M,Integer}) where {M} | ||
dimrev = ntuple(k -> k in dims, 2) | ||
if 2 < M || M != sum(dimrev) | ||
throw(ArgumentError("invalid dimensions $dims in reverse!")) | ||
end | ||
ax = ntuple(k -> dimrev[k] ? reverse(axes(A, k)) : axes(A, k), 2) | ||
return EmbeddedMap(A, size(A), ax...) | ||
end | ||
end | ||
|
||
function check_index(index::AbstractVector{Int}, dimA::Int, dimB::Int) | ||
length(index) != dimA && throw(ArgumentError("invalid length of index vector")) | ||
minimum(index) <= 0 && throw(ArgumentError("minimal index is below 1")) | ||
maximum(index) > dimB && throw(ArgumentError( | ||
"maximal index $(maximum(index)) exceeds dimension $dimB" | ||
)) | ||
# _isvalidstep(index) || throw(ArgumentError("non-monotone index set")) | ||
nothing | ||
end | ||
|
||
# _isvalidstep(index::AbstractRange) = step(index) > 0 | ||
# _isvalidstep(index::AbstractVector) = all(diff(index) .> 0) | ||
|
||
Base.size(A::EmbeddedMap) = A.dims | ||
|
||
# sufficient but not necessary conditions | ||
LinearAlgebra.issymmetric(A::EmbeddedMap) = | ||
issymmetric(A.lmap) && (A.dims[1] == A.dims[2]) && (A.rows == A.cols) | ||
LinearAlgebra.ishermitian(A::EmbeddedMap) = | ||
ishermitian(A.lmap) && (A.dims[1] == A.dims[2]) && (A.rows == A.cols) | ||
|
||
Base.:(==)(A::EmbeddedMap, B::EmbeddedMap) = | ||
(eltype(A) == eltype(B)) && (A.lmap == B.lmap) && | ||
(A.dims == B.dims) && (A.rows == B.rows) && (A.cols == B.cols) | ||
|
||
LinearAlgebra.adjoint(A::EmbeddedMap) = EmbeddedMap(adjoint(A.lmap), reverse(A.dims), A.cols, A.rows) | ||
LinearAlgebra.transpose(A::EmbeddedMap) = EmbeddedMap(transpose(A.lmap), reverse(A.dims), A.cols, A.rows) | ||
|
||
for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix)) | ||
@eval function _unsafe_mul!(y::$Out, A::EmbeddedMap, x::$In) | ||
fill!(y, zero(eltype(y))) | ||
_unsafe_mul!(selectdim(y, 1, A.rows), A.lmap, selectdim(x, 1, A.cols)) | ||
return y | ||
end | ||
@eval function _unsafe_mul!(y::$Out, A::EmbeddedMap, x::$In, alpha::Number, beta::Number) | ||
LinearAlgebra._rmul_or_fill!(y, beta) | ||
_unsafe_mul!(selectdim(y, 1, A.rows), A.lmap, selectdim(x, 1, A.cols), alpha, !iszero(beta)) | ||
return y | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
using Test, LinearMaps, LinearAlgebra, SparseArrays | ||
|
||
@testset "embeddedmap" begin | ||
m = 6; n = 5 | ||
M = 10(1:m) .+ (1:n)'; L = LinearMap(M) | ||
offset = (3,4) | ||
|
||
BM = [zeros(offset...) zeros(offset[1], size(M,2)); | ||
zeros(size(M,1), offset[2]) M] | ||
BL = @inferred LinearMap(L, size(BM); offset=offset) | ||
s1, s2 = size(BM) | ||
@test (@inferred Matrix(BL)) == BM | ||
@test (@inferred Matrix(BL')) == BM' | ||
@test (@inferred Matrix(transpose(BL))) == transpose(BM) | ||
|
||
@test_throws UndefKeywordError LinearMap(M, (10, 10)) | ||
@test_throws ArgumentError LinearMap(M, (m, n), (0:m, 1:n)) | ||
@test_throws ArgumentError LinearMap(M, (m, n), (0:m-1, 1:n)) | ||
@test_throws ArgumentError LinearMap(M, (m, n), (1:m, 1:n+1)) | ||
@test_throws ArgumentError LinearMap(M, (m, n), (1:m, 2:n+1)) | ||
@test_throws ArgumentError LinearMap(M, (m, n), offset=(3,3)) | ||
# @test_throws ArgumentError LinearMap(M, (m, n), (m:-1:1, 1:n)) | ||
# @test_throws ArgumentError LinearMap(M, (m, n), (collect(m:-1:1), 1:n)) | ||
@test size(@inferred LinearMap(M, (2m, 2n), (1:2:2m, 1:2:2n))) == (2m, 2n) | ||
@test @inferred !ishermitian(BL) | ||
@test @inferred !issymmetric(BL) | ||
@test @inferred LinearMap(L, size(BM), (offset[1] .+ (1:m), offset[2] .+ (1:n))) == BL | ||
Wc = @inferred LinearMap([2 im; -im 0]; ishermitian=true) | ||
Bc = @inferred LinearMap(Wc, (4,4); offset=(2,2)) | ||
@test (@inferred ishermitian(Bc)) | ||
|
||
x = randn(s2); X = rand(s2, 3) | ||
y = BM * x; Y = zeros(s1, 3) | ||
|
||
@test @inferred BL * x ≈ BM * x | ||
@test @inferred BL' * y ≈ BM' * y | ||
|
||
for α in (true, false, rand()), | ||
β in (true, false, rand()), | ||
t in (identity, adjoint, transpose) | ||
|
||
@test t(BL) * x ≈ mul!(copy(y), t(BL), x) ≈ t(BM) * x | ||
@test Matrix(t(BL) * X) ≈ mul!(copy(Y), t(BL), X) ≈ t(BM) * X | ||
y = randn(s1); Y = randn(s1, 3) | ||
@test (@inferred mul!(copy(y), t(BL), x, α, β)) ≈ mul!(copy(y), t(BM), x, α, β) | ||
@test (@inferred mul!(copy(Y), t(BL), X, α, β)) ≈ mul!(copy(Y), t(BM), X, α, β) | ||
end | ||
|
||
if VERSION >= v"1.8" | ||
M = rand(3,4) | ||
L = LinearMap(M) | ||
@test Matrix(reverse(L)) == reverse(M) | ||
for dims in (1, 2, (1,), (2,), (1, 2), (2, 1), :) | ||
@test Matrix(reverse(L, dims=dims)) == reverse(M, dims=dims) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,5 @@ include("left.jl") | |
include("fillmap.jl") | ||
|
||
include("nontradaxes.jl") | ||
|
||
include("embeddedmap.jl") |