Skip to content

Commit

Permalink
support quantize_fp8_row for up to 4d non contiguous tensor (#3508)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#589

reland D66990975 with fix for the NaN issued observed during LLaMa4 17B model run with fp8_rowwise FFN

Specifically, offset was not properly updated when loading/storing data.

Differential Revision: D67303282
  • Loading branch information
Jingyuan Fan authored and facebook-github-bot committed Dec 16, 2024
1 parent b8796b3 commit 0b9537a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 41 deletions.
53 changes: 39 additions & 14 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import itertools
import unittest
from typing import Optional, Tuple

Expand Down Expand Up @@ -37,38 +38,62 @@ def _test_quantize_fp8_row(
device: torch.device,
output_device: Optional[torch.device] = None,
use_scale_ub: bool = False,
transpose_inputs: bool = False,
) -> None:
a = torch.randn(shape, dtype=torch.bfloat16, device=device)

inputs = [a]
# if transpose_inputs is true, get all possible dimension combinations
# of the input tensor and transposes each pair
if transpose_inputs:
dims = range(a.ndim)
for dim1, dim2 in itertools.combinations(dims, 2):
dims_list = list(dims)
dims_list[dim1], dims_list[dim2] = dims_list[dim2], dims_list[dim1]
inputs.append(a.permute(dims_list))
scale_ub = (
torch.tensor([1200], dtype=torch.float, device=device)
if use_scale_ub
else None
)
for input_a in inputs:
a_fp8, a_scale = quantize_fp8_row(
input_a,
scale_ub=scale_ub,
use_triton=use_triton,
output_device=output_device,
)

a_fp8, a_scale = quantize_fp8_row(
a, scale_ub=scale_ub, use_triton=use_triton, output_device=output_device
)
# Undo scaling.
a_torch = a_fp8.to(torch.bfloat16)
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
a_torch *= a_scale.view(broadcast_shape)

# Undo scaling.
a_torch = a_fp8.to(torch.bfloat16)
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
a_torch *= a_scale.view(broadcast_shape)

self.assertTrue(
torch.allclose(
a.to(device=output_device), a_torch, atol=2e-1, rtol=1e-1
self.assertTrue(
torch.allclose(
input_a.to(device=output_device), a_torch, atol=2e-1, rtol=1e-1
)
)
)

_test_quantize_fp8_row((2, 3), True, torch.device("cuda"))
for n_col in range(1, 9000, 100):
_test_quantize_fp8_row((2, n_col), True, torch.device("cuda"))
# Test with batched input.
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda"))
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cuda"))
# Test with non-contiguous input
_test_quantize_fp8_row(
(4, 2, 3), True, torch.device("cuda"), transpose_inputs=True
)
_test_quantize_fp8_row(
(6, 4, 2, 3), True, torch.device("cuda"), transpose_inputs=True
)
_test_quantize_fp8_row((2, 3), True, torch.device("cuda"), use_scale_ub=True)
# Test with cpu
_test_quantize_fp8_row((2, 3), False, torch.device("cpu"), torch.device("cuda"))
_test_quantize_fp8_row(
(2, 3), False, torch.device("cpu"), torch.device("cuda"), use_scale_ub=True
)
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cpu"))
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cpu"))

def test_scale_fp8_row(self) -> None:
def _test_scale_fp8_row(
Expand Down
85 changes: 58 additions & 27 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,20 +1945,26 @@ def prep_matmul(
Config({"BLOCK_SIZE": 4096}),
Config({"BLOCK_SIZE": 8192}),
],
key=["N"],
key=["K"],
)
@triton.jit
def _kernel_quantize_fp8_row(
A,
A_scale,
A_fp8,
scale_ub,
B,
M,
N,
K,
stride_ab,
stride_am,
stride_an,
stride_ak,
stride_ob,
stride_om,
stride_on,
stride_ok,
TL_FP8_DTYPE: tl.constexpr,
MAX_FP8: tl.constexpr,
EPS: tl.constexpr,
Expand All @@ -1977,16 +1983,22 @@ def _kernel_quantize_fp8_row(
* Better tiling schemes.
Args:
A (Tensor): [m, n] higher precision input tensor.
A_scale (Tensor): [m] reciprocal scale tensor per row.
A_fp8 (Tensor): [m, n] fp8 scaled tensor. A_fp8 = A / a_scale
A (Tensor): higher precision input tensor of 4 dimension.
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
scale_ub (Tensor): [1] Maximum value allowed for scale.
M (int): Number of rows.
N (int): Number of columns.
B (int): Size of dimenion 0
M (int): Size of dimenion 1
N (int): Size of dimenion 2
K (int): Size of dimenion 3
stride_ab (int): Stride of b dimension of A.
stride_am (int): Stride of m dimension of A.
stride_an (int): Stride of n dimension of A.
stride_ak (int): Stride of k dimension of A.
stride_ob (int): Stride of b dimension of output.
stride_om (int): Stride of m dimension of output.
stride_on (int): Stride of n dimension of output.
stride_ok (int): Stride of k dimension of output.
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
MAX_FP8 (float): Maxmimum expressible value for FP8.
EPS (float): Epsilon value for numerical stability.
Expand All @@ -2000,16 +2012,25 @@ def _kernel_quantize_fp8_row(
if USE_INT64:
pid = pid.to(tl.int64)
n_offset = tl.arange(0, BLOCK_SIZE)
a_offset_base = (
pid // (M * N) * stride_ab
+ (pid % (M * N)) // N * stride_am
+ (pid % (M * N)) % N * stride_an
)
a_fp8_offset_base = (
pid // (M * N) * stride_ob
+ (pid % (M * N)) // N * stride_om
+ (pid % (M * N)) % N * stride_on
)

# Calculate max.
cur_max = 0.0
for _k in range(0, tl.cdiv(N, BLOCK_SIZE)):
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
a = tl.load(
A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0
A + a_offset_base + n_offset * stride_ak, mask=n_offset < K, other=0.0
)
tile_max = tl.max(tl.abs(a))
cur_max = tl.maximum(tile_max, cur_max)

n_offset += BLOCK_SIZE

# Clamp max value appropriately.
Expand All @@ -2022,9 +2043,10 @@ def _kernel_quantize_fp8_row(
a_scale = MAX_FP8 / cur_max
tl.store(A_scale + pid, 1.0 / a_scale)
n_offset = tl.arange(0, BLOCK_SIZE)
for _k in range(0, tl.cdiv(N, BLOCK_SIZE)):

for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
a = tl.load(
A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0
A + a_offset_base + n_offset * stride_ak, mask=n_offset < K, other=0.0
)
a_fp8 = a * a_scale
# Clamp A to fp8 range to make sure there's no overflow.
Expand All @@ -2033,7 +2055,7 @@ def _kernel_quantize_fp8_row(
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
a_fp8.to(TL_FP8_DTYPE)
tl.store(
A_fp8 + pid * stride_om + n_offset * stride_on, a_fp8, mask=n_offset < N
A_fp8 + a_fp8_offset_base + n_offset * stride_ok, a_fp8, mask=n_offset < K
)
n_offset += BLOCK_SIZE

Expand All @@ -2045,20 +2067,18 @@ def triton_quantize_fp8_row(
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
Args:
a (Tensor): [m, n] higher precision input tensor.
a (Tensor): higher precision input tensor of 4 dimension.
scale_ub (Tensor): Maximum allowed value for scale.
Returns:
torch.Tensor: fp8 scaled tensor.
torch.Tensor: reciprocal scale tensor per row.
"""
a_shape = a.shape
a = a.view(-1, a.size(-1))
# Get constant values.
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
num_rows = a.shape[0]
num_rows = a.numel() // a.shape[-1]
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
a_fp8 = torch.empty((a.shape[0], a.shape[1]), device=a.device, dtype=pt_dtype)
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)

# If input tensor is sufficiently large, we need to use int64 indexing.
use_int64 = a.numel() > (2**31 - 1)
Expand All @@ -2070,18 +2090,24 @@ def triton_quantize_fp8_row(
scale_ub,
a.shape[0],
a.shape[1],
a.shape[2],
a.shape[3],
a.stride(0),
a.stride(1),
a.stride(2),
a.stride(3),
a_fp8.stride(0),
a_fp8.stride(1),
a_fp8.stride(2),
a_fp8.stride(3),
TL_FP8_DTYPE=tl_dtype,
MAX_FP8=max_fp8,
EPS=eps,
CLAMP_MAX=scale_ub is not None,
USE_INT64=use_int64,
)

return a_fp8.view(a_shape), a_scale
return a_fp8, a_scale


@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
Expand All @@ -2095,7 +2121,7 @@ def quantize_fp8_row(
Quantize a to fp8 with row-wise scalings and optionally move to output device.
Args:
a (Tensor): Input high precision tensor.
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
scale_ub (Tensor): Maximum allowed value for scale.
use_triton (bool): Whether to use triton kernel or pytorch.
output_device (torch.device): Device to optionally move the scaled tensors to.
Expand All @@ -2104,36 +2130,41 @@ def quantize_fp8_row(
torch.Tensor: fp8 scaled tensor.
torch.Tensor: The reciprocal scale tensor per row.
"""
a_shape = a.shape
a = a.view(-1, a.size(-1))

if a.device == torch.device("cpu"):
logger.info("Triton does not support cpu, falling back to torch ops.")
use_triton = False
if use_triton:
aq, a_scale = triton_quantize_fp8_row(a, scale_ub)
return aq.view(a_shape), a_scale
assert (
a.dim() <= 4
), "Only up to 4 dimension input tensor is supported if use_triton is True"
a_shape = a.shape
while a.dim() < 4:
a = a.unsqueeze(0)
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub)
return a_fp8.view(a_shape), a_scale
# else use pytorch implementation.
if not output_device:
output_device = a.device

# Get constants.
pt_dtype, _, max_fp8, eps = get_fp8_constants()
row_max: torch.Tensor = torch.max(torch.abs(a), dim=1)[0]
row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
# Apply clamping.
if scale_ub is not None:
row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
else:
# pyre-ignore[6]: Incompatible parameter type [6]
row_max = torch.clamp(row_max, min=eps)
a_scale = torch.empty((a.shape[0]), dtype=torch.float32, device=output_device)
a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
a_fp8 = a * a_scale[:, None] # pyre-ignore
a_fp8 = a * a_scale[..., None] # pyre-ignore
# Cast and move data to output device (for cpu weight loading).
a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
a_scale = a_scale.to(output_device) # pyre-ignore
del a
return a_fp8.view(a_shape), 1 / a_scale # pyre-ignore
return a_fp8, 1 / a_scale # pyre-ignore


@quantize_fp8_row.register_fake
Expand Down

0 comments on commit 0b9537a

Please sign in to comment.