From 79b365ba750eb2406965b75f97b5980af91569f8 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Sun, 16 Jul 2023 20:55:25 -0700 Subject: [PATCH] Re-organize the forward split code generation (#1879) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1879 - Re-organize the forward split code generation Reviewed By: sryap Differential Revision: D47420155 fbshipit-source-id: 607088818d1801e5fbb8d307c96a75922e8e7188 --- .../embedding_backward_split_template.cu | 8 +- ...embedding_forward_split_kernel_template.cu | 171 +++++++----------- .../embedding_forward_split_template.cu | 111 +++++++++--- 3 files changed, 157 insertions(+), 133 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 74e8e041a..7345be06e 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -183,7 +183,7 @@ grad_mean{{ vbe_desc }}_kernel( For the experimental optimizers, kMaxVecsPerThread and kThreadGroupSize are fixed to 8 (1024 elements) and kWarpSize, respectively. */ -#define DISPATCH_KERNEL_BODY(MAX_D, ...) \ +#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \ [&] { \ constexpr auto kMaxVecsPerThread = {{ max_embedding_dim // items_per_warp }}; \ constexpr auto kThreadGroupSize = kWarpSize; \ @@ -208,7 +208,7 @@ grad_mean{{ vbe_desc }}_kernel( is unable to use std::max in constexpr context. */ #ifdef FBGEMM_USE_SUBWARP_SHUFFLE -#define DISPATCH_KERNEL_BODY(MAX_D, ...) \ +#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \ [&] { \ {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} @@ -223,7 +223,7 @@ grad_mean{{ vbe_desc }}_kernel( }() #else -#define DISPATCH_KERNEL_BODY(MAX_D, ...) \ +#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \ [&] { \ constexpr int kThreadGroupSize = kWarpSize; \ {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} @@ -573,7 +573,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi } {%- endif %} - DISPATCH_KERNEL_BODY(max_D, [&] { + DISPATCH_OPTIMAL_KERNEL(max_D, [&] { // Stay under used_shared_kb of shared memory (V100: 64 KB; A100: 96 KB; H100: 144 KB), BT_block_size must be a power of two. while (BT_block_size * sizeof(at::acc_type) * 4 * kWarpSize * kMaxVecsPerThread >= used_shared_bytes) { BT_block_size /= 2; diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu index c9fb32407..4bbd424d2 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu @@ -335,15 +335,76 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } } +//////////////////////////////////////////////////////////////////////////////// +// Explicit Template Instantiations +//////////////////////////////////////////////////////////////////////////////// + /* Explicitly instantiate the kernel function template. The instantiations are based on the types enumerated by DISPATCH_EMB_CACHE_TYPES macro used in embedding_forward_split_template.cu */ -{%- for output_type in ['uint8_t', 'at::Half', 'float'] %} -{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %} -{%- for cache_type in ['float', 'at::Half'] %} +{%- macro template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) %} +template __launch_bounds__(kForwardMaxThreads) __global__ +void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel +< + {{ emb_type }}, + {{ cache_type }}, + {{ output_type }}, + {%- if not dense %} + {{ use_cache }}, + {%- endif %} + int64_t, + {%- if not nobag %} + {{- kMaxVecsPerThread }}, + {%- endif %} + {{ kThreadGroupSize }} +> ( + const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, + {%- if not dense %} + const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, + const pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + {%- if vbe %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 b_t_map, + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- else %} + FixedDivisor fd_B, + {%- endif %} + const pta::PackedTensorAccessor32 indices, + const pta::PackedTensorAccessor32 offsets, + {%- if not nobag %} + int64_t pooling_mode, + {%- endif %} + {%- if weighted %} + pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 lxu_cache_locations, + {%- endif %} + pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output); +{%- endmacro %} + +{%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %} + {%- for emb_type in ['uint8_t', 'float', 'at::Half'] %} + {%- for cache_type in ['float', 'at::Half'] %} + {%- for output_type in ['uint8_t', 'at::Half', 'float'] %} + {{ template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) }} + {%- endfor %} + {%- endfor %} + {%- endfor %} +{%- endmacro %} + //////////////////////////////////////////////////////////////////////////////// #ifdef FBGEMM_USE_SUBWARP_SHUFFLE @@ -419,55 +480,7 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } (NULL,·NULL,·(kWarpSize·/·1)) */ #} {%- for (use_cache, kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} - -template __launch_bounds__(kForwardMaxThreads) __global__ -void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel -< - {{ emb_type }}, - {{ cache_type }}, - {{ output_type }}, - {%- if not dense %} - {{ use_cache }}, - {%- endif %} - int64_t, - {%- if not nobag %} - {{- kMaxVecsPerThread }}, - {%- endif %} - {{ kThreadGroupSize }} -> ( - const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, - {%- if not dense %} - const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, - const pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} - {%- if vbe %} - const pta::PackedTensorAccessor32 output_offsets, - const pta::PackedTensorAccessor32 b_t_map, - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- else %} - FixedDivisor fd_B, - {%- endif %} - const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 offsets, - {%- if not nobag %} - int64_t pooling_mode, - {%- endif %} - {%- if weighted %} - pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, - {%- endif %} - pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output); - + {{ bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) }} {%- endfor %} //////////////////////////////////////////////////////////////////////////////// @@ -528,61 +541,9 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } (NULL,·NULL,·kWarpSize) */ #} {%- for (use_cache, kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} - -template __launch_bounds__(kForwardMaxThreads) __global__ -void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel -< - {{ emb_type }}, - {{ cache_type }}, - {{ output_type }}, - {%- if not dense %} - {{ use_cache }}, - {%- endif %} - int64_t, - {%- if not nobag %} - {{- kMaxVecsPerThread }}, - {%- endif %} - {{ kThreadGroupSize }} -> ( - const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, - {%- if not dense %} - const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, - const pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} - {%- if vbe %} - const pta::PackedTensorAccessor32 output_offsets, - const pta::PackedTensorAccessor32 b_t_map, - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- else %} - FixedDivisor fd_B, - {%- endif %} - const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 offsets, - {%- if not nobag %} - int64_t pooling_mode, - {%- endif %} - {%- if weighted %} - pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, - {%- endif %} - pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output); - + {{ bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) }} {%- endfor %} //////////////////////////////////////////////////////////////////////////////// #endif //////////////////////////////////////////////////////////////////////////////// - -{%- endfor %} -{%- endfor %} -{%- endfor %} diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index 600c0280e..f32346b9c 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -22,6 +22,10 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +//////////////////////////////////////////////////////////////////////////////// +// External Function Declarations +//////////////////////////////////////////////////////////////////////////////// + {%- if not weighted %} template < typename emb_t, @@ -83,9 +87,10 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( #endif {% endif %} // if not dense + {%- for nobag in [True, False] %} {%- if not nobag or (not weighted and not vbe) %} -{%- set has_experimental = (not dense and not nobag and not vbe) %} + template < typename emb_t, typename cache_t, @@ -135,6 +140,75 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba pta::PackedTensorAccessor64 output // [B][total_D] ); +{%- endif %} +{%- endfor %} + + +//////////////////////////////////////////////////////////////////////////////// +// Utility Macros +//////////////////////////////////////////////////////////////////////////////// + +/* + The macro definition for both cases are almost the same except for the + definition of kThreadGroupSize. In the FBGEMM_USE_SUBWARP_SHUFFLE case, if + MAX_D is small, then we use fewer number of threads than kWarpSize. + + NOTE: kMaxVecsPerThread is computed using the ternary operator because HIPCC + is unable to use std::max in constexpr context. +*/ +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +#define DISPATCH_OPTIMAL_FORWARD_KERNEL(MAX_D, ...) \ + [&] { \ + {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} + {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} + if (MAX_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { \ + constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; \ + constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); \ + return __VA_ARGS__(); \ + } \ + {%- endif %} + {%- endfor %} + return; \ + }() + +#else +#define DISPATCH_OPTIMAL_FORWARD_KERNEL(MAX_D, ...) \ + [&] { \ + constexpr int kThreadGroupSize = kWarpSize; \ + {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} + {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} + if (MAX_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { \ + constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; \ + return __VA_ARGS__(); \ + } \ + {%- endif %} + {%- endfor %} + return; \ + }() + +#endif + + +#define DISPATCH_OPTIMAL_NOBAG_FORWARD_KERNEL(DD_, ...) \ + [&] { \ + {%- for kEmbeddingSize in [4, 8, 16, 32] %} + if (DD_ <= {{ kEmbeddingSize }}) { \ + constexpr int kEmbeddingSize = {{ kEmbeddingSize }}; \ + return __VA_ARGS__(); \ + } \ + {%- endfor %} + return; \ + }() + + +//////////////////////////////////////////////////////////////////////////////// +// Kernel Definitions +//////////////////////////////////////////////////////////////////////////////// + +{%- for nobag in [True, False] %} +{%- if not nobag or (not weighted and not vbe) %} +{%- set has_experimental = (not dense and not nobag and not vbe) %} + Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_cuda( Tensor dev_weights, {%- if not dense %} @@ -294,21 +368,8 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" {%- if not dense %} if (use_lxu_cache == {{ use_cache }}) { {%- endif %} - // kMaxElemPerThread is # of elements handled by thread if we use a full warp for a row - // We consider kMaxElemPerThread 1 and 2, and then a multiple of 4. - {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} - {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} - if (max_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { - // hipcc can't use max in constexpr - constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; - // If max_D is small, use fewer number of threads than kWarpSize. - -#ifdef FBGEMM_USE_SUBWARP_SHUFFLE - constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); -#else - constexpr int kThreadGroupSize = kWarpSize; -#endif + DISPATCH_OPTIMAL_FORWARD_KERNEL(max_D, [&] { #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "{{ "dense" if dense else "split" }}_embedding_codegen_forward_{{ wdesc }}_kernel"; #endif @@ -354,27 +415,28 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" output = output.reshape({-1}); {%- endif %} return; - } - {%- endif %} - {%- endfor %} + }); + {%- if not dense %} } // if (use_lxu_cache == {{ use_cache }}) {%- endif %} {%- endif %} // if (not dense) or (use_cache == "true" and not vbe) {%- endfor %} // for use_cache in ["false", "true"] + + {%- else %} - {%- for kEmbeddingSize in [4, 8, 16, 32] %} - if (D <= {{ kEmbeddingSize }}) { + + DISPATCH_OPTIMAL_NOBAG_FORWARD_KERNEL(D, [&] { {%- if not dense %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "split_embedding_nobag_codegen_forward_unweighted_small_kernel"; #endif - split_embedding_nobag_codegen_forward_unweighted_small_kernel<<< + split_embedding_nobag_codegen_forward_unweighted_small_kernel<<< {%- else %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "dense_embedding_nobag_codegen_forward_unweighted_small_kernel"; #endif - dense_embedding_nobag_codegen_forward_unweighted_small_kernel<<< + dense_embedding_nobag_codegen_forward_unweighted_small_kernel<<< {%- endif %} div_round_up(total_B, kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), @@ -398,8 +460,9 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return; - } - {%- endfor %} + }); + + {%- for use_cache in ["false", "true"] %} // The dense case does not have cache so we have to generate code for // only one case (value of use_cache/vbe does not matter)