From 2096a31c0cfb91bb7afd9e4809c4e982ae4d3f7b Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 14 Nov 2024 11:33:33 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2E/N) (#3375) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/466 - Add `index_t` support to TBE training backward kernels Differential Revision: D65933410 --- .../embedding_backward_split_kernel_cta_template.cu | 10 ++++++++-- .../backward/embedding_backward_split_template.cu | 6 ++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 3fb49ed5e..1cfeb66c9 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -77,6 +77,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -105,7 +106,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -430,6 +431,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -446,6 +448,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, + {{ index_type }}, {%- for ph_name in args.placeholder_tensor_names %} {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} @@ -470,7 +473,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -538,11 +541,13 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} {%- for ph_type_combo in args.placeholder_type_combos %} {{ template_instantiation( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -552,6 +557,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- endfor %} {%- endfor %} {%- endfor %} + {%- endfor %} {%- endmacro %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index f168bf79f..c731308e4 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -45,6 +45,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -73,7 +74,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -962,6 +963,7 @@ Tensor {{ embedding_cuda_op }}(