Skip to content

Commit

Permalink
Add kernel checks (pytorch#2718)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2718

Add kernel checks

Reviewed By: sryap

Differential Revision: D58376771

fbshipit-source-id: 49d50071fa2ce4993e1102fcef2bbcc4124a7a63
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jun 11, 2024
1 parent 032cc02 commit 46d6300
Showing 1 changed file with 75 additions and 4 deletions.
79 changes: 75 additions & 4 deletions fbgemm_gpu/src/quantize_ops/quantize_mx.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ int32_t compute_num_groups_and_dynamic_smem_bytes(
const int device,
const func_t kernel_func_name) {
int32_t smem_bytes = 0;

// V100: 96 KB; A100: 160 KB; H100: 228 KB.
int max_shared_bytes = 0;
#ifndef USE_ROCM
Expand Down Expand Up @@ -93,9 +94,31 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(
const int64_t mx_group_size,
const bool flush_fp32_subnorms = false,
const int64_t rounding_mode = 0) {
TORCH_CHECK((mx_group_size % 32 == 0), "Group size needs to be power of 2");
TORCH_CHECK(mx_group_size > 0, "Group size needs to be > 0");
TORCH_CHECK(
mx_group_size % 32 == 0,
"Group size needs to be multiply of 32 but is found to be ",
mx_group_size);
TORCH_CHECK(!flush_fp32_subnorms, "flush_fp32_subnorms is not yet supported");
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input);

const uint32_t total_elems = input.numel();
if (total_elems == 0) {
return at::empty(0, input.options().dtype(at::kByte));
}
TORCH_CHECK(
total_elems > mx_group_size,
"Input needs to be > mx_group_size of ",
mx_group_size,
" but is found to be ",
total_elems);
TORCH_CHECK(
total_elems % mx_group_size == 0,
"Input needs to be multiply of ",
mx_group_size,
"but is found to be ",
total_elems);

// Currently we only support MX4 E2M1, for other MX types, we will dispatch
// different kernels
TORCH_CHECK(
Expand All @@ -105,26 +128,42 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(

at::Device device = input.device();
const at::cuda::CUDAGuard device_guard{device};
const uint32_t total_elems = input.numel();
const uint32_t total_num_groups = input.numel() / mx_group_size;

RoundingMode rd = static_cast<RoundingMode>(rounding_mode);

uint32_t num_groups_per_block = MAX_THREADS / mx_group_size;
const auto kernel_func = quantize_float_to_mx4_kernel<float>;

int device_id = input.get_device();

// Use shmem to find max exponent (int) and temporarily store output (unint8)
const int32_t smem_size = compute_num_groups_and_dynamic_smem_bytes(
&num_groups_per_block, mx_group_size, input.get_device(), kernel_func);
&num_groups_per_block, mx_group_size, device_id, kernel_func);

const auto gridDim_x = div_round_up(total_num_groups, num_groups_per_block);
const auto gridDim_x =
max(1, div_round_up(total_num_groups, num_groups_per_block));

const dim3 gridDim(gridDim_x);
const dim3 blockDim(mx_group_size, num_groups_per_block);

auto output = at::empty(
(total_elems / 2) + total_num_groups, input.options().dtype(at::kByte));

int max_grid_size = 0;
cudaDeviceGetAttribute(&max_grid_size, cudaDevAttrMaxGridDimX, device_id);

TORCH_CHECK(
gridDim_x > 0 && gridDim_x <= max_grid_size,
"gridDim_x is of bound with value ",
gridDim_x,
". MaxGridDimX is ",
max_grid_size);
TORCH_CHECK(
smem_size >= 0,
"shared memory size needs to be >= 0 but found to be ",
smem_size);

// Call CUDA kernel
if (input.dtype() == torch::ScalarType::Half) {
TORCH_CHECK(0, " fp16 not supported for MX");
Expand Down Expand Up @@ -152,6 +191,15 @@ DLL_PUBLIC at::Tensor dequantize_mx_cuda(
const at::Tensor& input,
const int64_t mx_group_size) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input);
TORCH_CHECK(mx_group_size > 0, "Group size needs to be > 0");
TORCH_CHECK(
mx_group_size % 32 == 0,
"Group size needs to be multiply of 32 but is found to be ",
mx_group_size);
if (input.numel() == 0) {
return at::empty(0, input.options().dtype(at::kFloat));
}
at::Device device = input.device();
const at::cuda::CUDAGuard device_guard{device};
// num quantized elems = half of the total float elms + total number of groups
Expand All @@ -160,6 +208,18 @@ DLL_PUBLIC at::Tensor dequantize_mx_cuda(
// and total_elems need to be passed.
const int64_t total_elems =
(2 * mx_group_size * input.numel()) / (mx_group_size + 2);
TORCH_CHECK(
total_elems > mx_group_size,
"Input needs to be > mx_group_size of ",
mx_group_size,
" but is found to be ",
total_elems);
TORCH_CHECK(
total_elems % mx_group_size == 0,
"Input needs to be multiply of ",
mx_group_size,
" but is found to be ",
total_elems);
const uint32_t total_num_groups = total_elems / mx_group_size;
auto output = at::empty(
Expand All @@ -171,6 +231,17 @@ DLL_PUBLIC at::Tensor dequantize_mx_cuda(
const dim3 gridDim(gridDim_x);
const dim3 blockDim(mx_group_size, num_groups_per_block);
int max_grid_size = 0;
cudaDeviceGetAttribute(
&max_grid_size, cudaDevAttrMaxGridDimX, input.get_device());
TORCH_CHECK(
gridDim_x > 0 && gridDim_x <= max_grid_size,
"gridDim_x is of bound with value ",
gridDim_x,
". MaxGridDimX is ",
max_grid_size);
// Call CUDA kernel
if (input.dtype() == torch::ScalarType::Half) {
TORCH_CHECK(0, " fp16 not supported for MX");
Expand Down

0 comments on commit 46d6300

Please sign in to comment.