diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index 0bcd38af7..1e1d5ec1b 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -131,15 +131,20 @@ def _test_matmul_fp8_block( shape: Tuple[int, int, int], block_shape: Tuple[int, int, int], fp8_fast_accum: bool, + device: str = "cuda", ) -> None: M, N, K = shape BLOCK_M, BLOCK_N, BLOCK_K = block_shape - a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") - b = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + a = torch.randn(M, K, dtype=torch.bfloat16, device=device) + b = torch.randn(N, K, dtype=torch.bfloat16, device=device) # Quantize inputs. - a_fp8, a_scale = quantize_fp8_block(a, BLOCK_M, BLOCK_K) - b_fp8, b_scale = quantize_fp8_block(b, BLOCK_N, BLOCK_K) + a_fp8, a_scale = quantize_fp8_block( + a, BLOCK_M, BLOCK_K, output_device=torch.device("cuda") + ) + b_fp8, b_scale = quantize_fp8_block( + b, BLOCK_N, BLOCK_K, output_device=torch.device("cuda") + ) result = matmul_fp8_block( a_fp8, @@ -153,7 +158,7 @@ def _test_matmul_fp8_block( ) self.assertTrue(result.shape == (M, N)) - expected_result = a @ b.T + expected_result = (a @ b.T).to("cuda") self.assertTrue( torch.allclose(result, expected_result, atol=1e2, rtol=5e-2) @@ -163,3 +168,5 @@ def _test_matmul_fp8_block( _test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), True) _test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), False) _test_matmul_fp8_block((3, 4, 5), (256, 256, 256), False) + _test_matmul_fp8_block((3, 4, 5), (256, 256, 256), True, "cpu") + _test_matmul_fp8_block((1024, 2048, 4096), (256, 512, 1024), True, "cpu") diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 3c6a1f8e3..e2a41b50f 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -999,7 +999,7 @@ def _kernel_quantize_fp8_block( tl.store(A_fp8 + a_offset, a_fp8, mask=a_mask) -def quantize_fp8_block( +def triton_quantize_fp8_block( x: torch.Tensor, block_m: int = 256, block_k: int = 256, @@ -1050,3 +1050,73 @@ def quantize_fp8_block( ) return x_fp8, x_scale + + +def quantize_fp8_block( + x: torch.Tensor, + block_m: int = 256, + block_k: int = 256, + scale_ub: Optional[torch.Tensor] = None, + use_triton: bool = True, + output_device: Optional[torch.device] = None, +) -> Tuple[TensorWrapper, torch.Tensor]: + """ + Quantize a tensor to fp8 with block-wise scalings and optionally move to output device. + + Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k]))) + + Args: + x (Tensor): [M, K] higher precision input tensor. + block_m (int): Block size for M dimension of scale. + block_k (int): Block size for K dimension of scale. + scale_ub: Maximum allowed value for scale. + use_triton (bool): Whether to use triton kernel or pytorch. + output_device (torch.device): Device to optionally move the scaled tensors to. + + Returns: + TensorWrapper: [M, K] fp8 scaled tensor. + torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block. + """ + if x.device == torch.device("cpu"): + logger.info("Triton does not support cpu, falling back to torch ops.") + use_triton = False + if use_triton: + return triton_quantize_fp8_block(x, block_m, block_k, scale_ub) + # else use pytorch implementation. + if not output_device: + output_device = x.device + + M, K = x.shape + grid_m = triton.cdiv(M, block_m) + grid_k = triton.cdiv(K, block_k) + + # Pad x to multiple of block size. + padded_m = grid_m * block_m + padded_k = grid_k * block_k + x_padded = torch.zeros(padded_m, padded_k, dtype=x.dtype, device=x.device) + x_padded[:M, :K] = x + + # Blockwise max. + block_max = ( + x_padded.abs().reshape(grid_m, block_m, grid_k, block_k).amax(dim=(1, 3)) + ) + + # Apply clamping. + if scale_ub is not None: + block_max = torch.clamp(block_max, min=EPS, max=scale_ub.item()) + else: + # pyre-ignore[6]: Incompatible parameter type [6] + block_max = torch.clamp(block_max, min=EPS) + x_scale = torch.empty((grid_m, grid_k), dtype=torch.float32, device=output_device) + x_scale = MAX_FP8 / block_max.to(torch.float32) # pyre-ignore + x_scale[x_scale == float("inf")] = 1.0 + x_fp8 = ( + x_padded + * x_scale.repeat_interleave(block_m, dim=0).repeat_interleave(block_k, dim=1) + )[:M, :K] + + # Cast and move data to output device (for cpu weight loading). + x_fp8 = convert_fp8_type(x_fp8.to(device=output_device, dtype=PT_FP8_DTYPE)) + x_scale = x_scale.to(output_device) # pyre-ignore + del x, x_padded + return x_fp8, 1 / x_scale # pyre-ignore