Skip to content

Commit

Permalink
Customize FP8 grouped GEMM for non-zero calculation for token choice …
Browse files Browse the repository at this point in the history
…MoE (#3383)

Summary:
Pull Request resolved: #3383

X-link: facebookresearch/FBGEMM#471

Customize FP8 grouped GEMM for non-zero calculation for token choice MoE with dynamic dim M

Reviewed By: jianyuh, xw285cornell

Differential Revision: D65989604

fbshipit-source-id: 440f567d9d867592f205505a888d99ce9bc8221f
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Nov 16, 2024
1 parent d5b938b commit 8812a95
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,81 @@ __global__ void set_kernel_args_kernel(
}
}

__global__ void set_dynamic_kernel_args_kernel(
int64_t xq_ptr,
int64_t wq_ptr,
int64_t scale_ptr,
int64_t* input_args_ptr,
int64_t* output_args_ptr,
at::BFloat16* output_data,
int output_offset,
int xq_ptr_offset,
int wq_ptr_offset,
int scale_ptr_offset,
int problem_shape_buf_offset,
int stride_buf_offset,
int stride_size,
int problem_count,
int problem_shape_size,
int group_index,
int* zero_start_index_M,
int N,
int K) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each kernel annoyingly can only set the kernel args for one group.
// This could only be avoided with complicated memory management.
if (idx == 0) {
int64_t* xq_ptr_ = input_args_ptr + xq_ptr_offset;
int64_t* wq_ptr_ = input_args_ptr + wq_ptr_offset;
int64_t* scale_ptr_ = input_args_ptr + scale_ptr_offset;
uint8_t* problem_shape_buf =
reinterpret_cast<uint8_t*>(input_args_ptr + problem_shape_buf_offset);
uint8_t* stride_buf =
reinterpret_cast<uint8_t*>(input_args_ptr + stride_buf_offset);

GroupedGemmArgs::ProblemShape::UnderlyingProblemShape* problem_shape_ptr =
reinterpret_cast<
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputA*>(stride_buf);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
reinterpret_cast<int64_t>(output_data + output_offset);

// Write kernel arguments directly to memory.
xq_ptr_[group_index] = xq_ptr;
wq_ptr_[group_index] = wq_ptr;
scale_ptr_[group_index] = scale_ptr;
problem_shape_ptr[group_index] =
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape(
zero_start_index_M[group_index], N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
{zero_start_index_M[group_index], K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
{zero_start_index_M[group_index], N, 1});
}
}

template <
int TB_M,
int TB_N,
Expand All @@ -194,7 +269,8 @@ template <
std::vector<at::Tensor> f8f8bf16_grouped_impl(
const std::vector<at::Tensor>& XQ, // FP8
const std::vector<at::Tensor>& WQ, // FP8
const std::vector<at::Tensor>& scale) {
const std::vector<at::Tensor>& scale,
std::optional<at::Tensor> zero_start_index_M) {
int problem_count = XQ.size();
TORCH_CHECK(WQ.size() == problem_count);
if (problem_count == 0) {
Expand Down Expand Up @@ -253,7 +329,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
}

at::Tensor output_tensor =
at::empty(total_output_size, XQ[0].options().dtype(at::kBFloat16));
at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16));

int blockSize = 256;
int numBlocks = 1;
Expand All @@ -262,32 +338,55 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(

// Set arguments
for (int i = 0; i < problem_count; ++i) {
int M = XQ[i].size(0);
int N = WQ[i].size(0);
int K = XQ[i].size(1);
TORCH_CHECK_EQ(WQ[i].size(1), K);
set_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<int64_t>(XQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(WQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(
scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>()),
input_args.data_ptr<int64_t>(),
output_args.data_ptr<int64_t>(),
output_tensor.data_ptr<at::BFloat16>(),
output_offset,
xq_ptr_offset,
wq_ptr_offset,
scale_ptr_offset,
problem_shape_buf_offset,
stride_buf_offset,
stride_size,
problem_count,
problem_shape_size,
i,
M,
N,
K);

if (zero_start_index_M.has_value() == true) {
set_dynamic_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<int64_t>(XQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(WQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(
scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>()),
input_args.data_ptr<int64_t>(),
output_args.data_ptr<int64_t>(),
output_tensor.data_ptr<at::BFloat16>(),
output_offset,
xq_ptr_offset,
wq_ptr_offset,
scale_ptr_offset,
problem_shape_buf_offset,
stride_buf_offset,
stride_size,
problem_count,
problem_shape_size,
i,
reinterpret_cast<int*>(zero_start_index_M.value().data_ptr()),
N,
K);
} else {
int M = XQ[i].size(0);
set_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<int64_t>(XQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(WQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(
scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>()),
input_args.data_ptr<int64_t>(),
output_args.data_ptr<int64_t>(),
output_tensor.data_ptr<at::BFloat16>(),
output_offset,
xq_ptr_offset,
wq_ptr_offset,
scale_ptr_offset,
problem_shape_buf_offset,
stride_buf_offset,
stride_size,
problem_count,
problem_shape_size,
i,
M,
N,
K);
}
output_offset += output_sizes[i];
}

Expand Down Expand Up @@ -376,29 +475,33 @@ template <bool FastAccum>
std::vector<at::Tensor> dispatch_fp8_grouped_kernel(
const std::vector<at::Tensor>& xq_group, // FP8
const std::vector<at::Tensor>& wq_group, // FP8
const std::vector<at::Tensor>& scale) {
const std::vector<at::Tensor>& scale,
std::optional<at::Tensor> zero_start_index_M) {
KernelMode kernel = get_grouped_kernel_mode(xq_group, wq_group);
if (kernel == KernelMode::Small) {
return f8f8bf16_grouped_impl<64, 128, 128, 2, 1, 1, true, FastAccum>(
xq_group, wq_group, scale);
xq_group, wq_group, scale, zero_start_index_M);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_grouped_impl<128, 128, 128, 2, 1, 1, true, FastAccum>(
xq_group, wq_group, scale);
xq_group, wq_group, scale, zero_start_index_M);
} else {
return f8f8bf16_grouped_impl<128, 128, 128, 1, 2, 1, true, FastAccum>(
xq_group, wq_group, scale);
xq_group, wq_group, scale, zero_start_index_M);
}
}

std::vector<at::Tensor> f8f8bf16_grouped(
const std::vector<at::Tensor>& xq_group, // FP8
const std::vector<at::Tensor>& wq_group, // FP8
const std::vector<at::Tensor>& scale,
std::optional<at::Tensor> zero_start_index_M,
bool use_fast_accum) {
if (use_fast_accum) {
return dispatch_fp8_grouped_kernel<true>(xq_group, wq_group, scale);
return dispatch_fp8_grouped_kernel<true>(
xq_group, wq_group, scale, zero_start_index_M);
} else {
return dispatch_fp8_grouped_kernel<false>(xq_group, wq_group, scale);
return dispatch_fp8_grouped_kernel<false>(
xq_group, wq_group, scale, zero_start_index_M);
}
}

Expand All @@ -408,6 +511,7 @@ std::vector<at::Tensor> f8f8bf16_grouped(
const std::vector<at::Tensor>& xq_group, // FP8
const std::vector<at::Tensor>& wq_group, // FP8
const std::vector<at::Tensor>& scale,
std::optional<at::Tensor> zero_start_index_M,
bool use_fast_accum) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
Expand Down
4 changes: 3 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ std::vector<at::Tensor> f8f8bf16_grouped(
const std::vector<at::Tensor>& XQ,
const std::vector<at::Tensor>& WQ,
const std::vector<at::Tensor>& scale,
std::optional<at::Tensor> zero_start_index_M,
bool use_fast_accum = true);
at::Tensor f8f8bf16_rowwise(
at::Tensor XQ,
Expand Down Expand Up @@ -163,7 +164,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
"f8f8bf16_grouped(Tensor[] XQ, Tensor[] XQ, Tensor[] scale, bool use_fast_accum=True) -> Tensor[]");
"f8f8bf16_grouped(Tensor[] XQ, Tensor[] XQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]");
m.def(
"bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
Expand Down Expand Up @@ -429,6 +430,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_meta(
const std::vector<at::Tensor>& XQ,
const std::vector<at::Tensor>& WQ,
const std::vector<at::Tensor>& /* scale */,
std::optional<at::Tensor> /* zero_start_index_M = c10::nullopt */,
bool /* use_fast_accum = true */) {
std::vector<at::Tensor> Y;
for (int i = 0; i < XQ.size(); i++) {
Expand Down
49 changes: 44 additions & 5 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def fp8_loopover_bmm(
N=st.sampled_from([1024, 6144]),
K=st.sampled_from([512, 3584]),
use_cudagraph=st.sampled_from([True, False]),
use_padding_zeros=st.sampled_from([True, False]),
)
def test_fp8_grouped_gemm(
self,
Expand All @@ -756,8 +757,17 @@ def test_fp8_grouped_gemm(
N: int,
K: int,
use_cudagraph: bool,
use_padding_zeros: bool,
) -> None:
ms = torch.randint(1, (M // 64) + 1, (G,), dtype=torch.int) * 64
ms = (
torch.randint(
(258 // 64) + 1 if use_padding_zeros else 1,
(M // 64) + 1,
(G,),
dtype=torch.int,
)
* 64
)
ns = torch.randint(1, (N // 64) + 1, (G,), dtype=torch.int) * 64
ks = torch.randint(1, (K // 64) + 1, (G,), dtype=torch.int) * 64

Expand All @@ -766,10 +776,26 @@ def test_fp8_grouped_gemm(
xq_group = []
wq_group = []
scale_group = []
zero_start_index_M = None
zero_start_index_M_value = M

if use_padding_zeros:
zero_start_index_M_value = 256
zero_start_index_M = torch.full(
size=(G,),
fill_value=zero_start_index_M_value,
dtype=torch.int,
device="cuda",
)

for m, n, k in zip(ms, ns, ks):
for i, (m, n, k) in enumerate(zip(ms, ns, ks)):
x = torch.rand(size=(m, k), dtype=torch.bfloat16, device="cuda")
w = torch.rand(size=(n, k), dtype=torch.bfloat16, device="cuda")

if use_padding_zeros:
# Zero out dim M from index zero_start_index_M_value
x[zero_start_index_M_value:, :] = 0

xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
x_group.append(x)
Expand All @@ -781,16 +807,29 @@ def test_fp8_grouped_gemm(
# FP8 grouped gemm kernel
if use_cudagraph:
# warmup
torch.ops.fbgemm.f8f8bf16_grouped(xq_group, wq_group, scale_group)
torch.ops.fbgemm.f8f8bf16_grouped(
xq_group,
wq_group,
scale_group,
zero_start_index_M if use_padding_zeros else None,
)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_group = torch.ops.fbgemm.f8f8bf16_grouped(
xq_group, wq_group, scale_group
xq_group,
wq_group,
scale_group,
zero_start_index_M if use_padding_zeros else None,
)
g.replay()
else:
y_group = torch.ops.fbgemm.f8f8bf16_grouped(xq_group, wq_group, scale_group)
y_group = torch.ops.fbgemm.f8f8bf16_grouped(
xq_group,
wq_group,
scale_group,
zero_start_index_M if use_padding_zeros else None,
)

# BF16 loopover gemm reference
y_group_ref = []
Expand Down

0 comments on commit 8812a95

Please sign in to comment.