From 193ff6e8042e38885e7b94697a58c9424104c923 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 12 Nov 2024 22:10:13 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2/N) (#3326) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/420 - Add `index_t` support to TBE training backward kernels Reviewed By: basilwong Differential Revision: D65464554 --- .../embedding_backward_split_grad_template.cu | 20 +-- ..._backward_split_indice_weights_template.cu | 149 +++++++++--------- ...ding_backward_split_kernel_cta_template.cu | 10 +- ...ing_backward_split_kernel_warp_template.cu | 10 +- .../embedding_backward_split_template.cu | 87 +++++----- .../fbgemm_gpu/split_embeddings_utils.cuh | 8 +- .../radix_sort_pairs.cu | 8 +- 7 files changed, 164 insertions(+), 128 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu index f20b1b97bd..4f6ec14ae0 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu @@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel {% for vbe in [True, False] %} {% set vdesc = "_vbe" if vbe else "" %} -template +template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32 offsets, {% if vbe %} const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -175,12 +175,12 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( fd_B.DivMod(b_t, &t, &b); {% endif %} - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - int64_t indices_start = offsets[b_t]; - int64_t indices_end = offsets[b_t + 1]; - int32_t L = indices_end - indices_start; + const auto D_start = D_offsets[t]; + const auto D_end = D_offsets[t + 1]; + const auto D = D_end - D_start; + const auto indices_start = offsets[b_t]; + const auto indices_end = offsets[b_t + 1]; + const auto L = indices_end - indices_start; {% if vbe %} const auto grad_offset = row_grad_offsets[b_t]; @@ -212,6 +212,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( //////////////////////////////////////////////////////////////////////////////// {% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %} +{% for offset_type in ['int32_t', 'int64_t'] %} template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel <{{ grad_type }}> ( @@ -220,7 +221,7 @@ void grad_mean{{ vdesc }}_kernel const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32<{{ offset_type }}, 1, at::RestrictPtrTraits> offsets, {% if vbe %} const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel FixedDivisor fd_B {% endif %} ); +{% endfor %} // for offset_type in ['int32_t', 'int64_t'] {% endfor %} // for grad_type in ['at::Half', 'float'] {% endfor %} // for vbe in [True, False] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 8e1db36757..f87b20e8c6 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -64,6 +64,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, int32_t kFixedMaxVecsPerThread > __global__ __launch_bounds__(kForwardMaxThreads) void @@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] - const pta::PackedTensorAccessor32 offsets, // [B x T + 1] + const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] + const pta::PackedTensorAccessor32 offsets, // [B x T + 1] {%- if not dense %} const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} @@ -113,13 +114,13 @@ __global__ __launch_bounds__(kForwardMaxThreads) void fd_B.DivMod(b_t, &t, &b); {%- endif %} - int64_t weights_offset = weights_offsets[t]; - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - int64_t indices_start = offsets[b_t]; - int64_t indices_end = offsets[b_t + 1]; - int32_t L = indices_end - indices_start; + const auto weights_offset = weights_offsets[t]; + const auto D_start = D_offsets[t]; + const auto D_end = D_offsets[t + 1]; + const auto D = D_end - D_start; + const auto indices_start = offsets[b_t]; + const auto indices_end = offsets[b_t + 1]; + const auto L = indices_end - indices_start; if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) { // If the table does not require gradient computation, we set the gradient to zero. for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { @@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { int32_t l = l_start + threadIdx.x; - int64_t idx = l < L ? indices[indices_start + l] : 0; + auto idx = l < L ? indices[indices_start + l] : 0; {%- if not dense %} const auto {{ locs_or_addrs_idx }} = (placement == PlacementType::MANAGED_CACHING && l < L) ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - int64_t idx_j = shfl_sync(idx, j); + auto idx_j = shfl_sync(idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); {%- endif %} @@ -354,72 +355,74 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const uint32_t info_B_mask = info_B_mask_int64; {%- endif %} - DISPATCH_EMB_GRAD_CACHE_TYPES( - dev_weights.scalar_type(), - aligned_grad_output.scalar_type(), - {%- if not dense %} - lxu_cache_weights.scalar_type(), - {%- else %} - dev_weights.scalar_type(), - {%- endif %} - "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel", - [&] { - {%- if vbe %} - const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1}); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] { + DISPATCH_EMB_GRAD_CACHE_TYPES( + dev_weights.scalar_type(), + aligned_grad_output.scalar_type(), + {%- if not dense %} + lxu_cache_weights.scalar_type(), {%- else %} - const auto& grad_output_reshaped = aligned_grad_output; + dev_weights.scalar_type(), {%- endif %} + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2", + [&] { + {%- if vbe %} + const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1}); + {%- else %} + const auto& grad_output_reshaped = aligned_grad_output; + {%- endif %} - {%- for use_vec_blocking in [False, True] %} - {%- set vbdesc = "vec_blocking_" if use_vec_blocking else "" %} - {%- set dpdesc = "NON_" if not use_vec_blocking else "" %} - DISPATCH_{{ dpdesc }}VEC_BLOCKING_KERNEL(max_D, [&] { - {%- set kernel_name = - "{}_embedding_codegen_grad_indice_weights{}_{}kernel".format( - mdesc, vdesc, vbdesc) - %} -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "{{ kernel_name }}"; -#endif - {{ kernel_name }}< - emb_t, - grad_t, - cache_t, - kFixedMaxVecsPerThread><<< - div_round_up(total_B, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, grad_output_reshaped, grad_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, dev_weights, emb_t, 1, 64), - {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, cache_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32), - {%- endif %} - MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), - {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), - {%- endif %} - MAKE_PTA_WITH_NAME(func_name, feature_requires_grad_, int32_t, 1, 32), - MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32), - {%- if vbe %} - MAKE_PTA_WITH_NAME(func_name, vbe_row_output_offsets, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, vbe_b_t_map, int32_t, 1, 32), - info_B_num_bits, - info_B_mask - {%- else %} - FixedDivisor(total_B / T) - {%- endif %} - ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; + {%- for use_vec_blocking in [False, True] %} + {%- set vbdesc = "vec_blocking_" if use_vec_blocking else "" %} + {%- set dpdesc = "NON_" if not use_vec_blocking else "" %} + DISPATCH_{{ dpdesc }}VEC_BLOCKING_KERNEL(max_D, [&] { + {%- set kernel_name = + "{}_embedding_codegen_grad_indice_weights{}_{}kernel".format( + mdesc, vdesc, vbdesc) + %} + #ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "{{ kernel_name }}"; + #endif + {{ kernel_name }}< + emb_t, + grad_t, + cache_t, + index_t, + kFixedMaxVecsPerThread><<< + div_round_up(total_B, kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, grad_output_reshaped, grad_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, dev_weights, emb_t, 1, 64), + {%- if not dense %} + MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, cache_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32), + {%- endif %} + MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + {%- if not dense %} + MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), + {%- endif %} + MAKE_PTA_WITH_NAME(func_name, feature_requires_grad_, int32_t, 1, 32), + MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32), + {%- if vbe %} + MAKE_PTA_WITH_NAME(func_name, vbe_row_output_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, vbe_b_t_map, int32_t, 1, 32), + info_B_num_bits, + info_B_mask + {%- else %} + FixedDivisor(total_B / T) + {%- endif %} + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return; + }); + {%- endfor %} {# /* for use_vec_blocking */ #} }); - {%- endfor %} {# /* for use_vec_blocking */ #} }); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 3fb49ed5e7..1cfeb66c94 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -77,6 +77,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -105,7 +106,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -430,6 +431,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -446,6 +448,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, + {{ index_type }}, {%- for ph_name in args.placeholder_tensor_names %} {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} @@ -470,7 +473,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -538,11 +541,13 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} {%- for ph_type_combo in args.placeholder_type_combos %} {{ template_instantiation( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -552,6 +557,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- endfor %} {%- endfor %} {%- endfor %} + {%- endfor %} {%- endmacro %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 3b230b0100..bc27f15281 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -62,6 +62,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t"}}, {%- endfor %} @@ -90,7 +91,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -341,6 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -358,6 +360,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, + {{ index_type }}, {%- for ph_name in args.placeholder_tensor_names %} {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} @@ -381,7 +384,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -441,11 +444,13 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} {%- for ph_type_combo in args.placeholder_type_combos %} {{ template_instantiation( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -456,6 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endfor %} {%- endfor %} {%- endfor %} + {%- endfor %} {%- endmacro %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fdd9c0f798..bac5a1b006 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -45,6 +45,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -73,7 +74,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -134,6 +135,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -162,7 +164,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -232,13 +234,13 @@ split_embedding_backward_codegen_find_long_segments( const bool use_deterministic_algorithms); -template +template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32 offsets, {%- if vbe %} const pta::PackedTensorAccessor32 grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -742,31 +744,35 @@ Tensor {{ embedding_cuda_op }}( else { {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(radix_sort_pairs( - nullptr, - temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), - {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), - {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), - linear_indices.numel(), - 0, - total_hash_size_bits, - at::cuda::getCurrentCUDAStream())); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(radix_sort_pairs( - temp_storage.data_ptr(), - temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), - {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), - {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), - linear_indices.numel(), - 0, - total_hash_size_bits, - at::cuda::getCurrentCUDAStream())); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] { + AT_CUDA_CHECK(radix_sort_pairs( + nullptr, + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), + {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream())); + + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + indices.options().dtype(at::kByte)); + + AT_CUDA_CHECK(radix_sort_pairs( + temp_storage.data_ptr(), + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), + {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream())); + }); } } @@ -775,6 +781,8 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -792,8 +800,8 @@ Tensor {{ embedding_cuda_op }}( AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), indice_weights.data_ptr>(), indice_weights_sorted.data_ptr>(), linear_indices.numel(), @@ -806,8 +814,8 @@ Tensor {{ embedding_cuda_op }}( AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), indice_weights.data_ptr>(), indice_weights_sorted.data_ptr>(), linear_indices.numel(), @@ -840,9 +848,9 @@ Tensor {{ embedding_cuda_op }}( grad_output_mean = at::empty_like(grad_output_reshaped); {%- if not dense or not vbe %} -#ifdef FBGEMM_GPU_MEMCHECK + #ifdef FBGEMM_GPU_MEMCHECK const auto func_name1 = "grad_mean{{ vdesc }}_kernel"; -#endif + #endif grad_mean{{ vdesc }}_kernel<<< div_round_up(total_B, kMaxThreads / kWarpSize), @@ -853,7 +861,7 @@ Tensor {{ embedding_cuda_op }}( MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64), MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64), MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name1, offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32), {%- if vbe %} MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32), @@ -955,6 +963,7 @@ Tensor {{ embedding_cuda_op }}(