diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 1ab839af1..94c6eee00 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -9,7 +9,7 @@ import unittest -from typing import Tuple +from typing import List, Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 @@ -604,6 +604,50 @@ def test_quantize_fp8_per_tensor_with_ub( zq_ref = (x @ w.T).to(torch.bfloat16) torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) + @settings(deadline=None) + @given( + B=st.sampled_from([1, 4]), + M=st.sampled_from([2048, 4096]), + N=st.sampled_from([128, 256]), + K=st.sampled_from([256, 512]), + ) + def test_fp8_batched_gemm( + self, + B: int, + M: int, + N: int, + K: int, + ) -> None: + x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01 + + xq, x_scale = quantize_fp8_row(x) + x_scale = x_scale.view(B, -1) + assert x_scale.shape == (B, M) + wq, w_scale = quantize_fp8_row(w) + w_scale = w_scale.view(B, -1) + assert w_scale.shape == (B, N) + + def fp8_loopover_bmm( + xq: List[torch.Tensor], + wq: List[torch.Tensor], + x_scale: List[torch.Tensor], + w_scale: List[torch.Tensor], + ) -> torch.Tensor: + B = len(xq) + M = xq[0].shape[0] + N = wq[0].shape[0] + y = torch.empty((B, M, N), dtype=torch.bfloat16, device=xq[0].device) + for i in range(B): + y[i] = torch.ops.fbgemm.f8f8bf16_rowwise( + xq[i], wq[i], x_scale[i], w_scale[i] + ) + return y + + y_ref = torch.bmm(x, w.transpose(1, 2)) + y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) + if __name__ == "__main__": unittest.main()