From 0466806d944a0db931a1a83f13c09bc28d33bbc6 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 26 Sep 2024 15:01:42 -0700 Subject: [PATCH] Add support for int64_t indices in TBE inference [2/N] (#3125) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/214 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3125 - Add support for int64_t indices in TBE inference [2/N] - Convert `pruned_array_lookup_cuda` to use index_t Reviewed By: jianyuh Differential Revision: D62271409 --- ...mbedding_forward_quantized_split_lookup.cu | 55 ++++++++++--------- .../include/fbgemm_gpu/utils/tensor_utils.h | 38 ++++++++++++- 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 7d4eebcce..52f2a49dd 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -89,19 +89,20 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; const int32_t t = b_t / B; @@ -109,22 +110,22 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t index_remappings_start = index_remappings_offsets[t]; const int64_t index_remappings_end = index_remappings_offsets[t + 1]; const int64_t capacity = index_remappings_end - index_remappings_start; if (capacity > 0) { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { - int32_t idx = indices[indices_start + l]; + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { + index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = index_remappings[index_remappings_start + idx]; } } else { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { dense_indices[indices_start + l] = indices[indices_start + l]; } } @@ -178,6 +179,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); CUDA_DEVICE_GUARD(indices); @@ -204,23 +206,26 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< - nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return dense_indices; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index f64205b7e..60cca19ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -306,8 +306,6 @@ std::string tensor_scalar_type_is_one_of( const ScalarTypes&... ttypes) { auto has_match = false; - // Collect the GPU index of the first non-empty optional tensor and make sure - // that all tensors are on this same index. ( [&](const auto& ttype) { if (ten.scalar_type() == ttype) { @@ -339,3 +337,39 @@ std::string tensor_scalar_type_is_one_of( const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ TORCH_CHECK(has_match.empty(), has_match); \ } while (false) + +template +std::string tensors_have_same_scalar_type(const Tensors&... tensors) { + std::optional dtype; + bool have_same_type = true; + + ( + [&](const auto& tensor) { + if (!dtype) { + dtype = tensor.scalar_type(); + } else if (*dtype != tensor.scalar_type()) { + have_same_type = false; + } + }(tensors), + ...); + + if (have_same_type) { + return ""; + } + + std::string msg = "Tensors' scalar types ("; + ( + [&](const auto& tensor) { + msg.append(toString(tensor.scalar_type())); + msg.append(", "); + }(tensors), + ...); + msg.append(") are not one and the same!"); + return msg; +} + +#define TENSORS_HAVE_SAME_SCALAR_TYPE(...) \ + do { \ + const auto have_same_type = tensors_have_same_scalar_type(__VA_ARGS__); \ + TORCH_CHECK(have_same_type.empty(), have_same_type); \ + } while (false)