Skip to content

Commit

Permalink
Split Quantize Ops, pt. 1 (pytorch#1863)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1863

- Move quantize ops files over to their own directory

- Break up `quantize_ops.cu`, pt. 1

Reviewed By: sryap

Differential Revision: D47237288

fbshipit-source-id: 48d5c4fbb00978919bfa93b5948d2459f5639b63
  • Loading branch information
q10 authored and facebook-github-bot committed Jul 7, 2023
1 parent ed4fe6e commit f94db2c
Show file tree
Hide file tree
Showing 11 changed files with 820 additions and 702 deletions.
10 changes: 7 additions & 3 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ set(fbgemm_gpu_sources_static_cpu
src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
src/input_combine_cpu.cpp
src/layout_transform_ops_cpu.cpp
src/quantize_ops_cpu.cpp
src/quantize_ops/quantize_ops_cpu.cpp
src/sparse_ops/sparse_ops_cpu.cpp
src/sparse_ops/sparse_ops_meta.cpp
src/embedding_inplace_update_cpu.cpp)
Expand All @@ -569,7 +569,7 @@ if(NOT FBGEMM_CPU_ONLY)
src/layout_transform_ops_gpu.cpp
src/permute_pooled_embedding_ops_gpu.cpp
src/permute_pooled_embedding_ops_split_gpu.cpp
src/quantize_ops_gpu.cpp
src/quantize_ops/quantize_ops_gpu.cpp
src/sparse_ops/sparse_ops_gpu.cpp
src/split_embeddings_utils.cpp
src/split_table_batched_embeddings.cpp
Expand Down Expand Up @@ -631,7 +631,11 @@ if(NOT FBGEMM_CPU_ONLY)
src/metric_ops.cu
src/permute_pooled_embedding_ops_split.cu
src/permute_pooled_embedding_ops.cu
src/quantize_ops.cu
src/quantize_ops/quantize_bfloat16.cu
src/quantize_ops/quantize_hfp8.cu
src/quantize_ops/quantize_msfp.cu
src/quantize_ops/quantize_ops.cu
src/quantize_ops/quantize_padded_fp8_rowwise.cu
src/sparse_ops/sparse_async_cumsum.cu
src/sparse_ops/sparse_block_bucketize_features.cu
src/sparse_ops/sparse_bucketize_features.cu
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/include/fbgemm_gpu/ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <ATen/ATen.h>
#include <torch/library.h>

/*
* We annotate the public FBGEMM functions and hide the rest. Those
Expand Down
32 changes: 32 additions & 0 deletions fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/TensorIterator.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef __HIP_PLATFORM_HCC__
#include <math_constants.h>
#endif

#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/ops_utils.h"
#include "fbgemm_gpu/quantize_ops.cuh"
#include "fbgemm_gpu/quantize_ops_utils.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

#define QUANTIZE_OPS_MAX(a, b) ((a) > (b) ? (a) : (b))
#define QUANTIZE_OPS_MIN(a, b) ((a) < (b) ? (a) : (b))

using Tensor = at::Tensor;
72 changes: 72 additions & 0 deletions fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "common.cuh"

using Tensor = at::Tensor;

/// @defgroup quantize-data-cuda Quantization Data CUDA Operators
/// The following are CUDA Operators

namespace fbgemm_gpu {

///@ingroup quantize-data-cuda
DLL_PUBLIC at::Tensor _float_to_bfloat16_gpu(const at::Tensor& input) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());

// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
// NCCL input.options().dtype(at::kBFloat16)); // at::kBFloat16
auto output = at::empty({}, input.options().dtype(at::kHalf));
output.resize_(0);

auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_input(input)
.build();
at::native::gpu_kernel(iter, [] GPU_LAMBDA(float in) -> at::Half {
fbgemm_gpu::fint32 temp;
temp.F = in;
return at::Half((temp.I + (1 << 15)) >> 16, at::Half::from_bits());
});

return output;
}

///@ingroup quantize-data-cuda
DLL_PUBLIC at::Tensor _bfloat16_to_float_gpu(const at::Tensor& input) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());

auto output = at::empty({}, input.options().dtype(at::kFloat));
output.resize_(0);
auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_input(input)
.build();

at::native::gpu_kernel(iter, [] GPU_LAMBDA(at::Half in) -> float {
fbgemm_gpu::fint32 temp;
temp.I = in.x << 16;
return temp.F;
});
return output;
}

} // namespace fbgemm_gpu

FBGEMM_OP_DISPATCH(
CUDA,
"Bfloat16QuantizedToFloat",
fbgemm_gpu::_bfloat16_to_float_gpu);
FBGEMM_OP_DISPATCH(
CUDA,
"FloatToBfloat16Quantized",
fbgemm_gpu::_float_to_bfloat16_gpu);
80 changes: 80 additions & 0 deletions fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "common.cuh"

using Tensor = at::Tensor;

/// @defgroup quantize-data-cuda Quantization Data CUDA Operators
/// The following are CUDA Operators

namespace fbgemm_gpu {

DLL_PUBLIC at::Tensor _float_to_hfp8_gpu(
const at::Tensor& input,
const int64_t ebits,
const int64_t exponent_bias,
const double max_pos) {
TORCH_CHECK(ebits > 0);
TORCH_CHECK(exponent_bias > 0);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());

auto output = at::empty({}, input.options().dtype(at::kByte));
output.resize_(0);

auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_input(input)
.build();

at::native::gpu_kernel(iter, [=] GPU_LAMBDA(float in) -> uint8_t {
return float_to_hfp8(in, ebits, exponent_bias, max_pos);
});

return output;
}

DLL_PUBLIC at::Tensor _hfp8_to_float_gpu(
const at::Tensor& input,
const int64_t ebits,
const int64_t exponent_bias) {
TORCH_CHECK(ebits > 0);
TORCH_CHECK(exponent_bias > 0);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());

auto output = at::empty({}, input.options().dtype(at::kFloat));
output.resize_(0);

auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_input(input)
.build();

at::native::gpu_kernel(iter, [=] GPU_LAMBDA(uint8_t in) -> float {
return hfp8_to_float(in, ebits, exponent_bias);
});

return output;
}

} // namespace fbgemm_gpu

FBGEMM_OP_DISPATCH(
CUDA,
"FloatToHFP8Quantized",
fbgemm_gpu::_float_to_hfp8_gpu);
FBGEMM_OP_DISPATCH(
CUDA,
"HFP8QuantizedToFloat",
fbgemm_gpu::_hfp8_to_float_gpu);
Loading

0 comments on commit f94db2c

Please sign in to comment.