diff --git a/Project.toml b/Project.toml index f3d7cdd2..ce73163a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorOperations" uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" authors = ["Lukas Devos ", "Maarten Van Damme ", "Jutho Haegeman "] -version = "4.0.2" +version = "4.0.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/implementation/strided.jl b/src/implementation/strided.jl index 84046926..a63b149f 100644 --- a/src/implementation/strided.jl +++ b/src/implementation/strided.jl @@ -11,7 +11,24 @@ function tensoradd!(C::StridedView, pC::Index2Tuple, if !istrivialpermutation(pC) && Base.mightalias(C, A) throw(ArgumentError("output tensor must not be aliased with input tensor")) end - add!(C, permutedims(flag2op(conjA)(A), linearize(pC)), α, β) + A′ = permutedims(flag2op(conjA)(A), linearize(pC)) + if isone(α) + if iszero(β) + return Strided._mapreducedim!(identity, +, zero, size(C), (C, A′)) + elseif isone(β) + Strided._mapreducedim!(identity, +, nothing, size(C), (C, A′)) + else + Strided._mapreducedim!(identity, +, y -> β * y, size(C), (C, A′)) + end + else + if iszero(β) + Strided._mapreducedim!(x -> α * x, +, zero, size(C), (C, A′)) + elseif isone(β) + Strided._mapreducedim!(x -> α * x, +, nothing, size(C), (C, A′)) + else + Strided._mapreducedim!(x -> α * x, +, y -> β * y, size(C), (C, A′)) + end + end return C end @@ -29,22 +46,22 @@ function tensortrace!(C::StridedView, pC::Index2Tuple, tracesize = sizeA.(pA[1]) newstrides = (strideA.(linearize(pC))..., (strideA.(pA[1]) .+ strideA.(pA[2]))...) newsize = (size(C)..., tracesize...) - A2 = flag2op(conjA)(StridedView(A.parent, newsize, newstrides, A.offset, A.op)) - if α != 1 - if β == 0 - Strided._mapreducedim!(x -> α * x, +, zero, newsize, (C, A2)) - elseif β == 1 - Strided._mapreducedim!(x -> α * x, +, nothing, newsize, (C, A2)) + A′ = flag2op(conjA)(StridedView(A.parent, newsize, newstrides, A.offset, A.op)) + if isone(α) + if iszero(β) + return Strided._mapreducedim!(identity, +, zero, newsize, (C, A′)) + elseif isone(β) + Strided._mapreducedim!(identity, +, nothing, newsize, (C, A′)) else - Strided._mapreducedim!(x -> α * x, +, y -> β * y, newsize, (C, A2)) + Strided._mapreducedim!(identity, +, y -> β * y, newsize, (C, A′)) end else - if β == 0 - return Strided._mapreducedim!(identity, +, zero, newsize, (C, A2)) - elseif β == 1 - Strided._mapreducedim!(identity, +, nothing, newsize, (C, A2)) + if iszero(β) + Strided._mapreducedim!(x -> α * x, +, zero, newsize, (C, A′)) + elseif isone(β) + Strided._mapreducedim!(x -> α * x, +, nothing, newsize, (C, A′)) else - Strided._mapreducedim!(identity, +, y -> β * y, newsize, (C, A2)) + Strided._mapreducedim!(x -> α * x, +, y -> β * y, newsize, (C, A′)) end end return C @@ -77,6 +94,30 @@ function tensorcontract!(C::StridedView{T}, pC::Index2Tuple, end end +# reduce overhead for the case where it is just matrix multiplication +function tensorcontract!(C::StridedView{T,2}, pC::Index2Tuple{1,1}, + A::StridedView{T,2}, pA::Index2Tuple{1,1}, conjA::Symbol, + B::StridedView{T,2}, pB::Index2Tuple{1,1}, conjB::Symbol, + α, β, + backend::StridedBLAS=StridedBLAS()) where {T<:LinearAlgebra.BlasFloat} + argcheck_tensorcontract(C, pC, A, pA, B, pB) + dimcheck_tensorcontract(C, pC, A, pA, B, pB) + + (Base.mightalias(C, A) || Base.mightalias(C, B)) && + throw(ArgumentError("output tensor must not be aliased with input tensor")) + + opA = flag2op(conjA) + opB = flag2op(conjB) + A′ = pA == ((1,), (2,)) ? opA(A) : opA(permutedims(A, (pA[1][1], pA[2][1]))) + B′ = pB == ((1,), (2,)) ? opB(B) : opB(permutedims(B, (pB[1][1], pB[2][1]))) + if pC == ((1,), (2,)) + mul!(C, A′, B′, α, β) + elseif pC == ((2,), (1,)) + mul!(C, transpose(A′), transpose(B′), α, β) + end + return C +end + function tensorcontract!(C::StridedView, pC::Index2Tuple, A::StridedView, pA::Index2Tuple, conjA::Symbol, B::StridedView, pB::Index2Tuple, conjB::Symbol, @@ -177,7 +218,7 @@ function _unsafe_blas_contract!(C::StridedView{T}, ipC, return C end -function makeblascontractable(A, pA, conjA, TC) +@inline function makeblascontractable(A, pA, conjA, TC) flagA = isblascontractable(A, pA, conjA) && eltype(A) == TC if !flagA A_ = StridedView(TensorOperations.tensoralloc_add(TC, pA, A, conjA, true)) @@ -211,7 +252,7 @@ end _canfuse(::Dims{0}, ::Dims{0}) = true, 1, 1 _canfuse(dims::Dims{1}, strides::Dims{1}) = true, dims[1], strides[1] -function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N} +@inline function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N} if dims[1] == 0 return true, 0, 1 elseif dims[1] == 1