diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py index bd5123c7c..77fa04b3a 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py @@ -24,7 +24,8 @@ @click.command() @click.option("--cuda-graph", type=bool, default=True) -def bench(cuda_graph: bool) -> None: +@click.option("--rowwise-tma", is_flag=True, default=False) +def bench(cuda_graph: bool, rowwise_tma: bool) -> None: """Benchmark bf16 vs scale/cast + fp8.""" def _run_benchmark( @@ -69,6 +70,7 @@ def _run_benchmark( ) shapes = [ + (8192, 8192, 512), (8192, 8192, 8192), (65536, 8192, 7168), (65536, 3584, 8192), @@ -94,6 +96,12 @@ def _run_benchmark( tag="fp8 row gemm only | max_num_imprecise_acc=32", ) _run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only") + if rowwise_tma: + _run_benchmark( + row_gemm_bench_tma, + shape=shape, + tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True", + ) def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: @@ -148,6 +156,30 @@ def run_gemm() -> Tensor: return run_gemm +def row_gemm_bench_tma(x: Tensor, w: Tensor) -> Callable[[], Tensor]: + # Benchmark only row-wise gemm with TMA persistent + x_fp8: Tensor + w_fp8: Tensor + x_scale: Tensor + w_scale: Tensor + x_fp8, x_scale = quantize_fp8_row(x) + w_fp8, w_scale = quantize_fp8_row(w) + + def run_gemm() -> Tensor: + return matmul_fp8_row( + x_fp8, + w_fp8, + x_scale, + w_scale, + dot_out_dtype=torch.float32, + allow_tf32=True, + fp8_fast_accum=True, + tma_persistent=True, + ) + + return run_gemm + + def row_gemm_bench_no_fast_acc(x: Tensor, w: Tensor) -> Callable[[], Tensor]: # Benchmark only row-wise gemm, caching scaling. x_fp8: Tensor diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 4f16fcfab..68a65bbd4 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -646,6 +646,134 @@ def _kernel_matmul_fp8_row_imprecise_acc( tl.atomic_add(C, acc, mask=mask) +@triton.jit +def _kernel_matmul_fp8_row_tma_persistent( + A_ptr, + B_ptr, + C_ptr, + M, + N, + K, + A_scale, + B_scale, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + dot_out_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + fp8_fast_accum: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + AB_DTYPE: tl.constexpr, + NUM_SMS: tl.constexpr, +) -> None: + """Matmul kernel of [M, K] @ [N, K] with row-wise scales + + performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles. + + Args: + A (TensorWrapper): [M, K] input tensor. + B (TensorWrapper): [N, K] input tensor. + C (TensorWrapper): [M, N] output tensor. + M (int): M dimension of input tensor. + N (int): N dimension of input tensor. + K (int): K dimension of input tensor. + A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A + B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B + stride_am (int): Stride of M dimension of A. + stride_ak (int): Stride of K dimension of A. + stride_bn (int): Stride of N dimension of B. + stride_bk (int): Stride of K dimension of B. + stride_cm (int): Stride of M dimension of C. + stride_cn (int): Stride of N dimension of C. + dot_out_dtype (torch.dtype): Output type of tensor core. + allow_tf32 (bool): Whether to use TF32 for tensor core. + fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. + BLOCK_M (int): Block size for M dimension. + BLOCK_N (int): Block size for N dimension. + BLOCK_K (int): Block size for K dimension. + GROUP_M (int): Number of groups for M dimension swizzle. + SPLIT_K (int): Number of SM's to launch per row. + EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K. + AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core. + """ + # Matrix multiplication. + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + + dtype_fp8 = tl.float8e4nv + scale_dtype = tl.float32 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + offs_am = tl.multiple_of(offs_am, BLOCK_M) + offs_bn = tl.multiple_of(offs_bn, BLOCK_N) + + offs_k = ki * BLOCK_K + + a = tl._experimental_descriptor_load( + A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8 + ) + b = tl._experimental_descriptor_load( + B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8 + ) + acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + if ki == k_tiles - 1: + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + # # Invert scaling. + a_scale = tl._experimental_descriptor_load( + A_scale, [rm], [BLOCK_M], scale_dtype + ) + b_scale = tl._experimental_descriptor_load( + B_scale, [rn], [BLOCK_N], scale_dtype + ) + # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`. + scale = a_scale[:, None] * b_scale[None, :] + acc *= scale + acc = acc.to(C_ptr.dtype.element_ty) + + tl._experimental_descriptor_store(C_ptr, acc, [rm, rn]) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + + def matmul_fp8_row( a: torch.Tensor, b: torch.Tensor, @@ -655,6 +783,7 @@ def matmul_fp8_row( allow_tf32: bool = True, fp8_fast_accum: bool = True, imprecise_acc: bool = False, + tma_persistent: bool = False, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N]. @@ -667,6 +796,7 @@ def matmul_fp8_row( dot_out_dtype (torch.dtype): Output type of tensor core. allow_tf32 (bool): Whether to use TF32 for tensor core. fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. + tma_persistent (bool): Whether to use TMA persistent kernel impl. Returns: torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) @@ -705,6 +835,104 @@ def persistent_grid(META): ), ) + if tma_persistent: + # used by TMA persistent kernel + TMA_SIZE = 128 + import numpy as np + + # autotune doesn't work with TMA + # https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312 + + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 128 + GROUP_M = 8 + num_stages = 3 + num_warps = 8 + + desc_a = np.empty(TMA_SIZE, dtype=np.int8) + desc_b = np.empty(TMA_SIZE, dtype=np.int8) + desc_c = np.empty(TMA_SIZE, dtype=np.int8) + desc_a_scale = np.empty(TMA_SIZE, dtype=np.int8) + desc_b_scale = np.empty(TMA_SIZE, dtype=np.int8) + + triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + a_tl.data_ptr(), + M, + K, + BLOCK_M, + BLOCK_K, + a_tl.element_size(), + desc_a, + ) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + b_tl.data_ptr(), + N, + K, + BLOCK_N, + BLOCK_K, + b_tl.element_size(), + desc_b, + ) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + c.data_ptr(), + M, + N, + BLOCK_M, + BLOCK_N, + c.element_size(), + desc_c, + ) + triton.runtime.driver.active.utils.fill_1d_tma_descriptor( + a_scale.data_ptr(), + M, + BLOCK_M, + a_scale.element_size(), + desc_a_scale, + ) + triton.runtime.driver.active.utils.fill_1d_tma_descriptor( + b_scale.data_ptr(), + N, + BLOCK_N, + b_scale.element_size(), + desc_b_scale, + ) + desc_a = torch.tensor(desc_a, device="cuda") + desc_b = torch.tensor(desc_b, device="cuda") + desc_c = torch.tensor(desc_c, device="cuda") + desc_a_scale = torch.tensor(desc_a_scale, device="cuda") + desc_b_scale = torch.tensor(desc_b_scale, device="cuda") + + # pyre-ignore[28]: + _kernel_matmul_fp8_row_tma_persistent[persistent_grid]( + desc_a, + desc_b, + desc_c, + M, + N, + K, + desc_a_scale, + desc_b_scale, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + dot_out_dtype=dot_out_dtype_triton, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + GROUP_M=GROUP_M, + AB_DTYPE=False, + NUM_SMS=NUM_SMS, + num_stages=num_stages, + num_warps=num_warps, + ) + return c + if imprecise_acc: _kernel_matmul_fp8_row_imprecise_acc[grid]( a_tl,