Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative Allocators #182

Merged
merged 25 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ authors = [
version = "5.0.0"

[deps]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Expand All @@ -30,6 +32,7 @@ TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Aqua = "0.6, 0.7, 0.8"
Bumper = "0.6"
CUDA = "5.4.0"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Expand All @@ -38,6 +41,7 @@ LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
PackageExtensionCompat = "1"
PtrArrays = "1.2"
Random = "1"
Strided = "2.0.4"
StridedViews = "0.3"
Expand Down
1 change: 1 addition & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using LinearAlgebra
using LinearAlgebra: mul!, BlasFloat
using Strided
using StridedViews: isstrided
using PtrArrays, Bumper
using LRUCache

using Base.Meta: isexpr
Expand Down
62 changes: 51 additions & 11 deletions src/implementation/allocator.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ------------------------------------------------------------------------------------------
# DefaultAllocator
# Allocator backends
# ------------------------------------------------------------------------------------------
"""
DefaultAllocator()
Expand Down Expand Up @@ -30,6 +30,17 @@ parameters `Min`, `Mout`, `Mtemp`` can be any of the CUDA.jl memory types, i.e.
"""
struct CUDAAllocator{Mout,Min,Mtemp} end

"""
ManualAllocator()

Allocator that bypasses Julia's memory management for temporary tensors by leveraging `Libc.malloc`
and `Libc.free` directly. This can be useful for reducing the pressure on the garbage collector.
This backend will allocate using `DefaultAllocator` for output tensors that escape the `@tensor`
block, which will thus still be managed using Julia's GC. The other tensors will be backed by
`PtrArray` instances, from `PtrArrays.jl`, thus requiring compatibility with that interface.
"""
struct ManualAllocator end

# ------------------------------------------------------------------------------------------
# Generic implementation
# ------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -61,11 +72,11 @@ used to implement different allocation strategies.

See also [`tensoralloc`](@ref) and [`tensorfree!`](@ref).
"""
function tensoralloc_add(TC, A, pA::Index2Tuple, conjA::Bool, istemp::Bool=false,
function tensoralloc_add(TC, A, pA::Index2Tuple, conjA::Bool, istemp::Val=Val(false),
allocator=DefaultAllocator())
ttype = tensoradd_type(TC, A, pA, conjA)
structure = tensoradd_structure(A, pA, conjA)
return tensoralloc(ttype, structure, istemp, allocator)::ttype
return tensoralloc(ttype, structure, istemp, allocator)
end

"""
Expand All @@ -84,11 +95,11 @@ See also [`tensoralloc`](@ref) and [`tensorfree!`](@ref).
function tensoralloc_contract(TC,
A, pA::Index2Tuple, conjA::Bool,
B, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple, istemp::Bool=false,
pAB::Index2Tuple, istemp::Val=Val(false),
allocator=DefaultAllocator())
ttype = tensorcontract_type(TC, A, pA, conjA, B, pB, conjB, pAB)
structure = tensorcontract_structure(A, pA, conjA, B, pB, conjB, pAB)
return tensoralloc(ttype, structure, istemp, allocator)::ttype
return tensoralloc(ttype, structure, istemp, allocator)
end

# ------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -118,8 +129,9 @@ function tensoradd_structure(A::AbstractArray, pA::Index2Tuple, conjA::Bool)
return size.(Ref(A), linearize(pA))
end

function tensorcontract_type(TC, A::AbstractArray, pA, conjA,
B::AbstractArray, pB, conjB, pAB)
function tensorcontract_type(TC, A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple)
T1 = tensoradd_type(TC, A, pAB, conjA)
T2 = tensoradd_type(TC, B, pAB, conjB)
if T1 == T2
Expand All @@ -129,14 +141,15 @@ function tensorcontract_type(TC, A::AbstractArray, pA, conjA,
end
end

function tensorcontract_structure(A::AbstractArray, pA, conjA,
B::AbstractArray, pB, conjB, pAB)
function tensorcontract_structure(A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple)
return let lA = length(pA[1])
map(n -> n <= lA ? size(A, pA[1][n]) : size(B, pB[2][n - lA]), linearize(pAB))
end
end

function tensoralloc(ttype, structure, istemp=false, allocator=DefaultAllocator())
function tensoralloc(ttype, structure, ::Val=Val(false), allocator=DefaultAllocator())
C = ttype(undef, structure)
# fix an issue with undefined references for strided arrays
if !isbitstype(scalartype(ttype))
Expand All @@ -145,6 +158,33 @@ function tensoralloc(ttype, structure, istemp=false, allocator=DefaultAllocator(
return C
end

function tensorfree!(C, allocator=DefaultAllocator())
tensorfree!(C, allocator=DefaultAllocator()) = nothing

# ------------------------------------------------------------------------------------------
# ManualAllocator implementation
# ------------------------------------------------------------------------------------------
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
function tensoralloc(::Type{A}, structure, ::Val{istemp},
::ManualAllocator) where {A<:AbstractArray,istemp}
if istemp
return malloc(eltype(A), structure...)
else
return tensoralloc(A, structure, istemp)
end
end

function tensorfree!(C::PtrArray, ::ManualAllocator)
free(C)
return nothing
end

# ------------------------------------------------------------------------------------------
# BumperAllocator implementation
# ------------------------------------------------------------------------------------------
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
function tensoralloc(::Type{A}, structure, ::Val{istemp},
buf::Union{SlabBuffer,AllocBuffer}) where {A<:AbstractArray,istemp}
if istemp
return Bumper.alloc!(buf, eltype(A), structure...)
else
return tensoralloc(A, structure, istemp)
end
end
85 changes: 46 additions & 39 deletions src/implementation/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ function tensoradd!(C::AbstractArray,

# can we assume that C is mutable?
# is there more functionality in base that we can use?
Atemp = tensoralloc_add(eltype(A), A, pA, conjA, true, allocator)
= permutedims!(Atemp, A, linearize(pA))
Atemp = tensoralloc_add(eltype(A), A, pA, conjA, Val(true), allocator)
Atemp = permutedims!(Atemp, A, linearize(pA))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, there is no guarantee that the lhs is the same object as the one we allocated. In that case, we created a memory leak here. In practice this probably does not really happen, but this is definitely why I used different variable names in my initial change.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, I do think permutedims! does guarantee to store the result in the first argument, i.e. this should work without the Atemp = part.

  permutedims!(dest, src, perm)

  Permute the dimensions of array src and store the result in the array dest. perm is a vector specifying a permutation of length ndims(src). The preallocated array dest should have size(dest) == size(src)[perm] and is completely
  overwritten. No in-place permutation is supported and unexpected results will happen if src and dest have overlapping memory regions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know the docstring guarantees this, but for example, in TensorOperations with AD, we abuse the fact that the macro automatically adds C = tensoradd!(C,...) to hack into the system and make the lhs a copy. In principle anyone could do something like this for a custom type, and then the memory management chain would be broken (doom thinking here)

if conjA
if iszero(β)
C .= α .* conj.()
C .= α .* conj.(Atemp)
else
C .= β .* C .+ α .* conj.()
C .= β .* C .+ α .* conj.(Atemp)
end
else
if iszero(β)
C .= α .*
C .= α .* Atemp
else
C .= β .* C .+ α .*
C .= β .* C .+ α .* Atemp
end
end
tensorfree!(Atemp, allocator)
return C
end

Expand Down Expand Up @@ -100,28 +101,28 @@ function tensortrace!(C::AbstractArray,
so = TupleTools.getindices(szA, linearize(p))
st = prod(TupleTools.getindices(szA, q[1]))
perm = (linearize(p)..., linearize(q)...)
Atemp = tensoralloc_add(eltype(A), A, (perm, ()), conjA, true, allocator)
à = reshape(permutedims!(Atemp, A, perm), (prod(so), st, st))

Atemp′ = tensoralloc_add(eltype(A), A, (perm, ()), conjA, Val(true), allocator)
Ãtemp = reshape(permutedims!(Atemp′, A, perm), (prod(so), st, st))
if conjA
if iszero(β)
C .= α .* conj.(reshape(view(, :, 1, 1), so))
C .= α .* conj.(reshape(view(Ãtemp, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(, :, 1, 1), so))
C .= β .* C .+ α .* conj.(reshape(view(Ãtemp, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(, :, i, i), so))
C .+= α .* conj.(reshape(view(Ãtemp, :, i, i), so))
end
else
if iszero(β)
C .= α .* reshape(view(, :, 1, 1), so)
C .= α .* reshape(view(Ãtemp, :, 1, 1), so)
else
C .= β .* C .+ α .* reshape(view(, :, 1, 1), so)
C .= β .* C .+ α .* reshape(view(Ãtemp, :, 1, 1), so)
end
for i in 2:st
C .+= α .* reshape(view(, :, i, i), so)
C .+= α .* reshape(view(Ãtemp, :, i, i), so)
end
end
tensorfree!(Atemp′, allocator)
return C
end

Expand Down Expand Up @@ -182,44 +183,50 @@ function tensorcontract!(C::AbstractArray,
soB1 = prod(soB)
sc1 = prod(sc)

AB = tensoralloc_contract(eltype(C), A, pA, conjA, B, pB, conjB,
trivialpermutation(pAB), true, allocator)
ÃB̃ = reshape(AB, (soA1, soB1))
AB = tensoralloc_contract(eltype(C), A, pA, conjA, B, pB, conjB,
trivialpermutation(pAB), Val(true), allocator)
ÃB̃ = reshape(AB, (soA1, soB1))
if conjA && conjB
Atemp = tensoralloc_add(eltype(C), A, reverse(pA), conjA, true, allocator)
Btemp = tensoralloc_add(eltype(C), B, reverse(pB), conjB, true, allocator)
à = reshape(permutedims!(Atemp, A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(permutedims!(Btemp, B, linearize(reverse(pB))), (soB1, sc1))
mul!(ÃB̃, adjoint(Ã), adjoint(B̃))
Atemp′ = tensoralloc_add(eltype(C), A, reverse(pA), conjA, Val(true), allocator)
Btemp′ = tensoralloc_add(eltype(C), B, reverse(pB), conjB, Val(true), allocator)
Ãtemp = adjoint(reshape(permutedims!(Atemp′, A, linearize(reverse(pA))),
(sc1, soA1)))
B̃temp = adjoint(reshape(permutedims!(Btemp′, B, linearize(reverse(pB))),
(soB1, sc1)))
elseif conjA
Atemp = tensoralloc_add(eltype(C), A, reverse(pA), conjA, true, allocator)
Btemp = tensoralloc_add(eltype(C), B, pB, conjB, true, allocator)
= reshape(permutedims!(Atemp, A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(permutedims!(Btemp, B, linearize(pB)), (sc1, soB1))
mul!(ÃB̃, adjoint(Ã), B̃)
Atemp = tensoralloc_add(eltype(C), A, reverse(pA), conjA, Val(true), allocator)
Btemp = tensoralloc_add(eltype(C), B, pB, conjB, Val(true), allocator)
Ãtemp = adjoint(reshape(permutedims!(Atemp, A, linearize(reverse(pA))),
(sc1, soA1)))
B̃temp = reshape(permutedims!(Btemp′, B, linearize(pB)), (sc1, soB1))
elseif conjB
Atemp = tensoralloc_add(eltype(C), A, pA, conjA, true, allocator)
Btemp = tensoralloc_add(eltype(C), B, reverse(pB), conjB, true, allocator)
à = reshape(permutedims!(Atemp, A, linearize(pA)), (soA1, sc1))
B̃ = reshape(permutedims!(Btemp, B, linearize(reverse(pB))), (soB1, sc1))
Atemp = tensoralloc_add(eltype(C), A, pA, conjA, Val(true), allocator)
Btemp = tensoralloc_add(eltype(C), B, reverse(pB), conjB, Val(true), allocator)
à = reshape(permutedims!(Atemp, A, linearize(pA)), (soA1, sc1))
B̃ = adjoint(reshape(permutedims!(Btemp, B, linearize(reverse(pB))), (soB1, sc1)))
mul!(ÃB̃, Ã, adjoint(B̃))
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
else
Atemp = tensoralloc_add(eltype(C), A, pA, conjA, true, allocator)
Btemp = tensoralloc_add(eltype(C), B, pB, conjB, true, allocator)
à = reshape(permutedims!(Atemp, A, linearize(pA)), (soA1, sc1))
B̃ = reshape(permutedims!(Btemp, B, linearize(pB)), (sc1, soB1))
mul!(ÃB̃, Ã, B̃)
Atemp′ = tensoralloc_add(eltype(C), A, pA, conjA, Val(true), allocator)
Btemp′ = tensoralloc_add(eltype(C), B, pB, conjB, Val(true), allocator)
à = reshape(permutedims!(Atemp′, A, linearize(pA)), (soA1, sc1))
B̃ = reshape(permutedims!(Btemp′, B, linearize(pB)), (sc1, soB1))
end
if istrivialpermutation(linearize(pAB))
mul!(ÃB̃, adjoint(Ãtemp), adjoint(B̃temp))
tensorfree!(Atemp′, allocator)
tensorfree!(Btemp′, allocator)
ABtemp = !istrivialpermutation(linearize(pAB))
if ABtemp
pAB = AB
else
pABtemp = tensoralloc_add(eltype(C), AB, pAB, false, true, allocator)
pABtemp = tensoralloc_add(eltype(C), AB, pAB, false, Val(true), allocator)
pAB = permutedims!(pABtemp, AB, linearize(pAB))
end
if iszero(β)
C .= α .* pAB
else
C .= β .* C .+ α .* pAB
end
ABtemp || tensorfree!(pAB, allocator)
tensorfree!(AB′, allocator)
return C
end
8 changes: 4 additions & 4 deletions src/implementation/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
function tensorcopy(A, pA::Index2Tuple, conjA::Bool=false, α::Number=One(),
backend=DefaultBackend(), allocator=DefaultAllocator())
TC = promote_add(scalartype(A), scalartype(α))
C = tensoralloc_add(TC, A, pA, conjA)
C = tensoralloc_add(TC, A, pA, conjA, Val(false), allocator)
return tensorcopy!(C, A, pA, conjA, α, backend, allocator)
end

Expand Down Expand Up @@ -105,7 +105,7 @@ function tensoradd(A, pA::Index2Tuple, conjA::Bool,
α::Number=One(), β::Number=One(),
backend=DefaultBackend(), allocator=DefaultAllocator())
TC = promote_add(scalartype(A), scalartype(B), scalartype(α), scalartype(β))
C = tensoralloc_add(TC, A, pA, conjA, false, allocator)
C = tensoralloc_add(TC, A, pA, conjA, Val(false), allocator)
C = tensorcopy!(C, A, pA, conjA, α, backend, allocator)
return tensoradd!(C, B, pB, conjB, β, One(), backend, allocator)
end
Expand Down Expand Up @@ -151,7 +151,7 @@ end
function tensortrace(A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number=One(),
backend=DefaultBackend(), allocator=DefaultAllocator())
TC = promote_contract(scalartype(A), scalartype(α))
C = tensoralloc_add(TC, A, p, conjA, false, allocator)
C = tensoralloc_add(TC, A, p, conjA, Val(false), allocator)
return tensortrace!(C, A, p, q, conjA, α, Zero(), backend, allocator)
end

Expand Down Expand Up @@ -203,7 +203,7 @@ function tensorcontract(A, pA::Index2Tuple, conjA::Bool,
pAB::Index2Tuple, α::Number=One(),
backend=DefaultBackend(), allocator=DefaultAllocator())
TC = promote_contract(scalartype(A), scalartype(B), scalartype(α))
C = tensoralloc_contract(TC, A, pA, conjA, B, pB, conjB, pAB, false, allocator)
C = tensoralloc_contract(TC, A, pA, conjA, B, pB, conjB, pAB, Val(false), allocator)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend,
allocator)
end
Expand Down
4 changes: 2 additions & 2 deletions src/implementation/strided.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ function blas_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, allocator)
C_ = C
_unsafe_blas_contract!(C_, A_, pA, B_, pB, ipAB, α, β)
else
C_ = StridedView(TensorOperations.tensoralloc_add(TC, C, ipAB, false, true,
C_ = StridedView(TensorOperations.tensoralloc_add(TC, C, ipAB, false, Val(true),
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
allocator))
_unsafe_blas_contract!(C_, A_, pA, B_, pB, trivialpermutation(ipAB),
one(TC), zero(TC))
Expand Down Expand Up @@ -257,7 +257,7 @@ end
@inline function makeblascontractable(A, pA, TC, allocator)
flagA = isblascontractable(A, pA) && eltype(A) == TC
if !flagA
A_ = StridedView(TensorOperations.tensoralloc_add(TC, A, pA, false, true,
A_ = StridedView(TensorOperations.tensoralloc_add(TC, A, pA, false, Val(true),
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
allocator))
A = tensoradd!(A_, A, pA, false, One(), Zero())
pA = trivialpermutation(pA)
Expand Down
4 changes: 2 additions & 2 deletions src/indexnotation/instantiators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function instantiate_generaltensor(dst, β, ex::Expr, α, leftind::Vector{Any},
end
if alloc ∈ (NewTensor, TemporaryTensor)
TC = gensym("T_" * string(dst))
istemporary = (alloc === TemporaryTensor)
istemporary = Val(alloc === TemporaryTensor)
if scaltype === nothing
TCval = α === One() ? instantiate_scalartype(src) :
instantiate_scalartype(Expr(:call, :*, α, src))
Expand Down Expand Up @@ -260,7 +260,7 @@ function instantiate_contraction(dst, β, ex::Expr, α, leftind::Vector{Any},
else
TCval = scaltype
end
istemporary = alloc === TemporaryTensor
istemporary = Val(alloc === TemporaryTensor)
initC = Expr(:block, Expr(:(=), TCsym, TCval),
:($dst = tensoralloc_contract($TCsym, $A, $pA, $conjA, $B, $pB,
$conjB, $pAB, $istemporary)))
Expand Down
23 changes: 23 additions & 0 deletions src/indexnotation/tensormacros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,26 @@ macro cutensor(ex::Expr)
Expr(:call, GlobalRef(TensorOperations, :CUDAAllocator))))
return esc(parser(ex))
end

"""
@butensor tensor_expr

Use Bumper.jl to handle allocation of temporary tensors. This macro will use the default
buffer and automatically reset it after the tensor expression has been evaluated. This macro
is equivalent to `@no_escape @tensor tensor_expr` with all temporary allocations handled by
Bumper.jl.
"""
macro butensor(ex...)
buf_sym = gensym("buffer")
cp_sym = gensym("checkpoint")
res_sym = gensym("result")
return esc(quote
$buf_sym = $(Expr(:call, GlobalRef(Bumper, :default_buffer)))
$cp_sym = $(Expr(:call, GlobalRef(Bumper, :checkpoint_save), buf_sym))
$res_sym = $(Expr(:macrocall,
GlobalRef(TensorOperations, Symbol("@tensor")),
__source__, :(allocator = $buf_sym), ex...))
$(Expr(:call, GlobalRef(Bumper, :checkpoint_restore!), cp_sym))
$res_sym
end)
end
Loading