Skip to content

Commit

Permalink
MX4 check smem and fixes (pytorch#2703)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2703

Add check for share memory size
Fix
- calc_quantized_size for MX4.
- CUDA stream

Reviewed By: sryap

Differential Revision: D58316467

fbshipit-source-id: 83f969b9bff47dc88fd070432f74be98a3d78a7b
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jun 10, 2024
1 parent 08e641c commit c7eea48
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 46 deletions.
10 changes: 9 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def none_throws(


class QuantizationContext:
def __init__(self, row_dim: int = ROW_DIM_DEFAULT) -> None:
def __init__(self, row_dim: int = ROW_DIM_DEFAULT, mx_group_size: int = 32) -> None:
self.row_dim = row_dim
self.row_dim_quant: int = -1
self.mx_group_size = mx_group_size


def _quantize_tensor(
Expand Down Expand Up @@ -202,6 +203,13 @@ def calc_quantized_size(
nrows = input_len // ctx.row_dim
ncols = (ctx.row_dim + 3) // 4 * 4 + 2 * 4
return nrows * ncols
elif self._comm_precision == SparseType.MX4:
assert (
input_len % 32 == 0
), f"input_len {input_len} needs to be multiple of group_size 32"
# quantized output size = half input size + number of groups (shared exp)
ctx = none_throws(ctx)
return (input_len // 2) + (input_len // ctx.mx_group_size)
else:
return input_len

Expand Down
12 changes: 0 additions & 12 deletions fbgemm_gpu/src/quantize_ops/mx/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,4 @@ __host__ __device__ __forceinline__ bool shift_left_mantissa(
return overflow;
}

#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void
gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) {
if (code != cudaSuccess) {
fprintf(
stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort)
exit(code);
}
}

#endif
14 changes: 1 addition & 13 deletions fbgemm_gpu/src/quantize_ops/mx_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h>
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "mx/common.cuh"

Expand All @@ -33,19 +34,6 @@ __constant__ float MX4_values[16] = {
-4.0f,
-6.0f};

//-----------------------------------------------------------------------
// Misc. helper functions
//-----------------------------------------------------------------------

inline uint32_t align(int a, int b) {
return (a + b - 1) / b * b;
}

// Refactor to use FBGEMM's
__host__ __device__ __forceinline__ uint32_t round_up(uint32_t a, uint32_t b) {
return ((a + b - 1) / b);
}

//---------------------------------------------------------
// Helper functions for quantization
//---------------------------------------------------------
Expand Down
94 changes: 83 additions & 11 deletions fbgemm_gpu/src/quantize_ops/quantize_mx.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,63 @@

namespace fbgemm_gpu {

// from codegen/training/backward/embedding_backward_split_template.cu
template <typename func_t>
int32_t compute_num_groups_and_dynamic_smem_bytes(
uint32_t* num_groups_per_block,
const int64_t mx_group_size,
const int device,
const func_t kernel_func_name) {
int32_t smem_bytes = 0;
// V100: 96 KB; A100: 160 KB; H100: 228 KB.
int max_shared_bytes = 0;
#ifndef USE_ROCM
cudaDeviceGetAttribute(
&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
#else
// MI100 has 64 KB local memory (shared memory) per workgroup
max_shared_bytes = 64 << 10;
#endif
C10_CUDA_KERNEL_LAUNCH_CHECK();
int shared_kb = max_shared_bytes >> 10;
// V100: 64 KB; A100: 96 KB; H100: 144 KB
#ifndef USE_ROCM
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK_GT(used_shared_kb, 0);
#else
// MI100 has independent shared mem and L1
int used_shared_kb = shared_kb;
#endif
const int used_shared_bytes = used_shared_kb << 10;
// Stay under used_shared_kb of shared memory (V100: 64 KB;
// A100: 96 KB; H100: 144 KB), num_groups must be a power
// of two.
// max(num_elem_in_block * sizeof(int), num_elem_in_block /2 * sizeof(uint8))
while ((smem_bytes = *num_groups_per_block * mx_group_size * (sizeof(int))) >=
used_shared_bytes) {
*num_groups_per_block /= 2;
}
TORCH_CHECK_GE(*num_groups_per_block, 1);

// Check
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x
// "Compute capability 7.x devices allow a single thread block to
// address the full capacity of shared memory: 96 KB on Volta,
// 64 KB on Turing. Kernels relying on shared memory allocations
// over 48 KB per block are architecture-specific, as such they
// must use dynamic shared memory (rather than statically sized
// arrays) and require an explicit opt-in using cudaFuncSetAttribute()".
#ifndef USE_ROCM
cudaFuncSetAttribute(
kernel_func_name,
cudaFuncAttributeMaxDynamicSharedMemorySize,
used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB
C10_CUDA_KERNEL_LAUNCH_CHECK();
#endif
return smem_bytes;
}

//-----------------------------------------------------------------------
// quantize_mx_cuda
//-----------------------------------------------------------------------
Expand All @@ -38,6 +95,12 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(
const int64_t rounding_mode = 0) {
TORCH_CHECK((mx_group_size % 32 == 0), "Group size needs to be power of 2");
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input);
// Currently we only support MX4 E2M1, for other MX types, we will dispatch
// different kernels
TORCH_CHECK(
scale_bits == 8 && elem_ebits == 2 && elem_mbits == 3 &&
elem_max_norm == 6.0,
"FBGEMM currently only supports MX4 E2M1.");

at::Device device = input.device();
const at::cuda::CUDAGuard device_guard{device};
Expand All @@ -46,25 +109,33 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(

RoundingMode rd = static_cast<RoundingMode>(rounding_mode);

const int num_groups_per_block = MAX_THREADS / mx_group_size;
const auto gridDim_x = round_up(total_num_groups, num_groups_per_block);
uint32_t num_groups_per_block = MAX_THREADS / mx_group_size;
const auto kernel_func = quantize_float_to_mx4_kernel<float>;

// Use shmem to find max exponent (int) and temporarily store output (unint8)
const int32_t smem_size = compute_num_groups_and_dynamic_smem_bytes(
&num_groups_per_block, mx_group_size, input.get_device(), kernel_func);

const auto gridDim_x = div_round_up(total_num_groups, num_groups_per_block);

const dim3 gridDim(gridDim_x);
const dim3 blockDim(mx_group_size, num_groups_per_block);

// Use shmem to find max exponent (int) and temporarily store output (unint8)
// max(num_elem_in_block * sizeof(int), num_elem_in_block /2 * sizeof(uint8))
const int smem_size = num_groups_per_block * mx_group_size * (sizeof(int));
auto output = at::empty(
(total_elems / 2) + total_num_groups, input.options().dtype(at::kByte));

// Call CUDA kernel
if (input.dtype() == torch::ScalarType::Half) {
TORCH_CHECK(0, " fp16 not supported for MX");
} else {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "quantize_float_to_mx4_kernel";
#endif
quantize_float_to_mx4_kernel<<<gridDim, blockDim, smem_size>>>(
kernel_func<<<
gridDim,
blockDim,
smem_size,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, input, float, 1, 64),
mx_group_size,
total_elems,
Expand All @@ -73,8 +144,6 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(
MAKE_PTA_WITH_NAME(func_name, output, uint8_t, 1, 64));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

gpuErrchk(cudaPeekAtLastError());
return output;
}
Expand All @@ -96,7 +165,7 @@ DLL_PUBLIC at::Tensor dequantize_mx_cuda(
total_elems, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
const int num_groups_per_block = MAX_THREADS / mx_group_size;
const auto gridDim_x = round_up(total_num_groups, num_groups_per_block);
const auto gridDim_x = div_round_up(total_num_groups, num_groups_per_block);
const dim3 gridDim(gridDim_x);
const dim3 blockDim(mx_group_size, num_groups_per_block);
Expand All @@ -108,15 +177,18 @@ DLL_PUBLIC at::Tensor dequantize_mx_cuda(
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "dequantize_mx4_to_float_kernel";
#endif
dequantize_mx4_to_float_kernel<<<gridDim, blockDim>>>(
dequantize_mx4_to_float_kernel<<<
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, input, uint8_t, 1, 64),
mx_group_size,
total_elems,
MAKE_PTA_WITH_NAME(func_name, output, float, 1, 64));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
gpuErrchk(cudaPeekAtLastError());
return output;
}
Expand Down
11 changes: 2 additions & 9 deletions fbgemm_gpu/src/quantize_ops/quantize_mx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/

/*
* Microsoft Confidential
*/

#ifndef PYT_MX_MX_CUH
#define PYT_MX_MX_CUH

Expand Down Expand Up @@ -213,14 +209,12 @@ __global__ void quantize_float_to_mx4_kernel(
}
__syncthreads();

const uint32_t data_size_per_group = half_group_size + 1;

// Let each thread write 1 byte of output data
if (threadIdx.x < half_group_size) {
// write data output using uint8_t (1 bytes)

uint8_t* smem_ptr = reinterpret_cast<uint8_t*>(smem_base);
const uint32_t start_output_idx = (data_size_per_group)*linear_group_id;
const uint32_t start_output_idx = (half_group_size + 1) * linear_group_id;
uint8_t* output_base = &output[start_output_idx];

output_base[threadIdx.x] = smem_ptr[threadIdx.x];
Expand Down Expand Up @@ -250,9 +244,8 @@ __global__ void dequantize_mx4_to_float_kernel(
return;

const uint32_t half_group_size = group_size / 2;
const uint32_t data_size_per_group = half_group_size + 1;

const uint32_t start_output_idx = (data_size_per_group)*linear_group_id;
const uint32_t start_output_idx = (half_group_size + 1) * linear_group_id;
uint8_t elem = input[start_output_idx + (threadIdx.x / 2)];
const uint32_t shared_exp_idx = start_output_idx + half_group_size;
const uint8_t shared_exp = input[shared_exp_idx];
Expand Down

0 comments on commit c7eea48

Please sign in to comment.