Skip to content

Commit

Permalink
Add tma_persistent impl for FP8 rowwise gemm (pytorch#2742)
Browse files Browse the repository at this point in the history
Summary:
Add tma persistent kernel impl for FP8 rowwise gemm with ` fp8_fast_accum=True`

based on the Triton upstream implementation triton-lang/triton#4099.

Pull Request resolved: pytorch#2742

Reviewed By: chenyang78, htyu

Differential Revision: D58656793

Pulled By: sijiac

fbshipit-source-id: 692091eb367cc2fd1ef821384bb5e49347f08929
  • Loading branch information
sijiac authored and facebook-github-bot committed Jun 24, 2024
1 parent 9caef86 commit cdad003
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 1 deletion.
34 changes: 33 additions & 1 deletion fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

@click.command()
@click.option("--cuda-graph", type=bool, default=True)
def bench(cuda_graph: bool) -> None:
@click.option("--rowwise-tma", is_flag=True, default=False)
def bench(cuda_graph: bool, rowwise_tma: bool) -> None:
"""Benchmark bf16 vs scale/cast + fp8."""

def _run_benchmark(
Expand Down Expand Up @@ -69,6 +70,7 @@ def _run_benchmark(
)

shapes = [
(8192, 8192, 512),
(8192, 8192, 8192),
(65536, 8192, 7168),
(65536, 3584, 8192),
Expand All @@ -94,6 +96,12 @@ def _run_benchmark(
tag="fp8 row gemm only | max_num_imprecise_acc=32",
)
_run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only")
if rowwise_tma:
_run_benchmark(
row_gemm_bench_tma,
shape=shape,
tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True",
)


def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
Expand Down Expand Up @@ -148,6 +156,30 @@ def run_gemm() -> Tensor:
return run_gemm


def row_gemm_bench_tma(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
# Benchmark only row-wise gemm with TMA persistent
x_fp8: Tensor
w_fp8: Tensor
x_scale: Tensor
w_scale: Tensor
x_fp8, x_scale = quantize_fp8_row(x)
w_fp8, w_scale = quantize_fp8_row(w)

def run_gemm() -> Tensor:
return matmul_fp8_row(
x_fp8,
w_fp8,
x_scale,
w_scale,
dot_out_dtype=torch.float32,
allow_tf32=True,
fp8_fast_accum=True,
tma_persistent=True,
)

return run_gemm


def row_gemm_bench_no_fast_acc(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
# Benchmark only row-wise gemm, caching scaling.
x_fp8: Tensor
Expand Down
228 changes: 228 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,134 @@ def _kernel_matmul_fp8_row_imprecise_acc(
tl.atomic_add(C, acc, mask=mask)


@triton.jit
def _kernel_matmul_fp8_row_tma_persistent(
A_ptr,
B_ptr,
C_ptr,
M,
N,
K,
A_scale,
B_scale,
stride_am,
stride_ak,
stride_bn,
stride_bk,
stride_cm,
stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
AB_DTYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
) -> None:
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
Args:
A (TensorWrapper): [M, K] input tensor.
B (TensorWrapper): [N, K] input tensor.
C (TensorWrapper): [M, N] output tensor.
M (int): M dimension of input tensor.
N (int): N dimension of input tensor.
K (int): K dimension of input tensor.
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
stride_am (int): Stride of M dimension of A.
stride_ak (int): Stride of K dimension of A.
stride_bn (int): Stride of N dimension of B.
stride_bk (int): Stride of K dimension of B.
stride_cm (int): Stride of M dimension of C.
stride_cn (int): Stride of N dimension of C.
dot_out_dtype (torch.dtype): Output type of tensor core.
allow_tf32 (bool): Whether to use TF32 for tensor core.
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
BLOCK_M (int): Block size for M dimension.
BLOCK_N (int): Block size for N dimension.
BLOCK_K (int): Block size for K dimension.
GROUP_M (int): Number of groups for M dimension swizzle.
SPLIT_K (int): Number of SM's to launch per row.
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
"""
# Matrix multiplication.
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = num_pid_m * num_pid_n

tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1

tile_id = start_pid - NUM_SMS
ki = -1

pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0

num_pid_in_group = GROUP_M * num_pid_n

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)

dtype_fp8 = tl.float8e4nv
scale_dtype = tl.float32

for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m

offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_am = tl.multiple_of(offs_am, BLOCK_M)
offs_bn = tl.multiple_of(offs_bn, BLOCK_N)

offs_k = ki * BLOCK_K

a = tl._experimental_descriptor_load(
A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8
)
b = tl._experimental_descriptor_load(
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
)
acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

if ki == k_tiles - 1:
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M
rn = pid_n * BLOCK_N

# # Invert scaling.
a_scale = tl._experimental_descriptor_load(
A_scale, [rm], [BLOCK_M], scale_dtype
)
b_scale = tl._experimental_descriptor_load(
B_scale, [rn], [BLOCK_N], scale_dtype
)
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
scale = a_scale[:, None] * b_scale[None, :]
acc *= scale
acc = acc.to(C_ptr.dtype.element_ty)

tl._experimental_descriptor_store(C_ptr, acc, [rm, rn])
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)


def matmul_fp8_row(
a: torch.Tensor,
b: torch.Tensor,
Expand All @@ -655,6 +783,7 @@ def matmul_fp8_row(
allow_tf32: bool = True,
fp8_fast_accum: bool = True,
imprecise_acc: bool = False,
tma_persistent: bool = False,
) -> torch.Tensor:
"""
Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].
Expand All @@ -667,6 +796,7 @@ def matmul_fp8_row(
dot_out_dtype (torch.dtype): Output type of tensor core.
allow_tf32 (bool): Whether to use TF32 for tensor core.
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
tma_persistent (bool): Whether to use TMA persistent kernel impl.
Returns:
torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
Expand Down Expand Up @@ -705,6 +835,104 @@ def persistent_grid(META):
),
)

if tma_persistent:
# used by TMA persistent kernel
TMA_SIZE = 128
import numpy as np

# autotune doesn't work with TMA
# https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312

BLOCK_M = 128
BLOCK_N = 256
BLOCK_K = 128
GROUP_M = 8
num_stages = 3
num_warps = 8

desc_a = np.empty(TMA_SIZE, dtype=np.int8)
desc_b = np.empty(TMA_SIZE, dtype=np.int8)
desc_c = np.empty(TMA_SIZE, dtype=np.int8)
desc_a_scale = np.empty(TMA_SIZE, dtype=np.int8)
desc_b_scale = np.empty(TMA_SIZE, dtype=np.int8)

triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
a_tl.data_ptr(),
M,
K,
BLOCK_M,
BLOCK_K,
a_tl.element_size(),
desc_a,
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
b_tl.data_ptr(),
N,
K,
BLOCK_N,
BLOCK_K,
b_tl.element_size(),
desc_b,
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
c.data_ptr(),
M,
N,
BLOCK_M,
BLOCK_N,
c.element_size(),
desc_c,
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
a_scale.data_ptr(),
M,
BLOCK_M,
a_scale.element_size(),
desc_a_scale,
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
b_scale.data_ptr(),
N,
BLOCK_N,
b_scale.element_size(),
desc_b_scale,
)
desc_a = torch.tensor(desc_a, device="cuda")
desc_b = torch.tensor(desc_b, device="cuda")
desc_c = torch.tensor(desc_c, device="cuda")
desc_a_scale = torch.tensor(desc_a_scale, device="cuda")
desc_b_scale = torch.tensor(desc_b_scale, device="cuda")

# pyre-ignore[28]:
_kernel_matmul_fp8_row_tma_persistent[persistent_grid](
desc_a,
desc_b,
desc_c,
M,
N,
K,
desc_a_scale,
desc_b_scale,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
dot_out_dtype=dot_out_dtype_triton,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
GROUP_M=GROUP_M,
AB_DTYPE=False,
NUM_SMS=NUM_SMS,
num_stages=num_stages,
num_warps=num_warps,
)
return c

if imprecise_acc:
_kernel_matmul_fp8_row_imprecise_acc[grid](
a_tl,
Expand Down

0 comments on commit cdad003

Please sign in to comment.