Skip to content

Commit

Permalink
Make some fbgemm fp8 triton ops pt2 friendly (#3188)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#283

Pull Request resolved: #3188

Make some fbgemm fp8 triton ops pt2 friendly..

# What this diff tries to do
* stop using TensorWrapper and tl.reinterpret
* Remove the use of triton_heuristics for _kernel_matmul_fp8_row

# What this diff won't help:
* triton_herustics use cases of EVEN_K. One option is to just merge that into the autotuning configs

# need to do in the future:
* Update other ops, like quantize_fp8_row.
* Update documentation. Feels pretty outdated, and some still reference to TensorWrapper.

Differential Revision: D63560103
  • Loading branch information
jwfromm authored and facebook-github-bot committed Sep 30, 2024
1 parent 93dcc07 commit 0385aa4
Showing 1 changed file with 52 additions and 47 deletions.
99 changes: 52 additions & 47 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-unsafe
import logging
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
import triton # @manual
Expand All @@ -15,10 +15,7 @@
from torch._tensor import Tensor

from triton import Config # @manual
from triton.ops.matmul_perf_model import ( # @manual
early_config_prune,
estimate_matmul_time,
)
from triton.ops.matmul_perf_model import early_config_prune # @manual
from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual

logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -43,7 +40,7 @@ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12


def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper:
def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper:
"""
Converts tensor to triton fp8 type.
Expand Down Expand Up @@ -213,11 +210,6 @@ def get_configs_io_bound() -> List[Config]:
"k_key",
],
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def _kernel_matmul_fp8_row(
A_ptr,
Expand Down Expand Up @@ -246,7 +238,6 @@ def _kernel_matmul_fp8_row(
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
USE_BIAS: tl.constexpr,
AB_DTYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
Expand Down Expand Up @@ -964,7 +955,7 @@ def get_tma_descriptor_kernel_param(self, name):
return self.cuda_descriptors[name]


@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=())
@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
def matmul_fp8_row(
a: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -995,15 +986,15 @@ def matmul_fp8_row(
torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
"""
# Get datatypes and constants to use.
_, tl_dtype, _, _ = get_fp8_constants()
pt_fp8_dtype, _, _, _ = get_fp8_constants()
# Handle 3D+ a shape
a_shape = a.shape
a = a.view(-1, a.size(-1))
# Reinterpret inputs into proper triton fp8 dtype.
a_tl = convert_fp8_type(a, tl_dtype)
b_tl = convert_fp8_type(b, tl_dtype)
# View inputs into proper torch fp8 dtype.
assert a.dtype == pt_fp8_dtype
assert b.dtype == pt_fp8_dtype
M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
prep_matmul(a_tl, b_tl, dot_out_dtype)
prep_matmul(a, b, dot_out_dtype)
)

output_shape = a_shape[:-1] + (N,)
Expand Down Expand Up @@ -1049,22 +1040,22 @@ def persistent_grid_tma(META):
nonlocal desc_helper
desc_helper.fill_2d_tma_descriptor(
"a",
a_tl.data_ptr(),
a.data_ptr(),
M,
K,
META["BLOCK_M"],
META["BLOCK_K"],
a_tl.element_size(),
a.element_size(),
)

desc_helper.fill_2d_tma_descriptor(
"b",
b_tl.data_ptr(),
b.data_ptr(),
N,
K,
META["BLOCK_N"],
META["BLOCK_K"],
b_tl.element_size(),
b.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"c",
Expand Down Expand Up @@ -1111,8 +1102,10 @@ def persistent_grid_tma(META):
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")

# pyre-ignore[28]:
_kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma](
# pyre-ignore
torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
persistent_grid_tma
](
desc_a,
desc_b,
desc_c,
Expand Down Expand Up @@ -1141,9 +1134,9 @@ def persistent_grid_tma(META):
USE_BIAS=bias is not None,
)
elif imprecise_acc:
_kernel_matmul_fp8_row_imprecise_acc[grid](
a_tl,
b_tl,
torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid](
a,
b,
c,
M,
N,
Expand All @@ -1168,9 +1161,9 @@ def persistent_grid_tma(META):
AB_DTYPE=False,
)
elif fp8_fast_accum:
_kernel_matmul_fp8_row[persistent_grid](
a_tl,
b_tl,
torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid](
a,
b,
c,
M,
N,
Expand All @@ -1196,9 +1189,11 @@ def persistent_grid_tma(META):
NUM_SMS=NUM_SMS,
)
else:
_kernel_matmul_fp8_row_no_fast_acc[persistent_grid](
a_tl,
b_tl,
torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[
persistent_grid
](
a,
b,
c,
M,
N,
Expand Down Expand Up @@ -1269,8 +1264,6 @@ def prune_configs_block(configs, named_args, **kwargs):
], # TODO caller side bin keys so similar shapes can use same triton.autotune.
prune_configs_by={
"early_config_prune": prune_configs_block,
"perf_model": estimate_matmul_time,
"top_k": 10,
},
)
@triton.heuristics(
Expand Down Expand Up @@ -1465,8 +1458,6 @@ def _kernel_matmul_fp8_block_fastacc(
], # TODO caller side bin keys so similar shapes can use same triton.autotune.
prune_configs_by={
"early_config_prune": early_config_prune,
"perf_model": estimate_matmul_time,
"top_k": 10,
},
)
@triton.heuristics(
Expand Down Expand Up @@ -1659,13 +1650,13 @@ def matmul_fp8_block(
Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale)
"""
# Get datatypes and constants to use.
_, tl_dtype, _, _ = get_fp8_constants()
_, tl_fp8_dtype, _, _ = get_fp8_constants()
# Handle 3D+ a shape
a_shape = a.shape
a = a.view(-1, a.size(-1))
# Reinterpret inputs into proper triton fp8 dtype.
a_tl = convert_fp8_type(a, tl_dtype)
b_tl = convert_fp8_type(b, tl_dtype)
# View inputs into proper triton fp8 dtype.
a_tl = reinterpret_fp8_type(a, tl_fp8_dtype)
b_tl = reinterpret_fp8_type(b, tl_fp8_dtype)

M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul(
a_tl, b_tl, dot_out_dtype
Expand Down Expand Up @@ -1794,14 +1785,18 @@ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]:


def prep_matmul(
a: TensorWrapper, b: TensorWrapper, dot_out_dtype: Optional[torch.dtype]
) -> Tuple[int, int, int, int, int, int, torch.Tensor, str, str, torch.device]:
a: Union[TensorWrapper, torch.Tensor],
b: Union[TensorWrapper, torch.Tensor],
dot_out_dtype: Optional[torch.dtype],
) -> Tuple[
int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device
]:
"""
Shared bookkeeping for a @ b.T matmul.
Args:
a (TensorWrapper): [M, K] input tensor.
b (TensorWrapper): [N, K] input tensor.
a (torch.Tensor): [M, K] input tensor.
b (torch.Tensor): [N, K] input tensor.
dot_out_dtype (tl.dtype): Output type of tensor core.
Returns:
Expand All @@ -1812,7 +1807,8 @@ def prep_matmul(
n_key (int): Autotuning key for N dim.
k_key (int): Autotuning key for K dim.
c (Tensor): [M, N] output tensor.
dot_out_dtype (torch.dtype): Output type of tensor core.
c_dtype_triton (tl.dtype): Type of output tensor.
dot_out_dtype (tl.dtype): Output type of tensor core.
device (torch.device): Device of output tensor.
"""
device = a.device
Expand All @@ -1827,11 +1823,20 @@ def prep_matmul(

# allocates output
assert a.dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
tl.float8e4b8,
] and b.dtype in [
]
assert b.dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
Expand Down

0 comments on commit 0385aa4

Please sign in to comment.