Skip to content

Commit

Permalink
important performance fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Aug 8, 2023
1 parent 269be86 commit 9a1c961
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
version = "4.0.2"
version = "4.0.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
71 changes: 56 additions & 15 deletions src/implementation/strided.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,24 @@ function tensoradd!(C::StridedView, pC::Index2Tuple,
if !istrivialpermutation(pC) && Base.mightalias(C, A)
throw(ArgumentError("output tensor must not be aliased with input tensor"))
end
add!(C, permutedims(flag2op(conjA)(A), linearize(pC)), α, β)
A′ = permutedims(flag2op(conjA)(A), linearize(pC))
if isone(α)
if iszero(β)
return Strided._mapreducedim!(identity, +, zero, size(C), (C, A′))
elseif isone(β)
Strided._mapreducedim!(identity, +, nothing, size(C), (C, A′))
else
Strided._mapreducedim!(identity, +, y -> β * y, size(C), (C, A′))
end
else
if iszero(β)
Strided._mapreducedim!(x -> α * x, +, zero, size(C), (C, A′))
elseif isone(β)
Strided._mapreducedim!(x -> α * x, +, nothing, size(C), (C, A′))
else
Strided._mapreducedim!(x -> α * x, +, y -> β * y, size(C), (C, A′))
end
end
return C
end

Expand All @@ -29,22 +46,22 @@ function tensortrace!(C::StridedView, pC::Index2Tuple,
tracesize = sizeA.(pA[1])
newstrides = (strideA.(linearize(pC))..., (strideA.(pA[1]) .+ strideA.(pA[2]))...)
newsize = (size(C)..., tracesize...)
A2 = flag2op(conjA)(StridedView(A.parent, newsize, newstrides, A.offset, A.op))
if α != 1
if β == 0
Strided._mapreducedim!(x -> α * x, +, zero, newsize, (C, A2))
elseif β == 1
Strided._mapreducedim!(x -> α * x, +, nothing, newsize, (C, A2))
A′ = flag2op(conjA)(StridedView(A.parent, newsize, newstrides, A.offset, A.op))
if isone(α)
if iszero(β)
return Strided._mapreducedim!(identity, +, zero, newsize, (C, A′))
elseif isone(β)
Strided._mapreducedim!(identity, +, nothing, newsize, (C, A′))
else
Strided._mapreducedim!(x -> α * x, +, y -> β * y, newsize, (C, A2))
Strided._mapreducedim!(identity, +, y -> β * y, newsize, (C, A′))
end
else
if β == 0
return Strided._mapreducedim!(identity, +, zero, newsize, (C, A2))
elseif β == 1
Strided._mapreducedim!(identity, +, nothing, newsize, (C, A2))
if iszero(β)
Strided._mapreducedim!(x -> α * x, +, zero, newsize, (C, A′))
elseif isone(β)
Strided._mapreducedim!(x -> α * x, +, nothing, newsize, (C, A′))
else
Strided._mapreducedim!(identity, +, y -> β * y, newsize, (C, A2))
Strided._mapreducedim!(x -> α * x, +, y -> β * y, newsize, (C, A′))
end
end
return C
Expand Down Expand Up @@ -77,6 +94,30 @@ function tensorcontract!(C::StridedView{T}, pC::Index2Tuple,
end
end

# reduce overhead for the case where it is just matrix multiplication
function tensorcontract!(C::StridedView{T,2}, pC::Index2Tuple{1,1},
A::StridedView{T,2}, pA::Index2Tuple{1,1}, conjA::Symbol,
B::StridedView{T,2}, pB::Index2Tuple{1,1}, conjB::Symbol,
α, β,
backend::StridedBLAS=StridedBLAS()) where {T<:LinearAlgebra.BlasFloat}
argcheck_tensorcontract(C, pC, A, pA, B, pB)
dimcheck_tensorcontract(C, pC, A, pA, B, pB)

(Base.mightalias(C, A) || Base.mightalias(C, B)) &&
throw(ArgumentError("output tensor must not be aliased with input tensor"))

opA = flag2op(conjA)
opB = flag2op(conjB)
A′ = pA == ((1,), (2,)) ? opA(A) : opA(permutedims(A, (pA[1][1], pA[2][1])))
B′ = pB == ((1,), (2,)) ? opB(B) : opB(permutedims(B, (pB[1][1], pB[2][1])))
if pC == ((1,), (2,))
mul!(C, A′, B′, α, β)
elseif pC == ((2,), (1,))
mul!(C, transpose(A′), transpose(B′), α, β)
end
return C
end

function tensorcontract!(C::StridedView, pC::Index2Tuple,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
B::StridedView, pB::Index2Tuple, conjB::Symbol,
Expand Down Expand Up @@ -177,7 +218,7 @@ function _unsafe_blas_contract!(C::StridedView{T}, ipC,
return C
end

function makeblascontractable(A, pA, conjA, TC)
@inline function makeblascontractable(A, pA, conjA, TC)
flagA = isblascontractable(A, pA, conjA) && eltype(A) == TC
if !flagA
A_ = StridedView(TensorOperations.tensoralloc_add(TC, pA, A, conjA, true))
Expand Down Expand Up @@ -211,7 +252,7 @@ end

_canfuse(::Dims{0}, ::Dims{0}) = true, 1, 1
_canfuse(dims::Dims{1}, strides::Dims{1}) = true, dims[1], strides[1]
function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N}
@inline function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N}
if dims[1] == 0
return true, 0, 1
elseif dims[1] == 1
Expand Down

0 comments on commit 9a1c961

Please sign in to comment.