Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [3/N] (p…
Browse files Browse the repository at this point in the history
…ytorch#3124)

Summary:
Pull Request resolved: pytorch#3124

X-link: facebookresearch/FBGEMM#213

- Convert `pruned_hashmap_lookup_cuda` to use index_t

Reviewed By: jianyuh

Differential Revision: D62277673
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 26, 2024
1 parent fb08393 commit 4ca34a2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ using Tensor = at::Tensor;

namespace nbit {

template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor64<int32_t, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor64<index_t, 2, at::RestrictPtrTraits>
hash_table,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_table_offsets,
const int32_t B,
const int32_t T,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
dense_indices) {
// uint32_t capacity = hash_table.size(0);
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
Expand All @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
if (b_t >= B * T) {
return;
}
const int32_t indices_start = offsets[t * B + b];
const int32_t indices_end = offsets[t * B + b + 1];
const int32_t L = indices_end - indices_start;
const index_t indices_start = offsets[t * B + b];
const index_t indices_end = offsets[t * B + b + 1];
const index_t L = indices_end - indices_start;

const int64_t table_start = hash_table_offsets[t];
const int64_t table_end = hash_table_offsets[t + 1];
Expand All @@ -51,20 +52,25 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
return;
}

using hash_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const uint32_t subwarp_id = threadIdx.x / 4;
const uint32_t subwarp_tid = threadIdx.x % 4;
#ifdef USE_ROCM
const uint64_t subwarp_mask = static_cast<uint64_t>(0xF) << (4 * subwarp_id);
#else
const uint32_t subwarp_mask = static_cast<uint32_t>(0xF) << (4 * subwarp_id);
#endif

for (int32_t l_start = 0; l_start + subwarp_id < L;
l_start += kWarpSize / 4) {
const int32_t idx = indices[indices_start + l_start + subwarp_id];
uint32_t slot_start =
pruned_hash_function(static_cast<uint32_t>(idx)) % capacity;
const index_t idx = indices[indices_start + l_start + subwarp_id];
hash_t slot_start =
pruned_hash_function(static_cast<hash_t>(idx)) % capacity;

while (true) {
const uint32_t slot = (slot_start + subwarp_tid) % capacity;
const hash_t slot = (slot_start + subwarp_tid) % capacity;
const int2 val = *reinterpret_cast<const int2*>(
&hash_table[table_start + static_cast<int64_t>(slot)][0]);
const int32_t slot_sparse_idx = val.x;
Expand All @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
found = true;
dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx;
}

if (__any_sync(subwarp_mask, found)) {
break;
} else if (__any_sync(subwarp_mask, empty)) {
Expand Down Expand Up @@ -133,13 +140,16 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru

} // namespace nbit

using namespace nbit;

Tensor pruned_hashmap_lookup_cuda(
Tensor indices,
Tensor offsets,
Tensor hash_table,
Tensor hash_table_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, hash_table, hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table);

CUDA_DEVICE_GUARD(indices);

Expand All @@ -150,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda(
TORCH_CHECK(hash_table.size(0) < std::numeric_limits<int32_t>::max());
constexpr size_t kForwardMaxThreads = 256;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
#endif

nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<<
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32));
int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<<
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dense_indices;
Expand Down Expand Up @@ -209,10 +221,10 @@ Tensor pruned_array_lookup_cuda(
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
"int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel";
#endif
nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
Expand All @@ -224,8 +236,8 @@ Tensor pruned_array_lookup_cuda(
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dense_indices;
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ __device__ inline int32_t padded_D(

__device__ inline uint32_t pruned_hash_function(uint32_t h) {
// MurmorHash3 32-bit mixing function.
// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
Expand All @@ -96,6 +97,17 @@ __device__ inline uint32_t pruned_hash_function(uint32_t h) {
return h;
}

__device__ inline uint64_t pruned_hash_function(uint64_t k) {
// MurmorHash3 64-bit mixing function.
// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp
k ^= k >> 33;
k *= (0xff51afd7ed558ccd);
k ^= k >> 33;
k *= (0xc4ceb9fe1a85ec53);
k ^= k >> 33;
return k;
}

// ---------------------- START cp.async helpers, copied from CUTLASS

/// CUTLASS helper to get SMEM pointer
Expand Down

0 comments on commit 4ca34a2

Please sign in to comment.