diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 3f5486828..7df4a658a 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -1053,6 +1053,7 @@ def prefetch( # noqa C901 sp_prev_curr_map_gpu, inserted_rows_prev, actions_count_gpu, + use_pipeline=self.prefetch_pipeline, ) # Record the tensors that will be pushed into a queue @@ -1094,6 +1095,7 @@ def prefetch( # noqa C901 assigned_cache_slots, inserted_rows, actions_count_gpu, + use_pipeline=self.prefetch_pipeline, ) if linear_cache_indices.numel() > 0: diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index 13002be70..67623720d 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -22,6 +22,8 @@ #include "fbgemm_gpu/utils/tensor_accessor.h" #include "fbgemm_gpu/utils/vec4.cuh" +constexpr int ALL_TO_PREFETCH_SM_RATIO = 8; + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -59,23 +61,20 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_kernel( const pta::PackedTensorAccessor32 count) { const int32_t N = indices.size(0); - const int32_t n = blockIdx.x * blockDim.y + threadIdx.y; - if (n >= N) { - return; - } const auto count_ = count[0]; - if (n >= count_) { - return; - } - // idx == -1 if it is conflict miss - const auto idx = indices[n]; - if (idx < 0) { - return; + CUDA_KERNEL_ASSERT(count_ <= N); + for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < count_; + n += blockDim.y * gridDim.x) { + // idx == -1 if it is conflict miss + const auto idx = indices[n]; + if (idx < 0) { + continue; + } + const auto D = self.size(1); + const auto self_idx = is_index_put ? idx : n; + const auto values_idx = is_index_put ? n : idx; + vec4_copy(&self[self_idx][0], &values[values_idx][0], D); } - const auto D = self.size(1); - const auto self_idx = is_index_put ? idx : n; - const auto values_idx = is_index_put ? n : idx; - vec4_copy(&self[self_idx][0], &values[values_idx][0], D); } template @@ -83,7 +82,9 @@ Tensor masked_index_impl( const Tensor& self, const Tensor& indices, const Tensor& values, - const Tensor& count) { + const Tensor& count, + const bool use_pipeline, + const int preferred_sms) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(self, indices, values, count); TENSOR_CONTIGUOUS(self); TENSOR_CONTIGUOUS(indices); @@ -98,12 +99,43 @@ Tensor masked_index_impl( const auto D = self.size(1); TORCH_CHECK_EQ(self.size(1), values.size(1)); + const int32_t tx = std::min(D / 4, kMaxThreads); + const dim3 threads(tx, kMaxThreads / tx); + + const auto full_grid_size = div_round_up(N, kMaxThreads / tx); + + // The default number of SMs for use_pipeline=true is set based on an + // empirical study + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, at::cuda::current_device()); + + int DEFAULT_PIPELINE_SMS; + if (prop.major == 8) { + // Assume A100 + DEFAULT_PIPELINE_SMS = 4; + } else if (prop.major == 9) { + // Assume H100 + DEFAULT_PIPELINE_SMS = 16; + } else { + DEFAULT_PIPELINE_SMS = + div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO); + } + + const int pipeline_grid_size = + preferred_sms == -1 ? DEFAULT_PIPELINE_SMS : preferred_sms; + TORCH_CHECK( + !use_pipeline || pipeline_grid_size >= 1, "preferred_sms must >= 1"); + + // Use a fraction of SMs if use_pipeline=true + const auto grid_size = use_pipeline + ? std::min(pipeline_grid_size, full_grid_size) + : full_grid_size; + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( self.scalar_type(), is_index_put ? "masked_index_put" : "masked_index_select", [&] { - const int32_t tx = std::min(D / 4, kMaxThreads); - const dim3 threads(tx, kMaxThreads / tx); #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = is_index_put ? "masked_index_put_kernel" : "masked_index_select_kernel"; @@ -112,7 +144,7 @@ Tensor masked_index_impl( TORCH_CHECK(D % 16 == 0, "D needs to be padded to be multiple of 16") } masked_index_kernel - <<>>( @@ -131,17 +163,22 @@ Tensor masked_index_put_cuda( Tensor self, Tensor indices, Tensor values, - Tensor count) { - return masked_index_impl(self, indices, values, count); + Tensor count, + const bool use_pipeline, + const int64_t preferred_sms) { + return masked_index_impl( + self, indices, values, count, use_pipeline, preferred_sms); } Tensor masked_index_select_cuda( Tensor self, Tensor indices, Tensor values, - Tensor count) { + Tensor count, + const bool use_pipeline, + const int64_t preferred_sms) { return masked_index_impl( - self, indices, values, count); + self, indices, values, count, use_pipeline, preferred_sms); } __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 1cc67815a..14b467941 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -50,11 +50,22 @@ ssd_cache_populate_actions_cuda( /// @param indices The 1D index tensor /// @param values The 2D input tensor /// @param count The tensor that contains the length of `indices` to -/// process +/// process +/// @param use_pipeline A flag that indicates that this kernel will +/// overlap with other kernels. If it is true, then use a +/// fraction of SMs to reduce resource competition +/// @param preferred_sms The number of preferred SMs for the kernel to +/// use when use_pipeline=true. This value is ignored when +/// use_pipeline=false. /// /// @return The `self` tensor -Tensor -masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count); +Tensor masked_index_put_cuda( + Tensor self, + Tensor indices, + Tensor values, + Tensor count, + const bool use_pipeline, + const int64_t preferred_sms); /// @ingroup embedding-ssd /// @@ -76,14 +87,22 @@ masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count); /// @param indices The 1D index tensor /// @param values The 2D input tensor (the tensor that is indexed) /// @param count The tensor that contains the length of `indices` to -/// process +/// process +/// @param use_pipeline A flag that indicates that this kernel will +/// overlap with other kernels. If it is true, then use a +/// fraction of SMs to reduce resource competition +///// @param preferred_sms The number of preferred SMs for the kernel to +/// use when use_pipeline=true. This value is ignored when +/// use_pipeline=false. /// /// @return The `self` tensor Tensor masked_index_select_cuda( Tensor self, Tensor indices, Tensor values, - Tensor count); + Tensor count, + const bool use_pipeline, + const int64_t preferred_sms); Tensor masked_index_put_byte_cuda( Tensor self, @@ -330,7 +349,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor self, " " Tensor indices, " " Tensor values, " - " Tensor count" + " Tensor count, " + " bool use_pipeline=False, " + " int preferred_sms=-1" ") -> Tensor"); DISPATCH_TO_CUDA("masked_index_put", masked_index_put_cuda); m.def( @@ -338,7 +359,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor self, " " Tensor indices, " " Tensor values, " - " Tensor count" + " Tensor count, " + " bool use_pipeline=False, " + " int preferred_sms=-1" ") -> Tensor"); DISPATCH_TO_CUDA("masked_index_select", masked_index_select_cuda); m.def( diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py index ef434dd8e..cbd2ccd98 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py @@ -41,9 +41,10 @@ def execute_masked_index_test( num_output_rows: int, dtype: torch.dtype, test_fn: Callable[ - [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor + [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool], torch.Tensor ], is_index_put: bool, + use_pipeline: bool, ) -> None: """ A helper function that generates inputs/outputs, runs @@ -83,7 +84,7 @@ def execute_masked_index_test( output_ref = torch.zeros(num_output_rows, D, dtype=dtype, device=device) # Run test - output = test_fn(output, indices, values, count) + output = test_fn(output, indices, values, count, use_pipeline) # Run reference indices = indices[:count_val] @@ -104,6 +105,7 @@ def execute_masked_index_test( D=st.integers(min_value=2, max_value=256), num_output_rows=st.integers(min_value=10, max_value=100), dtype=st.sampled_from([torch.float, torch.half]), + use_pipeline=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_masked_index_put( @@ -112,6 +114,7 @@ def test_masked_index_put( D: int, num_output_rows: int, dtype: torch.dtype, + use_pipeline: bool, ) -> None: """ Test correctness of torch.ops.fbgemm.masked_index_put against PyTorch's @@ -126,6 +129,7 @@ def test_masked_index_put( dtype=dtype, test_fn=torch.ops.fbgemm.masked_index_put, is_index_put=True, + use_pipeline=use_pipeline, ) # pyre-ignore [56] @@ -134,6 +138,7 @@ def test_masked_index_put( D=st.integers(min_value=2, max_value=256), num_value_rows=st.integers(min_value=10, max_value=100), dtype=st.sampled_from([torch.float, torch.half]), + use_pipeline=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_masked_index_select( @@ -142,6 +147,7 @@ def test_masked_index_select( D: int, num_value_rows: int, dtype: torch.dtype, + use_pipeline: bool, ) -> None: """ Test correctness of torch.ops.fbgemm.masked_index_select aginst @@ -156,6 +162,7 @@ def test_masked_index_select( dtype=dtype, test_fn=torch.ops.fbgemm.masked_index_select, is_index_put=False, + use_pipeline=use_pipeline, ) def expand_tensor(