Skip to content

Commit

Permalink
Reduce prefetch SM usage when using pipeline prefetching (pytorch#2991)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#92

Pull Request resolved: pytorch#2991

This diff limits the SM usage on `masked_index_put` and
`masked_index_get` when pipeline prefetching is used by using a small
grid size.  This is to reduce the interference of these kernels with
the kernels on the compute stream during prefetch.  The grid size is
currently set to 1/8 of the total number of SMs.  However, this number
has to be tuned.

Reviewed By: jianyuh

Differential Revision: D61145930

fbshipit-source-id: 8f8096ba208b41dd537a582b9c539542cbc3ec82
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 16, 2024
1 parent e73b1f8 commit 7a2ec83
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 32 deletions.
2 changes: 2 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ def prefetch( # noqa C901
sp_prev_curr_map_gpu,
inserted_rows_prev,
actions_count_gpu,
use_pipeline=self.prefetch_pipeline,
)

# Record the tensors that will be pushed into a queue
Expand Down Expand Up @@ -1094,6 +1095,7 @@ def prefetch( # noqa C901
assigned_cache_slots,
inserted_rows,
actions_count_gpu,
use_pipeline=self.prefetch_pipeline,
)

if linear_cache_indices.numel() > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "fbgemm_gpu/utils/tensor_accessor.h"
#include "fbgemm_gpu/utils/vec4.cuh"

constexpr int ALL_TO_PREFETCH_SM_RATIO = 8;

using Tensor = at::Tensor;

using namespace fbgemm_gpu;
Expand Down Expand Up @@ -59,31 +61,30 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
count) {
const int32_t N = indices.size(0);
const int32_t n = blockIdx.x * blockDim.y + threadIdx.y;
if (n >= N) {
return;
}
const auto count_ = count[0];
if (n >= count_) {
return;
}
// idx == -1 if it is conflict miss
const auto idx = indices[n];
if (idx < 0) {
return;
CUDA_KERNEL_ASSERT(count_ <= N);
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < count_;
n += blockDim.y * gridDim.x) {
// idx == -1 if it is conflict miss
const auto idx = indices[n];
if (idx < 0) {
continue;
}
const auto D = self.size(1);
const auto self_idx = is_index_put ? idx : n;
const auto values_idx = is_index_put ? n : idx;
vec4_copy(&self[self_idx][0], &values[values_idx][0], D);
}
const auto D = self.size(1);
const auto self_idx = is_index_put ? idx : n;
const auto values_idx = is_index_put ? n : idx;
vec4_copy(&self[self_idx][0], &values[values_idx][0], D);
}

template <bool is_index_put>
Tensor masked_index_impl(
const Tensor& self,
const Tensor& indices,
const Tensor& values,
const Tensor& count) {
const Tensor& count,
const bool use_pipeline,
const int preferred_sms) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(self, indices, values, count);
TENSOR_CONTIGUOUS(self);
TENSOR_CONTIGUOUS(indices);
Expand All @@ -98,12 +99,43 @@ Tensor masked_index_impl(
const auto D = self.size(1);
TORCH_CHECK_EQ(self.size(1), values.size(1));

const int32_t tx = std::min<int32_t>(D / 4, kMaxThreads);
const dim3 threads(tx, kMaxThreads / tx);

const auto full_grid_size = div_round_up(N, kMaxThreads / tx);

// The default number of SMs for use_pipeline=true is set based on an
// empirical study

cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, at::cuda::current_device());

int DEFAULT_PIPELINE_SMS;
if (prop.major == 8) {
// Assume A100
DEFAULT_PIPELINE_SMS = 4;
} else if (prop.major == 9) {
// Assume H100
DEFAULT_PIPELINE_SMS = 16;
} else {
DEFAULT_PIPELINE_SMS =
div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO);
}

const int pipeline_grid_size =
preferred_sms == -1 ? DEFAULT_PIPELINE_SMS : preferred_sms;
TORCH_CHECK(
!use_pipeline || pipeline_grid_size >= 1, "preferred_sms must >= 1");

// Use a fraction of SMs if use_pipeline=true
const auto grid_size = use_pipeline
? std::min(pipeline_grid_size, full_grid_size)
: full_grid_size;

FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
self.scalar_type(),
is_index_put ? "masked_index_put" : "masked_index_select",
[&] {
const int32_t tx = std::min<int32_t>(D / 4, kMaxThreads);
const dim3 threads(tx, kMaxThreads / tx);
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = is_index_put ? "masked_index_put_kernel"
: "masked_index_select_kernel";
Expand All @@ -112,7 +144,7 @@ Tensor masked_index_impl(
TORCH_CHECK(D % 16 == 0, "D needs to be padded to be multiple of 16")
}
masked_index_kernel<scalar_t, is_index_put>
<<<div_round_up(N, kMaxThreads / tx),
<<<grid_size,
dim3(tx, kMaxThreads / tx),
0,
at::cuda::getCurrentCUDAStream()>>>(
Expand All @@ -131,17 +163,22 @@ Tensor masked_index_put_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count) {
return masked_index_impl</*is_index_put=*/true>(self, indices, values, count);
Tensor count,
const bool use_pipeline,
const int64_t preferred_sms) {
return masked_index_impl</*is_index_put=*/true>(
self, indices, values, count, use_pipeline, preferred_sms);
}
Tensor masked_index_select_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count) {
Tensor count,
const bool use_pipeline,
const int64_t preferred_sms) {
return masked_index_impl</*is_index_put=*/false>(
self, indices, values, count);
self, indices, values, count, use_pipeline, preferred_sms);
}
__global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,22 @@ ssd_cache_populate_actions_cuda(
/// @param indices The 1D index tensor
/// @param values The 2D input tensor
/// @param count The tensor that contains the length of `indices` to
/// process
/// process
/// @param use_pipeline A flag that indicates that this kernel will
/// overlap with other kernels. If it is true, then use a
/// fraction of SMs to reduce resource competition
/// @param preferred_sms The number of preferred SMs for the kernel to
/// use when use_pipeline=true. This value is ignored when
/// use_pipeline=false.
///
/// @return The `self` tensor
Tensor
masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count);
Tensor masked_index_put_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count,
const bool use_pipeline,
const int64_t preferred_sms);

/// @ingroup embedding-ssd
///
Expand All @@ -76,14 +87,22 @@ masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count);
/// @param indices The 1D index tensor
/// @param values The 2D input tensor (the tensor that is indexed)
/// @param count The tensor that contains the length of `indices` to
/// process
/// process
/// @param use_pipeline A flag that indicates that this kernel will
/// overlap with other kernels. If it is true, then use a
/// fraction of SMs to reduce resource competition
///// @param preferred_sms The number of preferred SMs for the kernel to
/// use when use_pipeline=true. This value is ignored when
/// use_pipeline=false.
///
/// @return The `self` tensor
Tensor masked_index_select_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count);
Tensor count,
const bool use_pipeline,
const int64_t preferred_sms);

Tensor masked_index_put_byte_cuda(
Tensor self,
Expand Down Expand Up @@ -330,15 +349,19 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor self, "
" Tensor indices, "
" Tensor values, "
" Tensor count"
" Tensor count, "
" bool use_pipeline=False, "
" int preferred_sms=-1"
") -> Tensor");
DISPATCH_TO_CUDA("masked_index_put", masked_index_put_cuda);
m.def(
"masked_index_select("
" Tensor self, "
" Tensor indices, "
" Tensor values, "
" Tensor count"
" Tensor count, "
" bool use_pipeline=False, "
" int preferred_sms=-1"
") -> Tensor");
DISPATCH_TO_CUDA("masked_index_select", masked_index_select_cuda);
m.def(
Expand Down
11 changes: 9 additions & 2 deletions fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def execute_masked_index_test(
num_output_rows: int,
dtype: torch.dtype,
test_fn: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool], torch.Tensor
],
is_index_put: bool,
use_pipeline: bool,
) -> None:
"""
A helper function that generates inputs/outputs, runs
Expand Down Expand Up @@ -83,7 +84,7 @@ def execute_masked_index_test(
output_ref = torch.zeros(num_output_rows, D, dtype=dtype, device=device)

# Run test
output = test_fn(output, indices, values, count)
output = test_fn(output, indices, values, count, use_pipeline)

# Run reference
indices = indices[:count_val]
Expand All @@ -104,6 +105,7 @@ def execute_masked_index_test(
D=st.integers(min_value=2, max_value=256),
num_output_rows=st.integers(min_value=10, max_value=100),
dtype=st.sampled_from([torch.float, torch.half]),
use_pipeline=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_masked_index_put(
Expand All @@ -112,6 +114,7 @@ def test_masked_index_put(
D: int,
num_output_rows: int,
dtype: torch.dtype,
use_pipeline: bool,
) -> None:
"""
Test correctness of torch.ops.fbgemm.masked_index_put against PyTorch's
Expand All @@ -126,6 +129,7 @@ def test_masked_index_put(
dtype=dtype,
test_fn=torch.ops.fbgemm.masked_index_put,
is_index_put=True,
use_pipeline=use_pipeline,
)

# pyre-ignore [56]
Expand All @@ -134,6 +138,7 @@ def test_masked_index_put(
D=st.integers(min_value=2, max_value=256),
num_value_rows=st.integers(min_value=10, max_value=100),
dtype=st.sampled_from([torch.float, torch.half]),
use_pipeline=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_masked_index_select(
Expand All @@ -142,6 +147,7 @@ def test_masked_index_select(
D: int,
num_value_rows: int,
dtype: torch.dtype,
use_pipeline: bool,
) -> None:
"""
Test correctness of torch.ops.fbgemm.masked_index_select aginst
Expand All @@ -156,6 +162,7 @@ def test_masked_index_select(
dtype=dtype,
test_fn=torch.ops.fbgemm.masked_index_select,
is_index_put=False,
use_pipeline=use_pipeline,
)

def expand_tensor(
Expand Down

0 comments on commit 7a2ec83

Please sign in to comment.