Skip to content

Commit

Permalink
Use int64_t for buffer indices to avoid overflow
Browse files Browse the repository at this point in the history
Summary:
Use `int64_t` instead of int for buffer indices
in FP8 rowwise cuda kernel to avoid overflow

Reviewed By: jspark1105, sryap

Differential Revision: D46980577

fbshipit-source-id: 91f896ce93234b974c004c8947de4df6caf88c57
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jul 27, 2023
1 parent 78c60ce commit 3579b4d
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ namespace {
template <typename input_t>
__global__ inline void _float_to_FP8rowwise_cuda_kernel(
const input_t* __restrict__ input,
const int nrows,
const int ncols,
const int64_t nrows,
const int64_t ncols,
std::uint8_t* __restrict__ output,
const bool forward) {
constexpr float kEpsilon = 1e-20f;
const int ebit = forward ? 4 : 5;
const int bias = forward ? 15 : 31;
const float max_pos = forward ? 0.9375 : 0.875;

const int ncols_aligned = (ncols + 4 - 1) / 4 * 4;
const int output_columns = ncols_aligned + 2 * sizeof(float);
const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4;
const int64_t output_columns = ncols_aligned + 2 * sizeof(float);

const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int64_t row = blockIdx.x * blockDim.x + threadIdx.x;

if (row < nrows) {
const input_t* input_row = input + row * ncols;
Expand All @@ -47,7 +47,7 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel(
const auto scale =
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale_bias[0] = scale;
for (std::size_t col = 0; col < ncols; ++col) {
for (int64_t col = 0; col < ncols; ++col) {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
}
Expand All @@ -57,15 +57,15 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel(
template <typename input_t>
__global__ inline void _get_FP8_qparam_cuda_kernel(
const input_t* __restrict__ input,
const int nrows,
const int ncols,
const int64_t nrows,
const int64_t ncols,
uint8_t* __restrict__ output,
float* __restrict__ range_list,
const bool forward) {
const int row = (int)blockIdx.x * blockDim.y + threadIdx.y;
const int64_t row = blockIdx.x * blockDim.y + threadIdx.y;

const int ncols_aligned = (ncols + 4 - 1) / 4 * 4;
const int output_columns = ncols_aligned + 2 * sizeof(float);
const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4;
const int64_t output_columns = ncols_aligned + 2 * sizeof(float);
float max_pos;
if (forward) {
max_pos = 0.9375;
Expand All @@ -84,7 +84,7 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
if (row < nrows) {
const input_t* const input_row = input + row * ncols;

for (int col = threadIdx.x; col < ncols; col += lane_width) {
for (int64_t col = threadIdx.x; col < ncols; col += lane_width) {
// Get thread-local minmax. These are the smallest min and max ever seen
// by this thread.
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
Expand Down Expand Up @@ -116,8 +116,8 @@ template <typename input_t>
__global__ inline void _compute_FP8_quantize_cuda_kernel(
const input_t* const __restrict__ input,
const float* const __restrict__ range_list,
const int nrows,
const int ncols,
const int64_t nrows,
const int64_t ncols,
std::uint8_t* const __restrict__ output,
const bool forward) {
int ebit;
Expand All @@ -133,18 +133,18 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
max_pos = 0.875;
}

const int ncols_aligned = (ncols + 4 - 1) / 4 * 4;
const int output_columns = ncols_aligned + 2 * sizeof(float);
const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4;
const int64_t output_columns = ncols_aligned + 2 * sizeof(float);

int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
const int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_incre = blockDim.y * gridDim.y;
int64_t row = blockIdx.y * blockDim.y + threadIdx.y;
const int64_t col = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t row_incre = blockDim.y * gridDim.y;
for (/*row*/; row < nrows; row += row_incre) {
if (col < ncols) {
float* row_qparams = reinterpret_cast<float*>(
output + row * output_columns + ncols_aligned);
const float scale = row_qparams[0];
const int input_idx = row * ncols + col;
const auto input_idx = row * ncols + col;
uint8_t* output_addr = output + row * output_columns + col;
// TODO: lift range_list into shared memory. However, when nrows is large,
// it might exceed the size of shared memory.
Expand Down

0 comments on commit 3579b4d

Please sign in to comment.