Skip to content

Commit

Permalink
Add CPU weight loading support for block-wise fp8 (pytorch#2676)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2676

Add cpu block-wise quantization
Add cpu weight loading to ffn_quantize_mode=fp8_blockwise

Reviewed By: jiawenliu64

Differential Revision: D57937205

fbshipit-source-id: 48b455ca4e3b7e7d2123bc0ac5d87eeaf3e72dea
  • Loading branch information
choutim authored and facebook-github-bot committed Jun 4, 2024
1 parent 08a0c22 commit 65bcbc4
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 6 deletions.
17 changes: 12 additions & 5 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,20 @@ def _test_matmul_fp8_block(
shape: Tuple[int, int, int],
block_shape: Tuple[int, int, int],
fp8_fast_accum: bool,
device: str = "cuda",
) -> None:
M, N, K = shape
BLOCK_M, BLOCK_N, BLOCK_K = block_shape
a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
b = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N, K, dtype=torch.bfloat16, device=device)

# Quantize inputs.
a_fp8, a_scale = quantize_fp8_block(a, BLOCK_M, BLOCK_K)
b_fp8, b_scale = quantize_fp8_block(b, BLOCK_N, BLOCK_K)
a_fp8, a_scale = quantize_fp8_block(
a, BLOCK_M, BLOCK_K, output_device=torch.device("cuda")
)
b_fp8, b_scale = quantize_fp8_block(
b, BLOCK_N, BLOCK_K, output_device=torch.device("cuda")
)

result = matmul_fp8_block(
a_fp8,
Expand All @@ -153,7 +158,7 @@ def _test_matmul_fp8_block(
)
self.assertTrue(result.shape == (M, N))

expected_result = a @ b.T
expected_result = (a @ b.T).to("cuda")

self.assertTrue(
torch.allclose(result, expected_result, atol=1e2, rtol=5e-2)
Expand All @@ -163,3 +168,5 @@ def _test_matmul_fp8_block(
_test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), True)
_test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), False)
_test_matmul_fp8_block((3, 4, 5), (256, 256, 256), False)
_test_matmul_fp8_block((3, 4, 5), (256, 256, 256), True, "cpu")
_test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), True, "cpu")
72 changes: 71 additions & 1 deletion fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def _kernel_quantize_fp8_block(
tl.store(A_fp8 + a_offset, a_fp8, mask=a_mask)


def quantize_fp8_block(
def triton_quantize_fp8_block(
x: torch.Tensor,
block_m: int = 256,
block_k: int = 256,
Expand Down Expand Up @@ -1050,3 +1050,73 @@ def quantize_fp8_block(
)

return x_fp8, x_scale


def quantize_fp8_block(
x: torch.Tensor,
block_m: int = 256,
block_k: int = 256,
scale_ub: Optional[torch.Tensor] = None,
use_triton: bool = True,
output_device: Optional[torch.device] = None,
) -> Tuple[TensorWrapper, torch.Tensor]:
"""
Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
Args:
x (Tensor): [M, K] higher precision input tensor.
block_m (int): Block size for M dimension of scale.
block_k (int): Block size for K dimension of scale.
scale_ub: Maximum allowed value for scale.
use_triton (bool): Whether to use triton kernel or pytorch.
output_device (torch.device): Device to optionally move the scaled tensors to.
Returns:
TensorWrapper: [M, K] fp8 scaled tensor.
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block.
"""
if x.device == torch.device("cpu"):
logger.info("Triton does not support cpu, falling back to torch ops.")
use_triton = False
if use_triton:
return triton_quantize_fp8_block(x, block_m, block_k, scale_ub)
# else use pytorch implementation.
if not output_device:
output_device = x.device

M, K = x.shape
grid_m = triton.cdiv(M, block_m)
grid_k = triton.cdiv(K, block_k)

# Pad x to multiple of block size.
padded_m = grid_m * block_m
padded_k = grid_k * block_k
x_padded = torch.zeros(padded_m, padded_k, dtype=x.dtype, device=x.device)
x_padded[:M, :K] = x

# Blockwise max.
block_max = (
x_padded.abs().reshape(grid_m, block_m, grid_k, block_k).amax(dim=(1, 3))
)

# Apply clamping.
if scale_ub is not None:
block_max = torch.clamp(block_max, min=EPS, max=scale_ub.item())
else:
# pyre-ignore[6]: Incompatible parameter type [6]
block_max = torch.clamp(block_max, min=EPS)
x_scale = torch.empty((grid_m, grid_k), dtype=torch.float32, device=output_device)
x_scale = MAX_FP8 / block_max.to(torch.float32) # pyre-ignore
x_scale[x_scale == float("inf")] = 1.0
x_fp8 = (
x_padded
* x_scale.repeat_interleave(block_m, dim=0).repeat_interleave(block_k, dim=1)
)[:M, :K]

# Cast and move data to output device (for cpu weight loading).
x_fp8 = convert_fp8_type(x_fp8.to(device=output_device, dtype=PT_FP8_DTYPE))
x_scale = x_scale.to(output_device) # pyre-ignore
del x, x_padded
return x_fp8, 1 / x_scale # pyre-ignore

0 comments on commit 65bcbc4

Please sign in to comment.