From 4ca34a2c1f298a578cd77abef9cbf2f75d8409b6 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 26 Sep 2024 14:50:26 -0700 Subject: [PATCH] Add support for int64_t indices and offsets in TBE inference [3/N] (#3124) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3124 X-link: https://github.com/facebookresearch/FBGEMM/pull/213 - Convert `pruned_hashmap_lookup_cuda` to use index_t Reviewed By: jianyuh Differential Revision: D62277673 --- ...mbedding_forward_quantized_split_lookup.cu | 68 +++++++++++-------- .../embedding_forward_template_helpers.cuh | 12 ++++ 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 52f2a49ddc..86165bb39f 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -14,19 +14,20 @@ using Tensor = at::Tensor; namespace nbit { +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 hash_table, const pta::PackedTensorAccessor32 hash_table_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { // uint32_t capacity = hash_table.size(0); const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t table_start = hash_table_offsets[t]; const int64_t table_end = hash_table_offsets[t + 1]; @@ -51,6 +52,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru return; } + using hash_t = + std::conditional_t, uint64_t, uint32_t>; + const uint32_t subwarp_id = threadIdx.x / 4; const uint32_t subwarp_tid = threadIdx.x % 4; #ifdef USE_ROCM @@ -58,13 +62,15 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru #else const uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); #endif + for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { - const int32_t idx = indices[indices_start + l_start + subwarp_id]; - uint32_t slot_start = - pruned_hash_function(static_cast(idx)) % capacity; + const index_t idx = indices[indices_start + l_start + subwarp_id]; + hash_t slot_start = + pruned_hash_function(static_cast(idx)) % capacity; + while (true) { - const uint32_t slot = (slot_start + subwarp_tid) % capacity; + const hash_t slot = (slot_start + subwarp_tid) % capacity; const int2 val = *reinterpret_cast( &hash_table[table_start + static_cast(slot)][0]); const int32_t slot_sparse_idx = val.x; @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } + if (__any_sync(subwarp_mask, found)) { break; } else if (__any_sync(subwarp_mask, empty)) { @@ -133,6 +140,8 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } // namespace nbit +using namespace nbit; + Tensor pruned_hashmap_lookup_cuda( Tensor indices, Tensor offsets, @@ -140,6 +149,7 @@ Tensor pruned_hashmap_lookup_cuda( Tensor hash_table_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); CUDA_DEVICE_GUARD(indices); @@ -150,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< - nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< + nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; @@ -209,10 +221,10 @@ Tensor pruned_array_lookup_cuda( AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), 0, @@ -224,8 +236,8 @@ Tensor pruned_array_lookup_cuda( B, T, MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index 97353e03cd..2164afd3e4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -88,6 +88,7 @@ __device__ inline int32_t padded_D( __device__ inline uint32_t pruned_hash_function(uint32_t h) { // MurmorHash3 32-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp h ^= h >> 16; h *= 0x85ebca6b; h ^= h >> 13; @@ -96,6 +97,17 @@ __device__ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +__device__ inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + // ---------------------- START cp.async helpers, copied from CUTLASS /// CUTLASS helper to get SMEM pointer