Skip to content

Commit

Permalink
Enable E2E MoE INT4 BMM with loopover (#3170)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3170

X-link: facebookresearch/FBGEMM#264

- Enable E2E MoE INT4 BMM with loopover
- Support 3d tensors in quantize_marlin_int4
- Add unit tests

Reviewed By: jianyuh

Differential Revision: D63303484
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Sep 26, 2024
1 parent e90603b commit 79c45cb
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@

from hypothesis import given, settings, strategies as st

# Marlin is currently only supported internally at Meta.
try:
if not torch.version.hip:
from marlin.quantize import marlin_quantize

torch.ops.load_library("//ai_codesign/gen_ai/marlin:marlin_ops")
MARLIN_ENABLED = True
except ImportError:
MARLIN_ENABLED = False

# Supported FP8 format is different on NV and AMD.
if torch.version.hip is not None:
fp8_e4m3: torch.dtype = torch.float8_e4m3fnuz
Expand Down Expand Up @@ -648,6 +658,53 @@ def fp8_loopover_bmm(
y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale)
torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2)

@unittest.skipIf(torch.version.hip, "Skip on AMD: Marlin not yet suported.")
@settings(deadline=None)
@given(
B=st.sampled_from([1, 4]),
M=st.sampled_from([2048, 4096]),
N=st.sampled_from([256, 512]),
K=st.sampled_from([256, 512]),
)
def test_int4_batched_gemm(
self,
B: int,
M: int,
N: int,
K: int,
) -> None:
if not MARLIN_ENABLED:
return
x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1
w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01

wq = []
w_scale = []
group_size = 128
for i in range(B):
_, wq_, w_scale_ = marlin_quantize(w[i].cuda().t().contiguous(), group_size)
wq.append(wq_)
w_scale.append(w_scale_)
wq = torch.stack(wq)
w_scale = torch.stack(w_scale)

def int4_loopover_bmm(
x: torch.Tensor,
wq: torch.Tensor,
w_scale: torch.Tensor,
) -> torch.Tensor:
B = x.shape[0]
M = x.shape[1]
N = w_scale.shape[2]
y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device)
for i in range(B):
y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i])
return y

y_ref = torch.bmm(x, w.transpose(1, 2))
y_int4 = int4_loopover_bmm(x, wq, w_scale)
torch.testing.assert_close(y_ref, y_int4, atol=8.0e-2, rtol=8.0e-2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 79c45cb

Please sign in to comment.