From ebbe4763e17031a64a59ebdc8d83243e0ecd29b3 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 26 Sep 2024 14:42:50 -0700 Subject: [PATCH 1/2] Add support for int64_t indices in TBE inference [1/N] Summary: - Add support for int64_t indices in TBE inference [1/N] Differential Revision: D61813383 --- ...ward_quantized_split_nbit_host_template.cu | 233 +++++++++++++----- .../include/fbgemm_gpu/utils/tensor_utils.h | 40 +++ 2 files changed, 207 insertions(+), 66 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index bc4e7ba74..5dd5c30b1 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -7,7 +7,7 @@ */ // clang-format off -{% set wdesc = "weighted" if weighted else "unweighted" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor.h" @@ -22,7 +22,7 @@ namespace nbit { `Tensor int_nbit_split_embedding*_codegen_forward_*_cuda(...)` later in the same generated source file. */ -{% for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} +{%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( @@ -31,30 +31,30 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const pta::PackedTensorAccessor32 weights_placements, const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 weights_tys, - {% if not nobag %} + {%- if not nobag %} const pta::PackedTensorAccessor32 D_offsets, - {% else %} + {%- else %} const int64_t D, - {% endif %} + {%- endif %} FixedDivisor fd_B, // FixedDivisor(div_round_up(B, OutputRowsPerThread)) const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - {% if not nobag %} + {%- if not nobag %} const int64_t pooling_mode, - {% endif %} + {%- endif %} const int64_t row_alignment, - {% if weighted %} + {%- if weighted %} pta::PackedTensorAccessor32 indice_weights, - {% endif %} - {% if type_map[emb_weight_type].enum_name == "FP8" %} + {%- endif %} + {%- if type_map[emb_weight_type].enum_name == "FP8" %} const int fp8_exponent_bits, const int fp8_exponent_bias, - {% endif %} + {%- endif %} pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations ); -{% endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] +{%- endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] } @@ -107,58 +107,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no C10_CUDA_KERNEL_LAUNCH_CHECK(); \ {%- endmacro %} - -Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( - Tensor dev_weights, - Tensor uvm_weights, - Tensor weights_placements, - Tensor weights_offsets, - Tensor weights_tys, - {% if not nobag %} - Tensor D_offsets, - const int64_t total_D, - {% else %} - const int64_t D, - {% endif %} - const int64_t max_int2_D, - const int64_t max_int4_D, - const int64_t max_int8_D, - const int64_t max_float16_D, - const int64_t max_float32_D, - Tensor indices, - Tensor offsets, - {% if not nobag %} - const int64_t pooling_mode, - {% endif %} - const int64_t row_alignment, - {% if weighted %} - Tensor indice_weights, - {% endif %} - const int64_t output_dtype, - Tensor lxu_cache_weights, - Tensor lxu_cache_locations, - const int64_t max_float8_D, - const int64_t fp8_exponent_bits, - const int64_t fp8_exponent_bias -) { - TENSOR_ON_CUDA_GPU(dev_weights); - TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); - {% if not nobag %} - TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); - {% endif %} - TENSORS_ON_SAME_DEVICE(indices, dev_weights); - TENSORS_ON_SAME_DEVICE(offsets, dev_weights); - {% if weighted %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); - {% endif %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - - CUDA_DEVICE_GUARD(dev_weights); - +{%- macro construct_and_return_output_tensor() %} // kernels assume indices are contiguous. indices = indices.contiguous(); @@ -180,8 +129,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ TORCH_CHECK(D > 0); {%- endif %} + // Construct output tensor Tensor output; const int kINT8QparamsBytes = 8; + SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); @@ -216,11 +167,63 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if (B == 0 || indices.numel() == 0) { return output; } +{%- endmacro %} - using index_t = int32_t; +template +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - constexpr int32_t kWarpsPerBlock = 4; + CUDA_DEVICE_GUARD(dev_weights); + + {{- construct_and_return_output_tensor() }} + constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; #define Y(...) \ if (device_only) { \ @@ -397,6 +400,104 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ })); #undef X + return output; +} + +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + // All argument tensors need to be on the same CUDA device + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); + + // indices and offsets need to have the same scalar type + TENSORS_HAVE_SAME_TYPE(indices, offsets); + // Only int32_t and int64_t indices are supported at the moment + TENSOR_SCALAR_TYPE_IS_ONE_OF(indices, at::ScalarType::Long, at::ScalarType::Int); + + CUDA_DEVICE_GUARD(dev_weights); + + // Create output tensor ref + Tensor output; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ 'int_nbit_split_embedding' + ('_nobag' if nobag else '') + '_codegen_forward_' + wdesc + '_cuda' }}", [&] { + output = int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + dev_weights, + uvm_weights, + weights_placements, + weights_offsets, + weights_tys, + {%- if not nobag %} + D_offsets, + total_D, + {%- else %} + D, + {%- endif %} + max_int2_D, + max_int4_D, + max_int8_D, + max_float16_D, + max_float32_D, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + {%- endif %} + row_alignment, + {%- if weighted %} + indice_weights, + {%- endif %} + output_dtype, + lxu_cache_weights, + lxu_cache_locations, + max_float8_D, + fp8_exponent_bits, + fp8_exponent_bias); + }); + return output; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index b1ab0306c..f64205b7e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -299,3 +299,43 @@ inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards( } return aligned_grad_output; } + +template +std::string tensor_scalar_type_is_one_of( + const at::Tensor& ten, + const ScalarTypes&... ttypes) { + auto has_match = false; + + // Collect the GPU index of the first non-empty optional tensor and make sure + // that all tensors are on this same index. + ( + [&](const auto& ttype) { + if (ten.scalar_type() == ttype) { + has_match = true; + } + }(ttypes), + ...); + + if (has_match) { + return ""; + } + + std::string msg = "Tensor's scalar type ("; + msg.append(toString(ten.scalar_type())); + msg.append(") did not match any one of the following types: ["); + ( + [&](const auto& ttype) { + msg.append(toString(ttype)); + msg.append(", "); + }(ttypes), + ...); + + msg.append("]"); + return msg; +} + +#define TENSOR_SCALAR_TYPE_IS_ONE_OF(...) \ + do { \ + const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ + TORCH_CHECK(has_match.empty(), has_match); \ + } while (false) From 0466806d944a0db931a1a83f13c09bc28d33bbc6 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 26 Sep 2024 15:01:42 -0700 Subject: [PATCH 2/2] Add support for int64_t indices in TBE inference [2/N] (#3125) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/214 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3125 - Add support for int64_t indices in TBE inference [2/N] - Convert `pruned_array_lookup_cuda` to use index_t Reviewed By: jianyuh Differential Revision: D62271409 --- ...mbedding_forward_quantized_split_lookup.cu | 55 ++++++++++--------- .../include/fbgemm_gpu/utils/tensor_utils.h | 38 ++++++++++++- 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 7d4eebcce..52f2a49dd 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -89,19 +89,20 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; const int32_t t = b_t / B; @@ -109,22 +110,22 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t index_remappings_start = index_remappings_offsets[t]; const int64_t index_remappings_end = index_remappings_offsets[t + 1]; const int64_t capacity = index_remappings_end - index_remappings_start; if (capacity > 0) { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { - int32_t idx = indices[indices_start + l]; + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { + index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = index_remappings[index_remappings_start + idx]; } } else { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { dense_indices[indices_start + l] = indices[indices_start + l]; } } @@ -178,6 +179,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); CUDA_DEVICE_GUARD(indices); @@ -204,23 +206,26 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< - nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return dense_indices; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index f64205b7e..60cca19ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -306,8 +306,6 @@ std::string tensor_scalar_type_is_one_of( const ScalarTypes&... ttypes) { auto has_match = false; - // Collect the GPU index of the first non-empty optional tensor and make sure - // that all tensors are on this same index. ( [&](const auto& ttype) { if (ten.scalar_type() == ttype) { @@ -339,3 +337,39 @@ std::string tensor_scalar_type_is_one_of( const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ TORCH_CHECK(has_match.empty(), has_match); \ } while (false) + +template +std::string tensors_have_same_scalar_type(const Tensors&... tensors) { + std::optional dtype; + bool have_same_type = true; + + ( + [&](const auto& tensor) { + if (!dtype) { + dtype = tensor.scalar_type(); + } else if (*dtype != tensor.scalar_type()) { + have_same_type = false; + } + }(tensors), + ...); + + if (have_same_type) { + return ""; + } + + std::string msg = "Tensors' scalar types ("; + ( + [&](const auto& tensor) { + msg.append(toString(tensor.scalar_type())); + msg.append(", "); + }(tensors), + ...); + msg.append(") are not one and the same!"); + return msg; +} + +#define TENSORS_HAVE_SAME_SCALAR_TYPE(...) \ + do { \ + const auto have_same_type = tensors_have_same_scalar_type(__VA_ARGS__); \ + TORCH_CHECK(have_same_type.empty(), have_same_type); \ + } while (false)