From 2d3e9ddf4fa0436b39922a8be87ce8a096c33c37 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Thu, 19 Sep 2024 16:35:45 -0700 Subject: [PATCH] MoE FP8 BMM with loopover (#3147) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3147 X-link: https://github.com/facebookresearch/FBGEMM/pull/239 Enable MoE FP8 rowwise BMM with loopover and benchmarks - MoE FP8 rowwise BMM using loopover with quant ops achieves **1.8x speedup over BF16 BMM with max_autotune** - BF16 BMM with torch.compile max_autotune (enabled in D62278399) can bring up to 2x speedup over torch.bmm (cublas) - Replacing Triton FP8 quantization loopover with 3d brings up to 30x speedup for the quantization op ([data sheet](https://l.facebook.com/l.php?u=https%3A%2F%2Fdocs.google.com%2Fspreadsheets%2Fd%2F1n-nUuus-XXmKykBvvXs3u0AfTJMoChVHQdG_u3pSPrk%2Fedit%3Fusp%3Dsharing&h=AT2OzkEQcXVJth6Q2-2VfhCNugmuPVWHXjrA2nJor8c2O54xyGQu-9kDB_dE9X2dVC6i8QY97QJp2Ojlb3cAvkmxNvMiajUs-jZ6oZl4gMmPKPNkOgWScdYNtP7geoIy1aTYr21rAszjznNFEYVgf9dPonw)) - For MoE in inference/training with expert parallelism, number of local experts are normally 2, and 4 is the max, such that performance of loopover MoE FP8 BMM is acceptable - Working on enabling customized MoE FP8 BMM kernel which could further improve performance - More results are in the [data sheet](https://docs.google.com/spreadsheets/d/1S-XqBh10G8sZqw97AJq37uy-JV6fxlZIVb2C41eeX38/edit?usp=sharing) {F1873899701} Reviewed By: jianyuh Differential Revision: D62889315 fbshipit-source-id: de87d9757f314974af1b12b9b97a47eeca6e2acc --- .../gen_ai/test/quantize/quantize_test.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 1ab839af1..94c6eee00 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -9,7 +9,7 @@ import unittest -from typing import Tuple +from typing import List, Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 @@ -604,6 +604,50 @@ def test_quantize_fp8_per_tensor_with_ub( zq_ref = (x @ w.T).to(torch.bfloat16) torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) + @settings(deadline=None) + @given( + B=st.sampled_from([1, 4]), + M=st.sampled_from([2048, 4096]), + N=st.sampled_from([128, 256]), + K=st.sampled_from([256, 512]), + ) + def test_fp8_batched_gemm( + self, + B: int, + M: int, + N: int, + K: int, + ) -> None: + x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01 + + xq, x_scale = quantize_fp8_row(x) + x_scale = x_scale.view(B, -1) + assert x_scale.shape == (B, M) + wq, w_scale = quantize_fp8_row(w) + w_scale = w_scale.view(B, -1) + assert w_scale.shape == (B, N) + + def fp8_loopover_bmm( + xq: List[torch.Tensor], + wq: List[torch.Tensor], + x_scale: List[torch.Tensor], + w_scale: List[torch.Tensor], + ) -> torch.Tensor: + B = len(xq) + M = xq[0].shape[0] + N = wq[0].shape[0] + y = torch.empty((B, M, N), dtype=torch.bfloat16, device=xq[0].device) + for i in range(B): + y[i] = torch.ops.fbgemm.f8f8bf16_rowwise( + xq[i], wq[i], x_scale[i], w_scale[i] + ) + return y + + y_ref = torch.bmm(x, w.transpose(1, 2)) + y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) + if __name__ == "__main__": unittest.main()