Skip to content

Commit

Permalink
MoE FP8 BMM with loopover (pytorch#3147)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3147

X-link: facebookresearch/FBGEMM#239

Enable MoE FP8 rowwise BMM with loopover and benchmarks

- MoE FP8 rowwise BMM using loopover with quant ops achieves **1.8x speedup over BF16 BMM with max_autotune**
- BF16 BMM with torch.compile max_autotune (enabled in D62278399) can bring up to 2x speedup over torch.bmm (cublas)
- Replacing Triton FP8 quantization loopover with 3d brings up to 30x speedup for the quantization op ([data sheet](https://l.facebook.com/l.php?u=https%3A%2F%2Fdocs.google.com%2Fspreadsheets%2Fd%2F1n-nUuus-XXmKykBvvXs3u0AfTJMoChVHQdG_u3pSPrk%2Fedit%3Fusp%3Dsharing&h=AT2OzkEQcXVJth6Q2-2VfhCNugmuPVWHXjrA2nJor8c2O54xyGQu-9kDB_dE9X2dVC6i8QY97QJp2Ojlb3cAvkmxNvMiajUs-jZ6oZl4gMmPKPNkOgWScdYNtP7geoIy1aTYr21rAszjznNFEYVgf9dPonw))
- For MoE in inference/training with expert parallelism, number of local experts are normally 2, and 4 is the max, such that performance of loopover MoE FP8 BMM is acceptable
- Working on enabling customized MoE FP8 BMM kernel which could further improve performance
- More results are in the [data sheet](https://docs.google.com/spreadsheets/d/1S-XqBh10G8sZqw97AJq37uy-JV6fxlZIVb2C41eeX38/edit?usp=sharing)

 {F1873899701}

Reviewed By: jianyuh

Differential Revision: D62889315

fbshipit-source-id: de87d9757f314974af1b12b9b97a47eeca6e2acc
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Sep 19, 2024
1 parent 55f4c78 commit 2d3e9dd
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import unittest

from typing import Tuple
from typing import List, Tuple

import fbgemm_gpu.experimental.gen_ai # noqa: F401

Expand Down Expand Up @@ -604,6 +604,50 @@ def test_quantize_fp8_per_tensor_with_ub(
zq_ref = (x @ w.T).to(torch.bfloat16)
torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3)

@settings(deadline=None)
@given(
B=st.sampled_from([1, 4]),
M=st.sampled_from([2048, 4096]),
N=st.sampled_from([128, 256]),
K=st.sampled_from([256, 512]),
)
def test_fp8_batched_gemm(
self,
B: int,
M: int,
N: int,
K: int,
) -> None:
x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1
w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01

xq, x_scale = quantize_fp8_row(x)
x_scale = x_scale.view(B, -1)
assert x_scale.shape == (B, M)
wq, w_scale = quantize_fp8_row(w)
w_scale = w_scale.view(B, -1)
assert w_scale.shape == (B, N)

def fp8_loopover_bmm(
xq: List[torch.Tensor],
wq: List[torch.Tensor],
x_scale: List[torch.Tensor],
w_scale: List[torch.Tensor],
) -> torch.Tensor:
B = len(xq)
M = xq[0].shape[0]
N = wq[0].shape[0]
y = torch.empty((B, M, N), dtype=torch.bfloat16, device=xq[0].device)
for i in range(B):
y[i] = torch.ops.fbgemm.f8f8bf16_rowwise(
xq[i], wq[i], x_scale[i], w_scale[i]
)
return y

y_ref = torch.bmm(x, w.transpose(1, 2))
y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale)
torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2d3e9dd

Please sign in to comment.