Skip to content

Commit

Permalink
Support stochastic rounding for FP32/BF16 -> FP8 conversion (pytorch#…
Browse files Browse the repository at this point in the history
…2677)

Summary:
Pull Request resolved: pytorch#2677

Support stochastic rounding with assemble float algorithm when doing FP32/BF16 --> FP8 stochastic rounding conversions.

Reviewed By: xintwfb

Differential Revision: D58061352

fbshipit-source-id: 3532be37986f401984b3c32dfe0d56b4d0c43582
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jun 17, 2024
1 parent e5d0c94 commit 74fc8d3
Show file tree
Hide file tree
Showing 3 changed files with 370 additions and 53 deletions.
21 changes: 13 additions & 8 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ std::tuple<at::Tensor, at::Tensor> per_tensor_dynamic_quantize_i8(at::Tensor X);
std::vector<at::Tensor> quantize_fp8_per_tensor(
at::Tensor input,
std::optional<at::Tensor> bs, // batch size
std::optional<at::Tensor> scale_ub); // scale upperbound
std::optional<at::Tensor> scale_ub, // scale upperbound
const bool stochastic_rounding); // whether apply stochastic rounding

std::vector<at::Tensor> quantize_fp8_per_row(
at::Tensor input,
std::optional<at::Tensor> bs, // batch size
std::optional<at::Tensor> scale_ub, // scale upperbound
std::optional<c10::ScalarType> output_dtype); // output dtype
std::optional<c10::ScalarType> output_dtype, // output dtype
bool stochastic_rounding); // whether apply stochastic rounding

#if CUDART_VERSION >= 12000
std::vector<at::Tensor> quantize_fp8_per_col(
Expand All @@ -93,7 +95,8 @@ std::vector<at::Tensor> quantize_fp8_per_col(
at::Tensor quantize_fp8_per_tensor_fixed_scale(
at::Tensor input,
at::Tensor scale,
std::optional<at::Tensor> bs);
std::optional<at::Tensor> bs,
bool stochatic_rounding);

at::Tensor get_fp8_per_tensor_scale(
at::Tensor input,
Expand Down Expand Up @@ -146,10 +149,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// quantize_ops with
// torch.ops.load_library
m.def(
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> Tensor[]");
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, bool stochastic_rounding=False) -> Tensor[]");
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.def(
"quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None) -> Tensor[]");
"quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None, bool stochastic_rounding=False) -> Tensor[]");
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);

#if CUDART_VERSION >= 12000
Expand All @@ -163,7 +166,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl("get_fp8_per_tensor_scale", get_fp8_per_tensor_scale);

m.def(
"quantize_fp8_per_tensor_fixed_scale(Tensor input, Tensor scale, Tensor? bs=None) -> Tensor");
"quantize_fp8_per_tensor_fixed_scale(Tensor input, Tensor scale, Tensor? bs=None, bool stochatic_rounding=False) -> Tensor");
m.impl(
"quantize_fp8_per_tensor_fixed_scale",
quantize_fp8_per_tensor_fixed_scale);
Expand Down Expand Up @@ -209,7 +212,8 @@ at::Tensor f8f8bf16_rowwise_meta(
std::vector<at::Tensor> quantize_fp8_per_tensor_meta(
at::Tensor X,
std::optional<at::Tensor> bs,
std::optional<at::Tensor> scale_ub) {
std::optional<at::Tensor> /*scale_ub*/,
const bool /*stochastic_rounding*/) {
auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn));
auto scale = at::empty({}, X.options().dtype(at::kBFloat16));
return {Y, scale};
Expand Down Expand Up @@ -277,7 +281,8 @@ std::vector<at::Tensor> quantize_fp8_per_row_meta(
at::Tensor input,
std::optional<at::Tensor> bs,
std::optional<at::Tensor> scale_ub,
std::optional<c10::ScalarType> output_dtype) {
std::optional<c10::ScalarType> /* output_dtype */,
bool /* stochastic_rounding */) {
const at::SymInt M = input.sym_size(0);
auto Y = at::empty_like(input, input.options().dtype(at::kFloat8_e4m3fn));
auto scale = at::empty_symint({M}, input.options().dtype(at::kFloat));
Expand Down
Loading

0 comments on commit 74fc8d3

Please sign in to comment.