Skip to content

Commit

Permalink
Start allocator implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jul 8, 2024
1 parent 625e111 commit 466c115
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ authors = [
version = "5.0.0"

[deps]
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Expand All @@ -30,6 +32,7 @@ TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Aqua = "0.6, 0.7, 0.8"
Bumper = "0.6"
CUDA = "5.4.0"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Expand All @@ -38,6 +41,7 @@ LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
PackageExtensionCompat = "1"
PtrArrays = "1.2"
Random = "1"
Strided = "2.0.4"
StridedViews = "0.3"
Expand Down
1 change: 1 addition & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using LinearAlgebra
using LinearAlgebra: mul!, BlasFloat
using Strided
using StridedViews: isstrided
using PtrArrays, Bumper
using LRUCache

using Base.Meta: isexpr
Expand Down
81 changes: 80 additions & 1 deletion src/implementation/allocator.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ------------------------------------------------------------------------------------------
# DefaultAllocator
# Allocator backends
# ------------------------------------------------------------------------------------------
"""
DefaultAllocator()
Expand Down Expand Up @@ -30,6 +30,17 @@ parameters `Min`, `Mout`, `Mtemp`` can be any of the CUDA.jl memory types, i.e.
"""
struct CUDAAllocator{Mout,Min,Mtemp} end

"""
ManualAllocator()
Allocator that bypasses Julia's memory management for temporary tensors by leveraging `Libc.malloc`
and `Libc.free` directly. This can be useful for reducing the pressure on the garbage collector.
This backend will allocate using `DefaultAllocator` for output tensors that escape the `@tensor`
block, which will thus still be managed using Julia's GC. The other tensors will be backed by
`PtrArray` instances, from `PtrArrays.jl`, thus requiring compatibility with that interface.
"""
struct ManualAllocator end

# ------------------------------------------------------------------------------------------
# Generic implementation
# ------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -148,3 +159,71 @@ end
function tensorfree!(C, allocator=DefaultAllocator())
return nothing
end

# ------------------------------------------------------------------------------------------
# ManualAllocator implementation
# ------------------------------------------------------------------------------------------
Base.@constprop :aggressive function tensoralloc_add(TC, A, pA::Index2Tuple, conjA::Bool,
istemp::Bool,
::ManualAllocator)
structure = tensoradd_structure(A, pA, conjA)
if istemp
return malloc(TC, structure...)
else
ttype = tensoradd_type(TC, A, pA, conjA)
return tensoralloc(ttype, structure, istemp)::ttype
end
end

Base.@constprop :aggressive function tensoralloc_contract(TC,
A, pA::Index2Tuple, conjA::Bool,
B, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple, istemp::Bool,
::ManualAllocator)
structure = tensorcontract_structure(A, pA, conjA, B, pB, conjB, pAB)
if istemp
return malloc(TC, structure...)
else
ttype = tensorcontract_type(TC, A, pA, conjA, B, pB, conjB, pAB)
return tensoralloc(ttype, structure, istemp)::ttype
end
end

function tensorfree!(C, ::ManualAllocator)
free(C)
return nothing
end

# ------------------------------------------------------------------------------------------
# BumperAllocator implementation
# ------------------------------------------------------------------------------------------

Base.@constprop :aggressive function tensoralloc_add(TC, A::AbstractArray, pA::Index2Tuple,
conjA::Bool,
istemp::Bool,
buf::Union{SlabBuffer,AllocBuffer})
structure = tensoradd_structure(A, pA, conjA)
if istemp
return Bumper.alloc!(buf, TC, structure...)
else
ttype = tensoradd_type(TC, A, pA, conjA)
return tensoralloc(ttype, structure, istemp)::ttype
end
end

Base.@constprop :aggressive function tensoralloc_contract(TC,
A::AbstractArray, pA::Index2Tuple,
conjA::Bool,
B::AbstractArray, pB::Index2Tuple,
conjB::Bool,
pAB::Index2Tuple, istemp::Bool,
buf::Union{SlabBuffer,
AllocBuffer})
structure = tensorcontract_structure(A, pA, conjA, B, pB, conjB, pAB)
if istemp
return Bumper.alloc!(buf, TC, structure...)
else
ttype = tensorcontract_type(TC, A, pA, conjA, B, pB, conjB, pAB)
return tensoralloc(ttype, structure, istemp)::ttype
end
end
23 changes: 23 additions & 0 deletions src/indexnotation/tensormacros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,26 @@ macro cutensor(ex::Expr)
Expr(:call, GlobalRef(TensorOperations, :CUDAAllocator))))
return esc(parser(ex))
end

"""
@butensor tensor_expr
Use Bumper.jl to handle allocation of temporary tensors. This macro will use the default
buffer and automatically reset it after the tensor expression has been evaluated. This macro
is equivalent to `@no_escape @tensor tensor_expr` with all temporary allocations handled by
Bumper.jl.
"""
macro butensor(ex...)
buf_sym = gensym("buffer")
cp_sym = gensym("checkpoint")
res_sym = gensym("result")
return esc(quote
$buf_sym = $(Expr(:call, GlobalRef(Bumper, :default_buffer)))
$cp_sym = $(Expr(:call, GlobalRef(Bumper, :checkpoint_save), buf_sym))
$res_sym = $(Expr(:macrocall,
GlobalRef(TensorOperations, Symbol("@tensor")),
__source__, :(allocator = $buf_sym), ex...))
$(Expr(:call, GlobalRef(Bumper, :checkpoint_restore!), cp_sym))
$res_sym
end)
end

0 comments on commit 466c115

Please sign in to comment.