From eea19e82c4e4af8a5846871a2888eb6d58870f81 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 11 Jul 2023 17:30:26 -0700 Subject: [PATCH] Re-organize the backward split code generation (#1871) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1871 - Move the determination of `kMaxVecsPerThread` and `kThreadGroupSize` in `embedding_backward_split_template.cu` into its own macro - Update `embedding_backward_split_template.cu` to hardcode values for kMaxVecsPerThread and kThreadGroupSize for the experimental optimizers case Reviewed By: sryap Differential Revision: D47354614 fbshipit-source-id: 85e947734d929d86339a348e3fb4bda1eeb24dce --- .../embedding_backward_split_template.cu | 565 ++++++++++-------- 1 file changed, 310 insertions(+), 255 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index de1377d72..be793ca86 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -172,10 +172,77 @@ grad_mean{{ vbe_desc }}_kernel( {%- endif %} ); + //////////////////////////////////////////////////////////////////////////////// -// Operator Code +// Utility Macros //////////////////////////////////////////////////////////////////////////////// +{%- if experimental_optimizer %} +/* + For the experimental optimizers, kMaxVecsPerThread and kThreadGroupSize are + fixed to 1024 and kWarpSize, respectively. +*/ +#define DISPATCH_KERNEL_BODY(MAX_D, ...) \ + [&] { \ + constexpr auto kMaxVecsPerThread = 1024; \ + constexpr auto kThreadGroupSize = kWarpSize; \ + return __VA_ARGS__(); \ + }() + +{%- else %} + +/* + For the non-experimental optimizers, we determine the kernel template + instantiation that is best optimized for MAX_D and invoke it. + + kMaxElemPerThread is the number of elements handled by each thread if we use + a full warp for a row. We consider kMaxElemPerThread values of 1, 2, and + multiples of 4. + + The macro definition for both cases are almost the same except for the + definition of kThreadGroupSize. In the FBGEMM_USE_SUBWARP_SHUFFLE case, if + MAX_D is small, then we use fewer number of threads than kWarpSize. + + NOTE: kMaxVecsPerThread is computed using the ternary operator because HIPCC + is unable to use std::max in constexpr context. +*/ +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +#define DISPATCH_KERNEL_BODY(MAX_D, ...) \ + [&] { \ + {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} + {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} + if (MAX_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { \ + constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; \ + constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); \ + return __VA_ARGS__(); \ + } \ + {%- endif %} + {%- endfor %} + return; \ + }() + +#else +#define DISPATCH_KERNEL_BODY(MAX_D, ...) \ + [&] { \ + constexpr int kThreadGroupSize = kWarpSize; \ + {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} + {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} + if (MAX_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { \ + constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; \ + return __VA_ARGS__(); \ + } \ + {%- endif %} + {%- endfor %} + return; \ + }() + +#endif +{%- endif %} + + +//////////////////////////////////////////////////////////////////////////////// +// Kernel Definition +//////////////////////////////////////////////////////////////////////////////// {%- set func_name0 = "split_embedding{}_backward_codegen_{}_{}_exact{}_cuda".format( "_nobag" if nobag else "", @@ -234,7 +301,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- endif %} ) { - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( {%- if optimizer != "none" %} dev_weights, {%- endif %} @@ -344,9 +411,8 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi #endif int used_shared_bytes = used_shared_kb << 10; - Tensor linear_indices, linear_indices_sorted; - Tensor infos_sorted; - Tensor sorted_linear_indices_run, sorted_linear_indices_run_lengths, + Tensor linear_indices, linear_indices_sorted, infos_sorted, + sorted_linear_indices_run, sorted_linear_indices_run_lengths, sorted_linear_indices_num_runs, sorted_linear_indices_cumulative_run_lengths; std::tie( @@ -459,14 +525,14 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- if not nobag %} Tensor grad_output_mean; if (static_cast(pooling_mode) == PoolingMode::MEAN) { - grad_output_mean = at::empty_like(grad_output); - {%- if not dense or not vbe %} + grad_output_mean = at::empty_like(grad_output); + {%- if not dense or not vbe %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name1 = "grad_mean{{ vbe_desc }}_kernel"; + const auto func_name1 = "grad_mean{{ vbe_desc }}_kernel"; #endif - grad_mean{{ vbe_desc }}_kernel<<< + grad_mean{{ vbe_desc }}_kernel<<< div_round_up(total_B, kMaxThreads / kWarpSize), dim3(kWarpSize, kMaxThreads / kWarpSize), 0, @@ -485,11 +551,11 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi FixedDivisor(total_B / T) {%- endif %} ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - {%- endif %} // if not dense or not vbe + C10_CUDA_KERNEL_LAUNCH_CHECK(); + {%- endif %} // if not dense or not vbe - grad_output_accessor = MAKE_PTA_WITH_NAME("{{ func_name0 }}.2", grad_output_mean, grad_t, 2, 64); + grad_output_accessor = MAKE_PTA_WITH_NAME("{{ func_name0 }}.2", grad_output_mean, grad_t, 2, 64); } {%- endif %} @@ -505,280 +571,269 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi } {%- endif %} - - // kMaxElemPerThread is # of elements handled by thread if we use a full warp for a row - // We consider kMaxElemPerThread 1 and 2, and then a multiple of 4. - {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} - {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} - if (max_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { - // hipcc can't use max in constexpr - constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; - // If max_D is small, use fewer number of threads than kWarpSize. -#ifdef FBGEMM_USE_SUBWARP_SHUFFLE - constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); -#else - constexpr int kThreadGroupSize = kWarpSize; -#endif - // Stay under used_shared_kb of shared memory (V100: 64 KB; A100: 96 KB; H100: 144 KB), BT_block_size must be a power of two. - while (BT_block_size * sizeof(at::acc_type) * 4 * kWarpSize * kMaxVecsPerThread >= used_shared_bytes) { - BT_block_size /= 2; - } - TORCH_CHECK(BT_block_size >= 1); - if (std::is_same::value) { - // Otherwise we see CUDA kernel launch failures despite the above checks. - BT_block_size = 1; - } - - auto long_run_ids = at::empty({indices.numel()}, sorted_linear_indices_run_lengths.options()); - auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); - - const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - Tensor long_run_id_to_really_long_run_ids; - if (use_deterministic_algorithms) { - long_run_id_to_really_long_run_ids = - at::empty(0, sorted_linear_indices_run_lengths.options()); - } else { - long_run_id_to_really_long_run_ids = - at::empty({indices.numel()}, sorted_linear_indices_run_lengths.options()); - } - auto num_really_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); - auto grad_accum_counter = at::empty( - use_deterministic_algorithms ? 0 : (indices.numel() / max_segment_length_per_cta), - indices.options().dtype(at::kInt)); + DISPATCH_KERNEL_BODY(max_D, [&] { + // Stay under used_shared_kb of shared memory (V100: 64 KB; A100: 96 KB; H100: 144 KB), BT_block_size must be a power of two. + while (BT_block_size * sizeof(at::acc_type) * 4 * kWarpSize * kMaxVecsPerThread >= used_shared_bytes) { + BT_block_size /= 2; + } + TORCH_CHECK(BT_block_size >= 1); + if (std::is_same::value) { + // Otherwise we see CUDA kernel launch failures despite the above checks. + BT_block_size = 1; + } + + auto long_run_ids = at::empty({indices.numel()}, sorted_linear_indices_run_lengths.options()); + auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); + + const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + + Tensor long_run_id_to_really_long_run_ids; + if (use_deterministic_algorithms) { + long_run_id_to_really_long_run_ids = + at::empty(0, sorted_linear_indices_run_lengths.options()); + } else { + long_run_id_to_really_long_run_ids = + at::empty({indices.numel()}, sorted_linear_indices_run_lengths.options()); + } + + + auto num_really_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); + auto grad_accum_counter = at::empty( + use_deterministic_algorithms ? 0 : (indices.numel() / max_segment_length_per_cta), + indices.options().dtype(at::kInt)); #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name2 = "split_embedding_backward_codegen_find_long_segments"; + const auto func_name2 = "split_embedding_backward_codegen_find_long_segments"; #endif - split_embedding_backward_codegen_find_long_segments<<< - div_round_up(total_unique_indices, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream() - >>>( - MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_num_runs, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_run_lengths, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name2, long_run_ids, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name2, num_long_run_ids, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name2, long_run_id_to_really_long_run_ids, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name2, num_really_long_run_ids, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name2, grad_accum_counter, int32_t, 1, 32), - max_segment_length_per_warp, - max_segment_length_per_cta, - use_deterministic_algorithms); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - // A temp buffer to accumulate gradients with atomics. - auto temp_grad_accum = at::zeros( - {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, - grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); - - int32_t grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), - get_max_thread_blocks_()); - - // Check https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x - // "Compute capability 7.x devices allow a single thread block to - // address the full capacity of shared memory: 96 KB on Volta, - // 64 KB on Turing. Kernels relying on shared memory allocations - // over 48 KB per block are architecture-specific, as such they - // must use dynamic shared memory (rather than statically sized - // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". + split_embedding_backward_codegen_find_long_segments<<< + div_round_up(total_unique_indices, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream() + >>>( + MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_num_runs, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_run_lengths, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, long_run_ids, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, num_long_run_ids, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, long_run_id_to_really_long_run_ids, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, num_really_long_run_ids, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, grad_accum_counter, int32_t, 1, 32), + max_segment_length_per_warp, + max_segment_length_per_cta, + use_deterministic_algorithms); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // A temp buffer to accumulate gradients with atomics. + auto temp_grad_accum = at::zeros( + {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, + grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); + + int32_t grid_size = std::min( + div_round_up(total_unique_indices, kMaxThreads), + get_max_thread_blocks_()); + + // Check https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x + // "Compute capability 7.x devices allow a single thread block to + // address the full capacity of shared memory: 96 KB on Volta, + // 64 KB on Turing. Kernels relying on shared memory allocations + // over 48 KB per block are architecture-specific, as such they + // must use dynamic shared memory (rather than statically sized + // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". #ifndef __HIP_PLATFORM_HCC__ - cudaFuncSetAttribute( - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1< - emb_t, - grad_t, - cache_t, - kMaxVecsPerThread, - kThreadGroupSize>, - cudaFuncAttributeMaxDynamicSharedMemorySize, - used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB + cudaFuncSetAttribute( + split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1< + emb_t, + grad_t, + cache_t, + kMaxVecsPerThread, + kThreadGroupSize>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB #endif - C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name3 = "split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1"; + const auto func_name3 = "split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1"; #endif - // dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0] - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1< - emb_t, - grad_t, - cache_t, - kMaxVecsPerThread, - kThreadGroupSize> - <<) * 4 * kWarpSize * - kMaxVecsPerThread, - at::cuda::getCurrentCUDAStream()>>>( - grad_output_accessor, - {%- if optimizer != "none" %} - {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name3, dev_weights, emb_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name3, uvm_weights, emb_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name3, lxu_cache_weights, cache_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name3, weights_placements, int32_t, 1, 32), - {%- else %} - MAKE_PTA_WITH_NAME(func_name3, dev_weights, emb_t, 1, 64), - {%- endif %} - {%- endif %} // if optimizer != "none" - MAKE_PTA_WITH_NAME(func_name3, weights_offsets, int64_t, 1, 32), - {%- if not nobag %} - MAKE_PTA_WITH_NAME(func_name3, D_offsets, int32_t, 1, 32), - {%- else %} - D, - {%- endif %} - MAKE_PTA_WITH_NAME(func_name3, hash_size_cumsum, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name3, sorted_linear_indices_run, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name3, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name3, long_run_ids, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name3, num_long_run_ids, int32_t, 1, 32), - {%- if not nobag %} - MAKE_PTA_WITH_NAME(func_name3, infos_sorted, int32_t, 1, 32), - {%- else %} - MAKE_PTA_WITH_NAME(func_name3, infos_sorted, int64_t, 1, 32), - {%- endif %} - {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name3, lxu_cache_locations_sorted, int32_t, 1, 32), - {%- endif %} - {%- if weighted %} - MAKE_PTA_ACC_WITH_NAME(func_name3, indice_weights_sorted, cache_t, 1, 32), - {%- endif %} - {%- if not dense and optimizer != "none" %} - stochastic_rounding, - rng_engine_inputs, - {%- else %} - MAKE_PTA_WITH_NAME(func_name3, grad_dev_weights, emb_t, 1, 64), - {%- if optimizer == "none" %} - max_D, - {%- endif %} - {%- endif %} // if not dense and optimizer != "none" - {%- if vbe %} - MAKE_PTA_WITH_NAME(func_name3, vbe_metadata.B_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name3, vbe_metadata.output_offsets, int64_t, 1, 32), - {%- endif %} - {%- if not nobag %} - info_B_num_bits, - info_B_mask, - {%- endif %} - MAKE_PTA_WITH_NAME(func_name3, long_run_id_to_really_long_run_ids, int32_t, 1, 32), - MAKE_PTA_ACC_WITH_NAME(func_name3, temp_grad_accum, cache_t, 2, 32), - MAKE_PTA_WITH_NAME(func_name3, grad_accum_counter, int32_t, 1, 32), - max_segment_length_per_cta, - use_deterministic_algorithms, - {{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name3") | join(",\n ") }}); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); - grid_size = std::min( - div_round_up(total_unique_indices, kBackwardMaxThreads / kThreadGroupSize), - get_max_thread_blocks_()); - - // Shared memory is not needed for non uint8_t weights - size_t shmem_bytes = 0; - if (std::is_same::value) { - shmem_bytes = BT_block_size * sizeof( - at::acc_type) * 4 * kWarpSize * kMaxVecsPerThread; -#ifndef __HIP_PLATFORM_HCC__ - cudaFuncSetAttribute( - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1< + // dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0] + split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1< emb_t, grad_t, cache_t, kMaxVecsPerThread, - kThreadGroupSize>, - cudaFuncAttributeMaxDynamicSharedMemorySize, - used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB + kThreadGroupSize> + <<) * 4 * kWarpSize * + kMaxVecsPerThread, + at::cuda::getCurrentCUDAStream()>>>( + grad_output_accessor, + {%- if optimizer != "none" %} + {%- if not dense %} + MAKE_PTA_WITH_NAME(func_name3, dev_weights, emb_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name3, uvm_weights, emb_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name3, lxu_cache_weights, cache_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name3, weights_placements, int32_t, 1, 32), + {%- else %} + MAKE_PTA_WITH_NAME(func_name3, dev_weights, emb_t, 1, 64), + {%- endif %} + {%- endif %} // if optimizer != "none" + MAKE_PTA_WITH_NAME(func_name3, weights_offsets, int64_t, 1, 32), + {%- if not nobag %} + MAKE_PTA_WITH_NAME(func_name3, D_offsets, int32_t, 1, 32), + {%- else %} + D, + {%- endif %} + MAKE_PTA_WITH_NAME(func_name3, hash_size_cumsum, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, sorted_linear_indices_run, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, long_run_ids, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, num_long_run_ids, int32_t, 1, 32), + {%- if not nobag %} + MAKE_PTA_WITH_NAME(func_name3, infos_sorted, int32_t, 1, 32), + {%- else %} + MAKE_PTA_WITH_NAME(func_name3, infos_sorted, int64_t, 1, 32), + {%- endif %} + {%- if not dense %} + MAKE_PTA_WITH_NAME(func_name3, lxu_cache_locations_sorted, int32_t, 1, 32), + {%- endif %} + {%- if weighted %} + MAKE_PTA_ACC_WITH_NAME(func_name3, indice_weights_sorted, cache_t, 1, 32), + {%- endif %} + {%- if not dense and optimizer != "none" %} + stochastic_rounding, + rng_engine_inputs, + {%- else %} + MAKE_PTA_WITH_NAME(func_name3, grad_dev_weights, emb_t, 1, 64), + {%- if optimizer == "none" %} + max_D, + {%- endif %} + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + MAKE_PTA_WITH_NAME(func_name3, vbe_metadata.B_offsets, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, vbe_metadata.output_offsets, int64_t, 1, 32), + {%- endif %} + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + MAKE_PTA_WITH_NAME(func_name3, long_run_id_to_really_long_run_ids, int32_t, 1, 32), + MAKE_PTA_ACC_WITH_NAME(func_name3, temp_grad_accum, cache_t, 2, 32), + MAKE_PTA_WITH_NAME(func_name3, grad_accum_counter, int32_t, 1, 32), + max_segment_length_per_cta, + use_deterministic_algorithms, + {{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name3") | join(",\n ") }}); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + grid_size = std::min( + div_round_up(total_unique_indices, kBackwardMaxThreads / kThreadGroupSize), + get_max_thread_blocks_()); + + // Shared memory is not needed for non uint8_t weights + size_t shmem_bytes = 0; + if (std::is_same::value) { + shmem_bytes = BT_block_size * sizeof( + at::acc_type) * 4 * kWarpSize * kMaxVecsPerThread; +#ifndef __HIP_PLATFORM_HCC__ + cudaFuncSetAttribute( + split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1< + emb_t, + grad_t, + cache_t, + kMaxVecsPerThread, + kThreadGroupSize>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB #endif - } + } #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name4 = "split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1"; + const auto func_name4 = "split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1"; #endif - C10_CUDA_KERNEL_LAUNCH_CHECK(); - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1< - emb_t, - grad_t, - cache_t, - kMaxVecsPerThread, - kThreadGroupSize> - <<>>( - grad_output_accessor, - {%- if optimizer != "none" %} - {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name4, dev_weights, emb_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name4, uvm_weights, emb_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name4, lxu_cache_weights, cache_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name4, weights_placements, int32_t, 1, 32), - {%- else %} - MAKE_PTA_WITH_NAME(func_name4, dev_weights, emb_t, 1, 64), - {%- endif %} - {%- endif %} - MAKE_PTA_WITH_NAME(func_name4, weights_offsets, int64_t, 1, 32), - {%- if not nobag %} - MAKE_PTA_WITH_NAME(func_name4, D_offsets, int32_t, 1, 32), - {%- else %} - D, - {%- endif %} - MAKE_PTA_WITH_NAME(func_name4, hash_size_cumsum, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_run, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32), - - {%- if not nobag %} - MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int32_t, 1, 32), - {%- else %} - MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int64_t, 1, 32), - {%- endif %} - {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name4, lxu_cache_locations_sorted, int32_t, 1, 32), - {%- endif %} - {%- if weighted %} - MAKE_PTA_ACC_WITH_NAME(func_name4, indice_weights_sorted, cache_t, 1, 32), - {%- endif %} - MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_num_runs, int32_t, 1, 32), - max_segment_length_per_warp, - {%- if not dense and optimizer != "none" %} - stochastic_rounding, - rng_engine_inputs, - {%- else %} - MAKE_PTA_WITH_NAME(func_name4, grad_dev_weights, emb_t, 1, 64), - {%- if optimizer == "none" %} - max_D, - {%- endif %} - {%- endif %} // if not dense and optimizer != "none" - {%- if vbe %} - MAKE_PTA_WITH_NAME(func_name4, vbe_metadata.B_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name4, vbe_metadata.output_offsets, int64_t, 1, 32), - {%- endif %} - {%- if not nobag %} - info_B_num_bits, - info_B_mask, - {%- endif %} - {{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name4") | join(",\n ") }}); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - {%- endif %} - {%- endfor %} + C10_CUDA_KERNEL_LAUNCH_CHECK(); + split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1< + emb_t, + grad_t, + cache_t, + kMaxVecsPerThread, + kThreadGroupSize> + <<>>( + grad_output_accessor, + {%- if optimizer != "none" %} + {%- if not dense %} + MAKE_PTA_WITH_NAME(func_name4, dev_weights, emb_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name4, uvm_weights, emb_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name4, lxu_cache_weights, cache_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name4, weights_placements, int32_t, 1, 32), + {%- else %} + MAKE_PTA_WITH_NAME(func_name4, dev_weights, emb_t, 1, 64), + {%- endif %} + {%- endif %} + MAKE_PTA_WITH_NAME(func_name4, weights_offsets, int64_t, 1, 32), + {%- if not nobag %} + MAKE_PTA_WITH_NAME(func_name4, D_offsets, int32_t, 1, 32), + {%- else %} + D, + {%- endif %} + MAKE_PTA_WITH_NAME(func_name4, hash_size_cumsum, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_run, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32), + {%- if not nobag %} + MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int32_t, 1, 32), + {%- else %} + MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int64_t, 1, 32), + {%- endif %} + {%- if not dense %} + MAKE_PTA_WITH_NAME(func_name4, lxu_cache_locations_sorted, int32_t, 1, 32), + {%- endif %} + {%- if weighted %} + MAKE_PTA_ACC_WITH_NAME(func_name4, indice_weights_sorted, cache_t, 1, 32), + {%- endif %} + MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_num_runs, int32_t, 1, 32), + max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + stochastic_rounding, + rng_engine_inputs, + {%- else %} + MAKE_PTA_WITH_NAME(func_name4, grad_dev_weights, emb_t, 1, 64), + {%- if optimizer == "none" %} + max_D, + {%- endif %} + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + MAKE_PTA_WITH_NAME(func_name4, vbe_metadata.B_offsets, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name4, vbe_metadata.output_offsets, int64_t, 1, 32), + {%- endif %} + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + {{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name4") | join(",\n ") }}); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return; + }); }); {%- if dense %} return grad_dev_weights; + {%- elif optimizer == "none" %} return at::sparse_coo_tensor( sorted_linear_indices_run.unsqueeze(0), grad_dev_weights.reshape({total_unique_indices, max_D}), {total_hash_size, max_D}, dev_weights.options().layout(at::kSparse)); + {%- else %} return Tensor(); {%- endif %} } -// clang-format on + // clang-format on