Skip to content

Commit

Permalink
Switch default MX4 rounding mode to Even (pytorch#3111)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3111

X-link: facebookresearch/FBGEMM#198

Even rounding is a bit faster and more accurate than ceil rounding. This diff swtiches the default to it.

Reviewed By: jspark1105

Differential Revision: D62466094

fbshipit-source-id: 9e80c49f536332ae65c665df7b325cecdbfef92b
  • Loading branch information
jwfromm authored and facebook-github-bot committed Sep 19, 2024
1 parent ebbebd4 commit 55f4c78
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def quantize_mx(
elem_mbits: int = 3,
elem_max_norm: float = 6.0,
mx_group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil,
rounding_mode: Union[RoundingMode, int] = RoundingMode.even,
) -> torch.Tensor:
"""
Registered quantize_mx ops for E2E comm.
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class QuantizationContext:
row_dim: int = ROW_DIM_DEFAULT
row_dim_quant: int = -1
mx_group_size: int = MX_GROUP_SIZE_DEFAULT
rounding_mode: RoundingMode = RoundingMode.ceil
rounding_mode: RoundingMode = RoundingMode.even
padded_dim_sum_per_rank: Optional[List[int]] = None


Expand Down Expand Up @@ -110,7 +110,7 @@ def _quantize_tensor(
return input_quant_all2all
elif comm_precision == SparseType.MX4:
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
rounding_mode = ctx.rounding_mode if ctx is not None else RoundingMode.ceil
rounding_mode = ctx.rounding_mode if ctx is not None else RoundingMode.even
return fp32_to_mx4(
input_tensor, mx_group_size, rounding_mode=rounding_mode
).view(-1)
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fp32_to_mx4(
group_size: int = 32,
ebits: int = 2,
mbits: int = 1,
rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.ceil,
rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.even,
stochastic_casting: bool = False,
use_triton: bool = True,
) -> torch.Tensor:
Expand All @@ -58,7 +58,7 @@ def fp32_to_mx4(
# Accelerated MX4 is only available on cuda, if input is on cpu, use python.
# Operate on flattened input.
if rounding_mode is None:
rounding_mode = RoundingMode.ceil
rounding_mode = RoundingMode.even

if not tensor.is_cuda:
return py_quantize_mx4(
Expand Down

0 comments on commit 55f4c78

Please sign in to comment.