Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for qr of strided inputs (non-contiguous views) #1764

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b5e53e6
Adding qr tests for view support
evelyne-ringoot Feb 1, 2023
1673bd9
Adding QR for views on CUDA
evelyne-ringoot Feb 9, 2023
f37ed22
Resolving typos in tests, commenting a TO DO
evelyne-ringoot Feb 9, 2023
952e1ed
Resolving TODOs in test
evelyne-ringoot Feb 9, 2023
3c0642c
Changes to manifest
evelyne-ringoot Feb 9, 2023
81d489b
use Manifest.toml from master
vchuravy Feb 9, 2023
875dddb
Update lib/cusolver/linalg.jl
evelyne-ringoot Feb 9, 2023
aa88c71
Update Manifest and Project
vchuravy Feb 15, 2023
3e1561b
Merge branch 'master' into qr_views
evelyne-ringoot Feb 15, 2023
928cbf8
Repairs to tests - dimensions of test for views
evelyne-ringoot Feb 16, 2023
34b71d5
Resolving dimension mismatches in tests
evelyne-ringoot Feb 22, 2023
040cf8a
Commenting out one buggy test - TO DO
evelyne-ringoot Feb 22, 2023
7ebff72
Resolving todo by adapting dimension issues in test
evelyne-ringoot Feb 24, 2023
0b16b3a
Adding support for in-place qr of views
evelyne-ringoot Mar 2, 2023
8941d0a
Restoring order of function definitions
evelyne-ringoot Mar 2, 2023
762f3e0
Adding tests for inplace qr of views
evelyne-ringoot Mar 3, 2023
31fb908
Updating dependency on QR_views branch of GPUArrays
evelyne-ringoot Mar 10, 2023
455c49b
Adding support for in place QR of views
evelyne-ringoot Mar 10, 2023
d9ca07a
Adding support for in place QR of views in LinearAlgebra
evelyne-ringoot Mar 10, 2023
05ba598
Merge branch 'master' into qr_views
evelyne-ringoot Mar 10, 2023
2bd209c
Merge branch 'master' into qr_views
evelyne-ringoot Sep 5, 2023
7d7601f
Merge branch 'master' into qr_views
evelyne-ringoot Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
@eval begin
function ormqr!(side::Char,
trans::Char,
A::CuMatrix{$elty},
tau::CuVector{$elty},
C::CuVecOrMat{$elty})
A::StridedCuMatrix{$elty},
tau::StridedCuVector{$elty},
C::StridedCuVecOrMat{$elty})

# Support transa = 'C' for real matrices
trans = $elty <: Real && trans == 'C' ? 'T' : trans
Expand Down
187 changes: 162 additions & 25 deletions lib/cusolver/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ _copywitheltype(::Type{T}, As...) where {T} = map(A -> copyto!(similar(A, T), A)

# matrix division

const CuMatOrAdj{T} = Union{CuMatrix{T},
LinearAlgebra.Adjoint{T, <:CuMatrix{T}},
LinearAlgebra.Transpose{T, <:CuMatrix{T}}}
const CuOrAdj{T} = Union{CuVecOrMat{T},
LinearAlgebra.Adjoint{T, <:CuVecOrMat{T}},
LinearAlgebra.Transpose{T, <:CuVecOrMat{T}}}
const CuMatOrAdj{T} = Union{StridedCuMatrix,
LinearAlgebra.Adjoint{T, <:StridedCuMatrix{T}},
LinearAlgebra.Transpose{T, <:StridedCuMatrix{T}}}
const CuOrAdj{T} = Union{StridedCuVector,
LinearAlgebra.Adjoint{T, <:StridedCuVector{T}},
LinearAlgebra.Transpose{T, <:StridedCuVector{T}},
StridedCuMatrix,
LinearAlgebra.Adjoint{T, <:StridedCuMatrix{T}},
LinearAlgebra.Transpose{T, <:StridedCuMatrix{T}}}

function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
A, B = copy_cublasfloat(_A, _B)
Expand Down Expand Up @@ -129,31 +132,34 @@ using LinearAlgebra: Factorization, AbstractQ, QRCompactWY, QRCompactWYQ, QRPack

## QR

LinearAlgebra.qr!(A::CuMatrix{T}) where T = QR(geqrf!(A::CuMatrix{T})...)


LinearAlgebra.qr!(A::StridedCuMatrix{T}) where T = QR(geqrf!(A::StridedCuMatrix{T})...)


# conversions
CuMatrix(F::Union{QR,QRCompactWY}) = CuArray(AbstractArray(F))
CuArray(F::Union{QR,QRCompactWY}) = CuMatrix(F)
CuMatrix(F::QRPivoted) = CuArray(AbstractArray(F))
CuArray(F::QRPivoted) = CuMatrix(F)

function LinearAlgebra.ldiv!(_qr::QR, b::CuVector)
function LinearAlgebra.ldiv!(_qr::QR, b::StridedCuVector)
m,n = size(_qr)
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * b)[1:n])
b[1:n] .= _x
unsafe_free!(_x)
return b[1:n]
end

function LinearAlgebra.ldiv!(_qr::QR, B::CuMatrix)
function LinearAlgebra.ldiv!(_qr::QR, B::StridedCuMatrix)
m,n = size(_qr)
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * B)[1:n, 1:size(B, 2)])
B[1:n, 1:size(B, 2)] .= _x
unsafe_free!(_x)
return B[1:n, 1:size(B, 2)]
end

function LinearAlgebra.ldiv!(x::CuArray, _qr::QR, b::CuArray)
function LinearAlgebra.ldiv!(x::StridedCuArray, _qr::QR, b::StridedCuArray)
_x = ldiv!(_qr, b)
x .= vec(_x)
unsafe_free!(_x)
Expand All @@ -174,57 +180,188 @@ CuMatrix{T}(Q::QRCompactWYQ) where {T} = error("QRCompactWY format is not suppor
Matrix{T}(Q::QRPackedQ{S,<:CuArray,<:CuArray}) where {T,S} = Array(CuMatrix{T}(Q))
Matrix{T}(Q::QRCompactWYQ{S,<:CuArray,<:CuArray}) where {T,S} = Array(CuMatrix{T}(Q))



if VERSION < v"1.10-"
# extracting the full matrix can be done with `collect` (which defaults to `Array`)
function Base.collect(src::Union{QRPackedQ{<:Any,<:CuArray,<:CuArray},
QRCompactWYQ{<:Any,<:CuArray,<:CuArray}})
function Base.collect(src::Union{QRPackedQ{<:Any,<:StridedCuArray,<:StridedCuArray},
QRCompactWYQ{<:Any,<:StridedCuArray,<:StridedCuArray}})
dest = similar(src)
copyto!(dest, I)
lmul!(src, dest)
collect(dest)
end

# avoid the generic similar fallback that returns a CPU array
Base.similar(Q::Union{QRPackedQ{<:Any,<:CuArray,<:CuArray},
QRCompactWYQ{<:Any,<:CuArray,<:CuArray}},
Base.similar(Q::Union{QRPackedQ{<:Any,<:StridedCuArray,<:StridedCuArray},
QRCompactWYQ{<:Any,<:StridedCuArray,<:StridedCuArray}},
::Type{T}, dims::Dims{N}) where {T,N} =
CuArray{T,N}(undef, dims)

end

function Base.getindex(Q::QRPackedQ{<:Any, <:CuArray}, ::Colon, j::Int)
function Base.getindex(Q::QRPackedQ{<:Any, <:StridedCuArray}, ::Colon, j::Int)
y = CUDA.zeros(eltype(Q), size(Q, 2))
y[j] = 1
lmul!(Q, y)
end


# multiplication by Q
LinearAlgebra.lmul!(A::QRPackedQ{T,<:CuArray,<:CuArray},
LinearAlgebra.lmul!(A::QRPackedQ{T,<:StridedCuArray,<:StridedCuArray},
B::CuVecOrMat{T}) where {T<:BlasFloat} =
ormqr!('L', 'N', A.factors, A.τ, B)
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:CuArray,<:CuArray}},
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}},
B::CuVecOrMat{T}) where {T<:BlasReal} =
ormqr!('L', 'T', parent(adjA).factors, parent(adjA).τ, B)
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:CuArray,<:CuArray}},
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}},
B::CuVecOrMat{T}) where {T<:BlasComplex} =
ormqr!('L', 'C', parent(adjA).factors, parent(adjA).τ, B)
LinearAlgebra.lmul!(trA::Transpose{T,<:QRPackedQ{T,<:CuArray,<:CuArray}},
LinearAlgebra.lmul!(trA::Transpose{T,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}},
B::CuVecOrMat{T}) where {T<:BlasFloat} =
ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B)

LinearAlgebra.rmul!(A::CuVecOrMat{T},
B::QRPackedQ{T,<:CuArray,<:CuArray}) where {T<:BlasFloat} =
B::QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}) where {T<:BlasFloat} =
ormqr!('R', 'N', B.factors, B.τ, A)
LinearAlgebra.rmul!(A::CuVecOrMat{T},
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:CuArray,<:CuArray}}) where {T<:BlasReal} =
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}}) where {T<:BlasReal} =
ormqr!('R', 'T', parent(adjB).factors, parent(adjB).τ, A)
LinearAlgebra.rmul!(A::CuVecOrMat{T},
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:CuArray,<:CuArray}}) where {T<:BlasComplex} =
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}}) where {T<:BlasComplex} =
ormqr!('R', 'C', parent(adjB).factors, parent(adjB).τ, A)
LinearAlgebra.rmul!(A::CuVecOrMat{T},
trA::Transpose{<:Any,<:QRPackedQ{T,<:CuArray,<:CuArray}}) where {T<:BlasFloat} =
trA::Transpose{<:Any,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}}) where {T<:BlasFloat} =
ormqr!('R', 'T', parent(trA).factors, parent(adjB).τ, A)

else

struct CuQR{T} <: Factorization{T}
factors::StridedCuMatrix
τ::StridedCuVector{T}
CuQR{T}(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} = new(factors, τ)
end

struct CuQRPackedQ{T} <: AbstractQ{T}
factors::StridedCuMatrix{T}
τ::StridedCuVector{T}
CuQRPackedQ{T}(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} = new(factors, τ)
end

CuQR(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} =
CuQR{T}(factors, τ)
CuQRPackedQ(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} =
CuQRPackedQ{T}(factors, τ)

# AbstractQ's `size` is the size of the full matrix,
# while `Matrix(Q)` only gives the compact Q.
# See JuliaLang/julia#26591 and JuliaGPU/CUDA.jl#969.
CuMatrix{T}(Q::AbstractQ{S}) where {T,S} = convert(CuArray{T}, Matrix(Q))
CuMatrix{T, B}(Q::AbstractQ{S}) where {T, B, S} = CuMatrix{T}(Q)
CuMatrix(Q::AbstractQ{T}) where {T} = CuMatrix{T}(Q)
CuArray{T}(Q::AbstractQ) where {T} = CuMatrix{T}(Q)
CuArray(Q::AbstractQ) = CuMatrix(Q)

# extracting the full matrix can be done with `collect` (which defaults to `Array`)
function Base.collect(src::CuQRPackedQ)
dest = similar(src)
copyto!(dest, I)
lmul!(src, dest)
collect(dest)
end

# avoid the generic similar fallback that returns a CPU array
Base.similar(Q::CuQRPackedQ, ::Type{T}, dims::Dims{N}) where {T,N} =
CuArray{T,N}(undef, dims)

LinearAlgebra.qr!(A::StridedCuMatrix{T}) where T = CuQR(geqrf!(A::StridedCuMatrix{T})...)

Base.size(A::CuQR) = size(A.factors)
Base.size(A::CuQRPackedQ, dim::Integer) = 0 < dim ? (dim <= 2 ? size(A.factors, 1) : 1) : throw(BoundsError())
CUDA.CuMatrix(A::CuQRPackedQ) = orgqr!(copy(A.factors), A.τ)
CUDA.CuArray(A::CuQRPackedQ) = CuMatrix(A)
Base.Matrix(A::CuQRPackedQ) = Matrix(CuMatrix(A))

function Base.getproperty(A::CuQR, d::Symbol)
m, n = size(getfield(A, :factors))
if d == :R
return triu!(view(A.factors,1:min(m, n), 1:n))
elseif d == :Q
return CuQRPackedQ(A.factors, A.τ)
else
getfield(A, d)
end
end

# iteration for destructuring into components
Base.iterate(S::CuQR) = (S.Q, Val(:R))
Base.iterate(S::CuQR, ::Val{:R}) = (S.R, Val(:done))
Base.iterate(S::CuQR, ::Val{:done}) = nothing

# Apply changes Q from the left
LinearAlgebra.lmul!(A::CuQRPackedQ{T}, B::StridedCuVecOrMat{T}) where {T<:BlasFloat} =
ormqr!('L', 'N', A.factors, A.τ, B)
LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T}}, B::StridedCuVecOrMat{T}) where {T<:BlasReal} =
ormqr!('L', 'T', parent(adjA).factors, parent(adjA).τ, B)
LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T}}, B::StridedCuVecOrMat{T}) where {T<:BlasComplex} =
ormqr!('L', 'C', parent(adjA).factors, parent(adjA).τ, B)
LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T}}, B::StridedCuVecOrMat{T}) where {T<:BlasFloat} =
ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B)

# Apply changes Q from the right
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T}, B::CuQRPackedQ{T}) where {T<:BlasFloat} =
ormqr!('R', 'N', B.factors, B.τ, A)
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T},
adjB::Adjoint{<:Any,<:CuQRPackedQ{T}}) where {T<:BlasReal} =
ormqr!('R', 'T', parent(adjB).factors, parent(adjB).τ, A)
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T},
adjB::Adjoint{<:Any,<:CuQRPackedQ{T}}) where {T<:BlasComplex} =
ormqr!('R', 'C', parent(adjB).factors, parent(adjB).τ, A)
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T},
trA::Transpose{<:Any,<:CuQRPackedQ{T}}) where {T<:BlasFloat} =
ormqr!('R', 'T', parent(trA).factors, parent(adjB).τ, A)

function Base.getindex(A::CuQRPackedQ{T}, i::Int, j::Int) where {T}
assertscalar("CuQRPackedQ getindex")
x = CUDA.zeros(T, size(A, 2))
x[j] = 1
lmul!(A, x)
return x[i]
end

function Base.show(io::IO, F::CuQR)
println(io, "$(typeof(F)) with factors Q and R:")
show(io, F.Q)
println(io)
show(io, F.R)
end

# https://github.com/JuliaLang/julia/pull/32887
LinearAlgebra.det(Q::CuQRPackedQ{<:Real}) = isodd(count(!iszero, Q.τ)) ? -1 : 1
LinearAlgebra.det(Q::CuQRPackedQ) = prod(τ -> iszero(τ) ? one(τ) : -sign(τ)^2, Q.τ)

function LinearAlgebra.ldiv!(_qr::CuQR, b::StridedCuVector)
m,n = size(_qr)
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * b)[1:n])
b[1:n] .= _x
unsafe_free!(_x)
return b[1:n]
end

function LinearAlgebra.ldiv!(_qr::CuQR, B::StridedCuMatrix)
m,n = size(_qr)
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * B)[1:n, 1:size(B, 2)])
B[1:n, 1:size(B, 2)] .= _x
unsafe_free!(_x)
return B[1:n, 1:size(B, 2)]
end

function LinearAlgebra.ldiv!(x::StridedCuArray,_qr::CuQR, b::StridedCuArray)
_x = ldiv!(_qr, b)
x .= vec(_x)
unsafe_free!(_x)
return x
end

end

## SVD

Expand Down
Loading