Skip to content

Commit

Permalink
Change data layout for MX4 (pytorch#2696)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2696

Change the data layout to avoid offset computation overhead. We trade this overhead performance with alignment.

We pack the 4-bit output data for each group (=16 bytes for group_size of 32) with 8-bit (1 byte) shared-exponent. So for each group of data, we pack 17 byte.

Reviewed By: sryap

Differential Revision: D58143133

fbshipit-source-id: 174afae68ad453b7dd6c32af051f306d0855dd1e
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jun 6, 2024
1 parent ce64e00 commit d2387cb
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 144 deletions.
2 changes: 0 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ at::Tensor fusednbitrowwise_to_float_or_half_cpu(

at::Tensor quantize_mx_cuda(
const at::Tensor& input,
const std::vector<int64_t>& split_sizes,
const int64_t scale_bits,
const int64_t elem_ebits,
const int64_t elem_mbits,
Expand All @@ -413,7 +412,6 @@ at::Tensor quantize_mx_cuda(

at::Tensor dequantize_mx_cuda(
const at::Tensor& input,
const std::vector<int64_t>& split_sizes,
const int64_t mx_group_size);

///@ingroup sparse-data-cuda
Expand Down
5 changes: 4 additions & 1 deletion fbgemm_gpu/src/quantize_ops/mx_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h>
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "mx/common.cuh"

//-----------------------------------------------------------------------
// MX4-Float mapping
//-----------------------------------------------------------------------

__device__ const float MX4_values[16] = {
__constant__ float MX4_values[16] = {
0.0f,
0.5f,
1.0f,
Expand Down
105 changes: 20 additions & 85 deletions fbgemm_gpu/src/quantize_ops/quantize_mx.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "fbgemm_gpu/ops_utils.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <ATen/core/TensorAccessor.h>
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "quantize_mx.cuh"

namespace fbgemm_gpu {
Expand All @@ -27,15 +29,13 @@ namespace fbgemm_gpu {
//-----------------------------------------------------------------------
DLL_PUBLIC at::Tensor quantize_mx_cuda(
const at::Tensor& input,
const std::vector<int64_t>& split_sizes,
const int64_t scale_bits,
const int64_t elem_ebits,
const int64_t elem_mbits,
const double elem_max_norm,
const int64_t mx_group_size,
const bool flush_fp32_subnorms = false,
const int64_t rounding_mode = 0) {
TORCH_CHECK((split_sizes.size() > 0), "Input split sizes cannot be empty");
TORCH_CHECK((mx_group_size % 32 == 0), "Group size needs to be power of 2");
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input);

Expand All @@ -44,40 +44,6 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(
const uint32_t total_elems = input.numel();
const uint32_t total_num_groups = input.numel() / mx_group_size;

// Compute offsets to be passed to kernel
auto start_output_cumsum =
at::empty(split_sizes.size() + 1, at::TensorOptions().dtype(at::kUInt32));
auto group_ids =
at::empty(total_num_groups, at::TensorOptions().dtype(at::kUInt32));
auto num_groups_cumsum =
at::empty(split_sizes.size() + 1, at::TensorOptions().dtype(at::kUInt32));
uint32_t offset = 0;
start_output_cumsum[0] = 0;
num_groups_cumsum[0] = 0;
uint32_t num_groups_cumsum_ = 0;
int64_t start_idx = 0;
int64_t end_idx = 0;
for (int i = 0; i < split_sizes.size(); i++) {
const uint32_t split_size = split_sizes[i];

TORCH_CHECK(
split_size % mx_group_size == 0,
" Number of inputs needs to be a multiple of group size");
const uint32_t num_groups = split_size / mx_group_size;
end_idx += num_groups;
offset += align((split_size / 2) + num_groups, 16);
start_output_cumsum[i + 1] = offset;
num_groups_cumsum_ += num_groups;
num_groups_cumsum[i + 1] = num_groups_cumsum_;
group_ids.index_put_({at::indexing::Slice(start_idx, end_idx)}, i);
start_idx = end_idx;
}

// TODO: Search in the kernel
start_output_cumsum = start_output_cumsum.to(device, /*non_blocking=*/true);
group_ids = group_ids.to(device, /*non_blocking=*/true);
num_groups_cumsum = num_groups_cumsum.to(device, /*non_blocking=*/true);

RoundingMode rd = static_cast<RoundingMode>(rounding_mode);

const int num_groups_per_block = MAX_THREADS / mx_group_size;
Expand All @@ -90,22 +56,21 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(
// max(num_elem_in_block * sizeof(int), num_elem_in_block /2 * sizeof(uint8))
const int smem_size = num_groups_per_block * mx_group_size * (sizeof(int));
auto output = at::empty(
offset, // 4 = sizeof(float)
input.options().dtype(at::kByte));
(total_elems / 2) + total_num_groups, input.options().dtype(at::kByte));
// Call CUDA kernel
if (input.dtype() == torch::ScalarType::Half) {
AT_ASSERTM(0, " fp16 not supported for MX");
TORCH_CHECK(0, " fp16 not supported for MX");
} else {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "quantize_float_to_mx4_kernel";
#endif
quantize_float_to_mx4_kernel<<<gridDim, blockDim, smem_size>>>(
input.data_ptr<float>(),
MAKE_PTA_WITH_NAME(func_name, input, float, 1, 64),
mx_group_size,
group_ids.data_ptr<uint32_t>(),
start_output_cumsum.data_ptr<uint32_t>(),
num_groups_cumsum.data_ptr<uint32_t>(),
total_elems,
flush_fp32_subnorms,
rd,
output.data_ptr<uint8_t>());
MAKE_PTA_WITH_NAME(func_name, output, uint8_t, 1, 64));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

Expand All @@ -115,48 +80,18 @@ DLL_PUBLIC at::Tensor quantize_mx_cuda(

DLL_PUBLIC at::Tensor dequantize_mx_cuda(
const at::Tensor& input,
const std::vector<int64_t>& split_sizes,
const int64_t mx_group_size) {
TORCH_CHECK((split_sizes.size() > 0), "Input sizes cannot be empty");
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input);
at::Device device = input.device();
const at::cuda::CUDAGuard device_guard{device};
// num quantized elems = half of the total float elms + total number of groups
// so, quantized input size = (total_num_elems/2)+(total_num_elems/group_size)
// Note that this formula won't work if there's padding to quantized output
// and total_elems need to be passed.
const int64_t total_elems =
std::accumulate(split_sizes.begin(), split_sizes.end(), 0);
(2 * mx_group_size * input.numel()) / (mx_group_size + 2);
const uint32_t total_num_groups = total_elems / mx_group_size;

auto start_output_cumsum =
at::empty(split_sizes.size() + 1, at::TensorOptions().dtype(at::kUInt32));
auto group_ids =
at::empty(total_num_groups, at::TensorOptions().dtype(at::kUInt32));
auto num_groups_cumsum =
at::empty(split_sizes.size() + 1, at::TensorOptions().dtype(at::kUInt32));
uint32_t offset = 0;
start_output_cumsum[0] = 0;
num_groups_cumsum[0] = 0;
uint32_t num_groups_cumsum_ = 0;
int64_t start_idx = 0;
int64_t end_idx = 0;

for (int i = 0; i < split_sizes.size(); i++) {
const uint32_t split_size = split_sizes[i];

TORCH_CHECK(
split_size % mx_group_size == 0,
" Number of inputs needs to be a multiple of group size");
const uint32_t num_groups = split_size / mx_group_size;
end_idx += num_groups;
offset += align((split_size / 2) + num_groups, 16);
start_output_cumsum[i + 1] = offset;
num_groups_cumsum_ += num_groups;
num_groups_cumsum[i + 1] = num_groups_cumsum_;
group_ids.index_put_({at::indexing::Slice(start_idx, end_idx)}, i);
start_idx = end_idx;
}
start_output_cumsum = start_output_cumsum.to(device, /*non_blocking=*/true);
group_ids = group_ids.to(device, /*non_blocking=*/true);
num_groups_cumsum = num_groups_cumsum.to(device, /*non_blocking=*/true);

auto output = at::empty(
total_elems, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
Expand All @@ -168,16 +103,16 @@ DLL_PUBLIC at::Tensor dequantize_mx_cuda(

// Call CUDA kernel
if (input.dtype() == torch::ScalarType::Half) {
AT_ASSERTM(0, " fp16 not supported for MX");
TORCH_CHECK(0, " fp16 not supported for MX");
} else {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "dequantize_mx4_to_float_kernel";
#endif
dequantize_mx4_to_float_kernel<<<gridDim, blockDim>>>(
input.data_ptr<uint8_t>(),
MAKE_PTA_WITH_NAME(func_name, input, uint8_t, 1, 64),
mx_group_size,
total_elems,
group_ids.data_ptr<uint32_t>(),
start_output_cumsum.data_ptr<uint32_t>(),
num_groups_cumsum.data_ptr<uint32_t>(),
output.data_ptr<float>());
MAKE_PTA_WITH_NAME(func_name, output, float, 1, 64));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

Expand Down
79 changes: 26 additions & 53 deletions fbgemm_gpu/src/quantize_ops/quantize_mx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,17 @@ __device__ __forceinline__ uint8_t quantize_elemwise_4bit(

template <typename T>
__global__ void quantize_float_to_mx4_kernel(
const T* __restrict__ input,
const pta::PackedTensorAccessor64<T, 1, at::RestrictPtrTraits> input,
const int group_size, // can change to Blockdim.x
const uint32_t* __restrict__ group_ids,
const uint32_t* __restrict__ start_output_cumsum,
const uint32_t* __restrict__ num_groups_cumsum,
const uint32_t total_elems,
const bool flush_fp32_subnorms,
const RoundingMode rounding_mode,
uint8_t* __restrict__ output) {
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output) {
const auto linear_group_id = (blockIdx.x * blockDim.y) + threadIdx.y;
const auto linear_tid = linear_group_id * group_size + threadIdx.x;
if (linear_tid >= total_elems)
return;

const uint32_t rank_id = group_ids[linear_group_id];
const uint32_t accum_num_groups = num_groups_cumsum[rank_id];
const uint32_t num_elems_per_rank =
(num_groups_cumsum[rank_id + 1] - accum_num_groups) * group_size;
const uint32_t start_output_idx = start_output_cumsum[rank_id];

// offsets for within rank to write quantized data and shared_exponent
const uint32_t group_id_in_rank = linear_group_id - accum_num_groups;

// MX4 values
constexpr int scale_bits = 8;
constexpr int elem_ebits = 2;
Expand Down Expand Up @@ -225,27 +213,24 @@ __global__ void quantize_float_to_mx4_kernel(
}
__syncthreads();

// Let only thread0 write output data and shared exponent
if (threadIdx.x == 0) {
// write data output using float4 (16 bytes)
// 1 output is 1 byte, we write 16 outputs in 1 float4
// group_size needs to be multiple of 32 so that the output
// will be multiple of 16 bytes
const int num_vecs = group_size / 32;
// smem is float, move every 4 bytes
// need to move half_group_size / 4
float4* smem_ptr = reinterpret_cast<float4*>(smem_base);
const uint32_t data_size_per_group = half_group_size + 1;

// Let each thread write 1 byte of output data
if (threadIdx.x < half_group_size) {
// write data output using uint8_t (1 bytes)

uint8_t* smem_ptr = reinterpret_cast<uint8_t*>(smem_base);
const uint32_t start_output_idx = (data_size_per_group)*linear_group_id;
uint8_t* output_base = &output[start_output_idx];
float4* output_ptr = reinterpret_cast<float4*>(
output_base + group_id_in_rank * half_group_size);
for (int i = 0; i < num_vecs; i++) {
output_ptr[i] = smem_ptr[i];

output_base[threadIdx.x] = smem_ptr[threadIdx.x];

// write share exp
if (threadIdx.x == 0) {
// shared_exp_idx is stored after data
// need to offset with start_output + output data
output_base[half_group_size] = clamped_shared_exp;
}
// shared_exp_idx is stored after data
// need to offset with start_output + output data
const uint32_t shared_exp_offset =
(num_elems_per_rank / 2) + group_id_in_rank;
output_base[shared_exp_offset] = clamped_shared_exp;
}
}

Expand All @@ -255,33 +240,22 @@ __global__ void quantize_float_to_mx4_kernel(

template <typename T>
__global__ void dequantize_mx4_to_float_kernel(
const uint8_t* __restrict__ input,
const pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> input,
const int group_size,
const int64_t total_elems,
const uint32_t* __restrict__ group_ids,
const uint32_t* __restrict__ start_output_cumsum,
const uint32_t* __restrict__ num_groups_cumsum,
T* __restrict__ output) {
pta::PackedTensorAccessor64<T, 1, at::RestrictPtrTraits> output) {
const auto linear_group_id = (blockIdx.x * blockDim.y) + threadIdx.y;
const auto linear_tid = linear_group_id * group_size + threadIdx.x;
if (linear_tid >= total_elems)
return;

const uint32_t rank_id = group_ids[linear_group_id];
const uint32_t accum_num_groups = num_groups_cumsum[rank_id];
const uint32_t num_groups_per_rank =
num_groups_cumsum[rank_id + 1] - accum_num_groups;
const uint32_t num_elems_per_rank = num_groups_per_rank * group_size;
const uint32_t start_output_idx = start_output_cumsum[rank_id];

const uint32_t tid_in_rank = linear_tid - (accum_num_groups * group_size);
const uint32_t group_id_in_rank = linear_group_id - accum_num_groups;

const uint32_t shared_exp_idx =
start_output_idx + (num_elems_per_rank / 2) + group_id_in_rank;
const uint32_t output_idx = start_output_idx + round_up(tid_in_rank - 1, 2);
const uint32_t half_group_size = group_size / 2;
const uint32_t data_size_per_group = half_group_size + 1;

uint8_t elem = input[output_idx];
const uint32_t start_output_idx = (data_size_per_group)*linear_group_id;
uint8_t elem = input[start_output_idx + (threadIdx.x / 2)];
const uint32_t shared_exp_idx = start_output_idx + half_group_size;
const uint8_t shared_exp = input[shared_exp_idx];

constexpr uint8_t upper_4bit_mask = 0xF0;
constexpr uint8_t lower_4bit_mask = 0x0F;
Expand All @@ -296,7 +270,6 @@ __global__ void dequantize_mx4_to_float_kernel(
}
CUDA_KERNEL_ASSERT(elem < 16);

const uint8_t shared_exp = input[shared_exp_idx];
output[linear_tid] = MX4_values[elem] * pow(2, shared_exp - FLOAT32_EXP_BIAS);
}

Expand Down
5 changes: 2 additions & 3 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"PaddedFP8RowwiseQuantizedToFloat(Tensor input, bool forward, int row_dim, int output_last_dim=-1, int output_dtype=0) -> Tensor");
m.def(
"quantize_mx_cuda(Tensor input, int[] split_sizes, int scale_bits, int elem_ebits, int elem_mbits, float elem_max_norm, int mx_group_size, bool flush_fp32_subnorms=False, int rounding_mode=0) -> Tensor");
m.def(
"dequantize_mx_cuda(Tensor input, int[] split_sizes, int mx_group_size) -> Tensor");
"quantize_mx_cuda(Tensor input, int scale_bits, int elem_ebits, int elem_mbits, float elem_max_norm, int mx_group_size, bool flush_fp32_subnorms=False, int rounding_mode=0) -> Tensor");
m.def("dequantize_mx_cuda(Tensor input, int mx_group_size) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
Expand Down

0 comments on commit d2387cb

Please sign in to comment.