forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MX4 ops front-end API (pytorch#2777)
Summary: X-link: facebookresearch/FBGEMM#12 Pull Request resolved: pytorch#2777 - Add shim function to handle MX4 ops API to be used in A2A comm - Register the ops once to avoid error ``` the model failed with error “Tried to register an operator (fbgemm::quantize_mx(Tensor input, int scale_bits, int elem_ebits, int elem_mbits, float elem_max_norm, int mx_group_size) -> Tensor) with the same name and overload name multiple times. Each overload's schema should only be registered with a single call to def(). Duplicate registration: registered at /dev/null:393 ``` Reviewed By: sryap Differential Revision: D58627286 fbshipit-source-id: c659e46f2828ce6f8b9c703c3a72ba1832b2a73a
- Loading branch information
1 parent
3d6e0c7
commit 468673d
Showing
6 changed files
with
200 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import torch | ||
from fbgemm_gpu.quantize.quantize_ops import dequantize_mx, quantize_mx # noqa F401 | ||
|
||
|
||
def op_registeration( | ||
lib, # pyre-ignore[2] | ||
op_name, # pyre-ignore[2] | ||
fn, # pyre-ignore[2] | ||
dispatch_key, # pyre-ignore[2] | ||
) -> None: | ||
""" | ||
Registers an op with the given name and dispatch key only once. | ||
Args: | ||
lib: torch.library (e.g., torch.library.Library("fbgemm", "FRAGMENT")) | ||
op_name: operator name | ||
fn: function that's the operator implementation for the input dispatch key | ||
dispatch_key: dispatch key that the function should be registered for (e.g., "CUDA") | ||
Returns: | ||
None | ||
Example: | ||
lib = torch.library.Library("fbgemm", "FRAGMENT") | ||
lib.define(...) | ||
op_registeration(lib, "quantize_mx", quantize_mx, "CUDA") | ||
""" | ||
full_op_name = "fbgemm::" + op_name | ||
if not torch._C._dispatch_has_kernel_for_dispatch_key(full_op_name, dispatch_key): | ||
lib.impl(op_name, fn, dispatch_key) | ||
|
||
|
||
lib = torch.library.Library("fbgemm", "FRAGMENT") | ||
|
||
if "fbgemm::quantize_mx" not in torch.library._defs: | ||
lib.define( | ||
"""quantize_mx( | ||
Tensor input, | ||
int scale_bits, | ||
int elem_ebits, | ||
int elem_mbits, | ||
float elem_max_norm, | ||
int mx_group_size | ||
) -> Tensor | ||
""" | ||
) | ||
|
||
if "fbgemm::dequantize_mx" not in torch.library._defs: | ||
lib.define( | ||
"""dequantize_mx( | ||
Tensor input, | ||
int mx_group_size | ||
) -> Tensor | ||
""" | ||
) | ||
|
||
op_registeration(lib, "quantize_mx", quantize_mx, "CUDA") | ||
op_registeration(lib, "quantize_mx", quantize_mx, "CPU") | ||
op_registeration(lib, "dequantize_mx", dequantize_mx, "CUDA") | ||
op_registeration(lib, "dequantize_mx", dequantize_mx, "CPU") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# # pyre-unsafe | ||
|
||
import torch | ||
|
||
from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32 | ||
|
||
|
||
def quantize_mx( | ||
input: torch.Tensor, | ||
scale_bits: int = 8, | ||
elem_ebits: int = 2, | ||
elem_mbits: int = 3, | ||
elem_max_norm: float = 6.0, | ||
mx_group_size: int = 32, | ||
) -> torch.Tensor: | ||
""" | ||
Registered quantize_mx ops for E2E comm. | ||
(registration is done in __init__.py) | ||
We use Triton implementation for quantization | ||
Args: | ||
input: FP32 tensor of size total_elems to be quantized | ||
scale_bits: num bits of the shared exponent (i.e., 8 for MX4 e2m1) | ||
elem_ebits: num bits of the exponent (i.e., 2 for MX4 e2m1) | ||
elem_mbits: num bits of the mantissa incl. sign and implicit bits ( | ||
i.e., 3 for MX4 e2m1) | ||
elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1) | ||
mx_group_size: num elements that share the max shared_exponent | ||
Return: | ||
output: MX4 tensor packed into int8 values with size | ||
(total_elems / 2 + total_elems / groupsize) | ||
the shared exponent of each group is stored at the last byte | ||
of output of each group | ||
""" | ||
return fp32_to_mx4(input, mx_group_size, use_triton=True) | ||
|
||
|
||
def dequantize_mx( | ||
input: torch.Tensor, | ||
mx_group_size: int = 32, | ||
) -> torch.Tensor: | ||
""" | ||
Registered dequantize_mx ops for E2E comm | ||
(registration is done in __init__.py to prevent multiple loading) | ||
We use triton implementation for quantization | ||
Args: | ||
input: FP8 tensor (MX4 packed in FP8) | ||
mx_group_size: number of elements that shares the same max shared_exponent | ||
Return: | ||
output: FP32 tensor with total elements (total_elems) | ||
""" | ||
return mx4_to_fp32(input, mx_group_size, use_triton=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters