Skip to content

Commit

Permalink
MX4 ops front-end API (pytorch#2777)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#12

Pull Request resolved: pytorch#2777

- Add shim function to handle MX4 ops API to be used in A2A comm
- Register the ops once to avoid error

```
 the model failed with error “Tried to register an operator (fbgemm::quantize_mx(Tensor input, int scale_bits, int elem_ebits, int elem_mbits, float elem_max_norm, int mx_group_size) -> Tensor) with the same name and overload name multiple times. Each overload's schema should only be registered with a single call to def(). Duplicate registration: registered at /dev/null:393
```

Reviewed By: sryap

Differential Revision: D58627286

fbshipit-source-id: c659e46f2828ce6f8b9c703c3a72ba1832b2a73a
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jul 12, 2024
1 parent 3d6e0c7 commit 468673d
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 29 deletions.
68 changes: 68 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/quantize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import torch
from fbgemm_gpu.quantize.quantize_ops import dequantize_mx, quantize_mx # noqa F401


def op_registeration(
lib, # pyre-ignore[2]
op_name, # pyre-ignore[2]
fn, # pyre-ignore[2]
dispatch_key, # pyre-ignore[2]
) -> None:
"""
Registers an op with the given name and dispatch key only once.
Args:
lib: torch.library (e.g., torch.library.Library("fbgemm", "FRAGMENT"))
op_name: operator name
fn: function that's the operator implementation for the input dispatch key
dispatch_key: dispatch key that the function should be registered for (e.g., "CUDA")
Returns:
None
Example:
lib = torch.library.Library("fbgemm", "FRAGMENT")
lib.define(...)
op_registeration(lib, "quantize_mx", quantize_mx, "CUDA")
"""
full_op_name = "fbgemm::" + op_name
if not torch._C._dispatch_has_kernel_for_dispatch_key(full_op_name, dispatch_key):
lib.impl(op_name, fn, dispatch_key)


lib = torch.library.Library("fbgemm", "FRAGMENT")

if "fbgemm::quantize_mx" not in torch.library._defs:
lib.define(
"""quantize_mx(
Tensor input,
int scale_bits,
int elem_ebits,
int elem_mbits,
float elem_max_norm,
int mx_group_size
) -> Tensor
"""
)

if "fbgemm::dequantize_mx" not in torch.library._defs:
lib.define(
"""dequantize_mx(
Tensor input,
int mx_group_size
) -> Tensor
"""
)

op_registeration(lib, "quantize_mx", quantize_mx, "CUDA")
op_registeration(lib, "quantize_mx", quantize_mx, "CPU")
op_registeration(lib, "dequantize_mx", dequantize_mx, "CUDA")
op_registeration(lib, "dequantize_mx", dequantize_mx, "CPU")
59 changes: 59 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# # pyre-unsafe

import torch

from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32


def quantize_mx(
input: torch.Tensor,
scale_bits: int = 8,
elem_ebits: int = 2,
elem_mbits: int = 3,
elem_max_norm: float = 6.0,
mx_group_size: int = 32,
) -> torch.Tensor:
"""
Registered quantize_mx ops for E2E comm.
(registration is done in __init__.py)
We use Triton implementation for quantization
Args:
input: FP32 tensor of size total_elems to be quantized
scale_bits: num bits of the shared exponent (i.e., 8 for MX4 e2m1)
elem_ebits: num bits of the exponent (i.e., 2 for MX4 e2m1)
elem_mbits: num bits of the mantissa incl. sign and implicit bits (
i.e., 3 for MX4 e2m1)
elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1)
mx_group_size: num elements that share the max shared_exponent
Return:
output: MX4 tensor packed into int8 values with size
(total_elems / 2 + total_elems / groupsize)
the shared exponent of each group is stored at the last byte
of output of each group
"""
return fp32_to_mx4(input, mx_group_size, use_triton=True)


def dequantize_mx(
input: torch.Tensor,
mx_group_size: int = 32,
) -> torch.Tensor:
"""
Registered dequantize_mx ops for E2E comm
(registration is done in __init__.py to prevent multiple loading)
We use triton implementation for quantization
Args:
input: FP8 tensor (MX4 packed in FP8)
mx_group_size: number of elements that shares the same max shared_exponent
Return:
output: FP32 tensor with total elements (total_elems)
"""
return mx4_to_fp32(input, mx_group_size, use_triton=True)
20 changes: 16 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
fp32_to_bf16_with_clamp,
fp32_to_fp16_with_clamp,
fp32_to_hfp8_with_clamp,
fp32_to_mx4,
hfp8_to_fp32,
mx4_to_fp32,
)

from fbgemm_gpu.split_embedding_configs import SparseType
from torch.autograd.profiler import record_function # usort:skip
from dataclasses import dataclass

import fbgemm_gpu.quantize.quantize_ops # noqa F401

logger: logging.Logger = logging.getLogger()

Expand Down Expand Up @@ -100,7 +100,15 @@ def _quantize_tensor(
return input_quant_all2all
elif comm_precision == SparseType.MX4:
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
return fp32_to_mx4(input_tensor, mx_group_size)
quantized_output = torch.ops.fbgemm.quantize_mx(
input=input_tensor,
scale_bits=8,
elem_ebits=2,
elem_mbits=3,
elem_max_norm=6.0,
mx_group_size=mx_group_size,
)
return quantized_output
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")

Expand Down Expand Up @@ -141,7 +149,11 @@ def _dequantize_tensor(
return dequant_tensor.view(-1)
elif comm_precision == SparseType.MX4:
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
return mx4_to_fp32(quantized_tensor, mx_group_size)
dequant_tensor = torch.ops.fbgemm.dequantize_mx(
input=quantized_tensor,
mx_group_size=mx_group_size,
)
return dequant_tensor.view(-1)
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")

Expand Down
27 changes: 18 additions & 9 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import logging
import math

import torch

Expand Down Expand Up @@ -44,13 +45,23 @@ def fp32_to_mx4(
output: MX4 tensor packed into int8 values with total elements (M / 2 + M / groupsize)
"""
# Accelerated MX4 is only available on cuda, if input is on cpu, use python.
# For CPU and triton, set the second dim to 2048 or the nearest power of 2.
dim = (
2048 if tensor.numel() >= 2048 else 2 ** (math.floor(math.log2(tensor.numel())))
)
input = (
tensor.view(-1)
if (tensor.is_cuda and not use_triton) or tensor.numel() % dim != 0
else tensor.view(-1, dim)
)
if not tensor.is_cuda:
return py_quantize_mx4(tensor, group_size)
return py_quantize_mx4(input, group_size)

if use_triton:
return quantize_mx4(tensor, group_size)
return quantize_mx4(input, group_size)
else:
out = torch.ops.fbgemm.quantize_mx_cuda(
tensor.view(-1),
input,
scale_bits=8,
elem_ebits=2,
elem_mbits=3,
Expand All @@ -75,16 +86,14 @@ def mx4_to_fp32(
Return:
output: FP32 tensor with total elements (M).
"""
flatten_tensor = tensor.view(-1)
# Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
if not tensor.is_cuda:
return py_dequantize_mx4(tensor, group_size)
return py_dequantize_mx4(flatten_tensor, group_size)
if use_triton:
return dequantize_mx4(tensor, group_size)
return dequantize_mx4(flatten_tensor, group_size)
else:
out = torch.ops.fbgemm.dequantize_mx_cuda(tensor.view(-1), group_size)
# Perserve input dimensions.
output_shape = list(tensor.shape[:-1]) + [-1]
return out.view(output_shape)
return torch.ops.fbgemm.dequantize_mx_cuda(flatten_tensor, group_size)


def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/quantize/comm_codec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class QuantizedCommCodecTest(unittest.TestCase):
@settings(deadline=4000)
@settings(deadline=8000)
# pyre-ignore
@given(
comm_precisions_loss_scale=st.sampled_from(
Expand Down
53 changes: 38 additions & 15 deletions fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import unittest
from typing import List

import fbgemm_gpu.quantize.quantize_ops # noqa F401

import hypothesis.strategies as st

import torch

from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4

Expand Down Expand Up @@ -81,10 +84,6 @@ def fake_quantize_mx(
) -> torch.Tensor:
"""Function used for MX* fake quantization"""

####################
# Python Quantize
####################

# Make sure axes is a list of non-negative numbers
axes = [x + A.ndim if x < 0 else x for x in axes]

Expand Down Expand Up @@ -161,6 +160,7 @@ def test_mx4(self, power: int, sizes: int) -> None:
ebits, mbits, emax, max_norm, _ = _get_format_params(element_format_str)
scale_bits = 8

# Reference from mx_github
output_ref = fake_quantize_mx(
input,
scale_bits,
Expand All @@ -172,7 +172,8 @@ def test_mx4(self, power: int, sizes: int) -> None:
group_size=group_size,
)

output = fake_quantize_mx_cuda(
# Test CUDA implementation
output_cuda = fake_quantize_mx_cuda(
input,
scale_bits,
ebits,
Expand All @@ -183,16 +184,38 @@ def test_mx4(self, power: int, sizes: int) -> None:
)

# Test intercompatibility between implementations.
py_mx_q_input = py_quantize_mx4(input, group_size)
py_mx_output = py_dequantize_mx4(py_mx_q_input, group_size)
triton_mx_q_input = fp32_to_mx4(input, group_size, use_triton=True)
cuda_mx_output = mx4_to_fp32(triton_mx_q_input, group_size, use_triton=False)
triton_mx_output = mx4_to_fp32(triton_mx_q_input, group_size, use_triton=True)

check_diff_quantize(input, py_mx_output, output_ref)
check_diff_quantize(input, cuda_mx_output, output_ref)
check_diff_quantize(input, triton_mx_output, output_ref)
check_diff_quantize(input, output, output_ref)
# Test CPU implementation
quantized_cpu = py_quantize_mx4(input, group_size)
output_cpu = py_dequantize_mx4(quantized_cpu, group_size)

# Test Triton implementation
quantized_triton = fp32_to_mx4(input, group_size, use_triton=True)
output_triton = mx4_to_fp32(quantized_triton, group_size, use_triton=True)

# Test shim functions
output_cuda_from_quantized_triton = mx4_to_fp32(
quantized_triton, group_size, use_triton=False
)

# Test torch.ops
quantized_from_ops = torch.ops.fbgemm.quantize_mx(
input,
scale_bits,
ebits,
mbits,
max_norm,
mx_group_size=group_size,
)
output_from_ops = torch.ops.fbgemm.dequantize_mx(
quantized_from_ops,
mx_group_size=group_size,
)

check_diff_quantize(input, output_ref, output_cuda)
check_diff_quantize(input, output_cuda, output_cuda_from_quantized_triton)
check_diff_quantize(input, output_cuda_from_quantized_triton, output_triton)
check_diff_quantize(input, output_triton, output_cpu)
check_diff_quantize(input, output_cuda, output_from_ops)


if __name__ == "__main__":
Expand Down

0 comments on commit 468673d

Please sign in to comment.