From 79c45cba551c80d564167333760f0c3ff008deb6 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Thu, 26 Sep 2024 09:14:16 -0700 Subject: [PATCH] Enable E2E MoE INT4 BMM with loopover (#3170) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3170 X-link: https://github.com/facebookresearch/FBGEMM/pull/264 - Enable E2E MoE INT4 BMM with loopover - Support 3d tensors in quantize_marlin_int4 - Add unit tests Reviewed By: jianyuh Differential Revision: D63303484 --- .../gen_ai/test/quantize/quantize_test.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 94c6eee00..fdd038e2c 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -24,6 +24,16 @@ from hypothesis import given, settings, strategies as st +# Marlin is currently only supported internally at Meta. +try: + if not torch.version.hip: + from marlin.quantize import marlin_quantize + + torch.ops.load_library("//ai_codesign/gen_ai/marlin:marlin_ops") + MARLIN_ENABLED = True +except ImportError: + MARLIN_ENABLED = False + # Supported FP8 format is different on NV and AMD. if torch.version.hip is not None: fp8_e4m3: torch.dtype = torch.float8_e4m3fnuz @@ -648,6 +658,53 @@ def fp8_loopover_bmm( y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) + @unittest.skipIf(torch.version.hip, "Skip on AMD: Marlin not yet suported.") + @settings(deadline=None) + @given( + B=st.sampled_from([1, 4]), + M=st.sampled_from([2048, 4096]), + N=st.sampled_from([256, 512]), + K=st.sampled_from([256, 512]), + ) + def test_int4_batched_gemm( + self, + B: int, + M: int, + N: int, + K: int, + ) -> None: + if not MARLIN_ENABLED: + return + x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01 + + wq = [] + w_scale = [] + group_size = 128 + for i in range(B): + _, wq_, w_scale_ = marlin_quantize(w[i].cuda().t().contiguous(), group_size) + wq.append(wq_) + w_scale.append(w_scale_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale) + + def int4_loopover_bmm( + x: torch.Tensor, + wq: torch.Tensor, + w_scale: torch.Tensor, + ) -> torch.Tensor: + B = x.shape[0] + M = x.shape[1] + N = w_scale.shape[2] + y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + for i in range(B): + y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) + return y + + y_ref = torch.bmm(x, w.transpose(1, 2)) + y_int4 = int4_loopover_bmm(x, wq, w_scale) + torch.testing.assert_close(y_ref, y_int4, atol=8.0e-2, rtol=8.0e-2) + if __name__ == "__main__": unittest.main()