Skip to content

Commit

Permalink
Add Cutlass Blockwise Kernel to Quantize Benchmark (pytorch#2800)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2800

X-link: facebookresearch/FBGEMM#4

This diff adds the new cutlass blockwise kernel added in D57965065 to quantize_bench.py. I also set a default block size to make the API conformant with the triton quantize op that is often paired with it.

Reviewed By: choudharydhruv

Differential Revision: D59249763

fbshipit-source-id: 7d845745549111115c3f84b9deaad389a57bd7eb
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jul 2, 2024
1 parent 162a966 commit ece4c24
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 30 deletions.
49 changes: 40 additions & 9 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,23 +332,54 @@ def cuda(self) -> bool:


@register_quantize_op
class FP8BlockwiseGemm(QuantizeOpBase):
class FP8TritonBlockwiseGemm(QuantizeOpBase):
"""
FP8 matmul with block scaling.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_block(x)
wq, w_scale = quantize_fp8_block(w)
xq, x_scale = quantize_fp8_block(x, 128, 128)
wq, w_scale = quantize_fp8_block(w, 128, 128)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
# Dispatch to appropriate function based on device.
if torch.version.cuda:
return matmul_fp8_block(xq, wq, x_scale, w_scale)
else:
return torch.ops.fbgemm.f8f8bf16_blockwise(xq, wq, x_scale, w_scale)
return matmul_fp8_block(xq, wq, x_scale, w_scale, 128, 128, 128)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale)

@property
def name(self) -> str:
return "triton_blockwise"

@property
def hip(self) -> bool:
# Currently has some issues.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CutlassBlockwiseGemm(QuantizeOpBase):
"""
FP8 matmul with block scaling.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_block(x, 128, 128)
wq, w_scale = quantize_fp8_block(w, 128, 128)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
return torch.ops.fbgemm.f8f8bf16_blockwise(
xq, wq, x_scale, w_scale, 128, 128, 128
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
Expand All @@ -357,7 +388,7 @@ def quantize_and_compute(self, x, w):
@property
def name(self) -> str:
if torch.version.cuda:
return "triton_blockwise"
return "cutlass_blockwise"
else:
return "ck_blockwise"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ at::Tensor f8f8bf16_blockwise_impl(
int StrideB = K;
int StrideE = N;

// For now hardcode block size.
TORCH_CHECK(block_m == 256);
TORCH_CHECK(block_n == 256);
TORCH_CHECK(block_k == 256);

int ScaleStrideAM = K / block_k;
int ScaleStrideAK = 1;
int ScaleStrideBN = K / block_k;
Expand Down Expand Up @@ -243,9 +238,9 @@ at::Tensor f8f8bf16_blockwise(
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k) {
int64_t block_m = 256,
int64_t block_n = 256,
int64_t block_k = 256) {
// Check that input datatypes are valid.
TORCH_CHECK(
(XQ.dtype() == at::kFloat8_e4m3fnuz) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1666,9 +1666,9 @@ at::Tensor f8f8bf16_blockwise(
at::Tensor WQ, // FP8
at::Tensor x_scale, // FP32
at::Tensor w_scale, // FP32
int64_t block_m,
int64_t block_n,
int64_t block_k) {
int64_t block_m = 256,
int64_t block_n = 256,
int64_t block_k = 256) {
// Check datatypes.
TORCH_CHECK(
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
Expand Down Expand Up @@ -2443,9 +2443,9 @@ at::Tensor f8f8bf16_blockwise(
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k) {
int64_t block_m = 256,
int64_t block_n = 256,
int64_t block_k = 256) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand Down
14 changes: 7 additions & 7 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ at::Tensor f8f8bf16_blockwise(
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
int64_t block_m,
int64_t block_n,
int64_t block_k);
int64_t block_m = 256,
int64_t block_n = 256,
int64_t block_k = 256);
at::Tensor f8f8bf16_cublas(
at::Tensor A,
at::Tensor B,
Expand Down Expand Up @@ -138,7 +138,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic);
#endif
m.def(
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m, int block_n, int block_k) -> Tensor");
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=256, int block_n=256, int block_k=256) -> Tensor");
m.def(
"f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -226,9 +226,9 @@ at::Tensor f8f8bf16_blockwise_meta(
at::Tensor WQ, // FP8
at::Tensor /* x_scale */,
at::Tensor /* w_scale */,
int64_t /* block_m */,
int64_t /* block_n */,
int64_t /* block_k */) {
int64_t /* block_m = 256*/,
int64_t /* block_n = 256*/,
int64_t /* block_k = 256*/) {
int M = XQ.size(0);
int N = WQ.size(0);
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
Expand Down

0 comments on commit ece4c24

Please sign in to comment.