Skip to content

Commit

Permalink
Call cudaGetDeviceProperties once in masked_index_impl (pytorch#3136)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3136

X-link: facebookresearch/FBGEMM#229

This diff caches the default number of SMs for pipeline prefetching to
avoid calling `cudaGetDeviceProperties` in every `masked_index_*` call

Reviewed By: chrisxcai

Differential Revision: D62672190

fbshipit-source-id: a5865f602f394ba4909ab1334661df4689b8cfd5
  • Loading branch information
sryap authored and facebook-github-bot committed Sep 16, 2024
1 parent a90aac1 commit 49fa9a5
Showing 1 changed file with 20 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,27 @@
#include "fbgemm_gpu/utils/tensor_utils.h"
#include "fbgemm_gpu/utils/vec4.cuh"

constexpr int ALL_TO_PREFETCH_SM_RATIO = 8;

using Tensor = at::Tensor;

using namespace fbgemm_gpu;

int get_masked_index_default_pipeline_sms(int device) {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device);

// The default number of SMs for use_pipeline=true is set based on an
// empirical study
if (prop.major == 8) {
// Assume A100
return 4;
} else if (prop.major == 9) {
// Assume H100
return 16;
}
constexpr int ALL_TO_PREFETCH_SM_RATIO = 8;
return div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO);
}

template <typename scalar_t>
DEVICE_INLINE void
vec4_copy(scalar_t* dst, const scalar_t* src, const int32_t D) {
Expand Down Expand Up @@ -103,26 +118,11 @@ Tensor masked_index_impl(

const auto full_grid_size = div_round_up(N, kMaxThreads / tx);

// The default number of SMs for use_pipeline=true is set based on an
// empirical study

cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, at::cuda::current_device());

int DEFAULT_PIPELINE_SMS;
if (prop.major == 8) {
// Assume A100
DEFAULT_PIPELINE_SMS = 4;
} else if (prop.major == 9) {
// Assume H100
DEFAULT_PIPELINE_SMS = 16;
} else {
DEFAULT_PIPELINE_SMS =
div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO);
}
static int masked_index_default_pipeline_sms =
get_masked_index_default_pipeline_sms(at::cuda::current_device());

const int pipeline_grid_size =
preferred_sms == -1 ? DEFAULT_PIPELINE_SMS : preferred_sms;
preferred_sms == -1 ? masked_index_default_pipeline_sms : preferred_sms;
TORCH_CHECK(
!use_pipeline || pipeline_grid_size >= 1, "preferred_sms must >= 1");

Expand Down

0 comments on commit 49fa9a5

Please sign in to comment.