diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_grouped.cu index b79c468fd..8e0769ab0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_grouped.cu @@ -182,6 +182,81 @@ __global__ void set_kernel_args_kernel( } } +__global__ void set_dynamic_kernel_args_kernel( + int64_t xq_ptr, + int64_t wq_ptr, + int64_t scale_ptr, + int64_t* input_args_ptr, + int64_t* output_args_ptr, + at::BFloat16* output_data, + int output_offset, + int xq_ptr_offset, + int wq_ptr_offset, + int scale_ptr_offset, + int problem_shape_buf_offset, + int stride_buf_offset, + int stride_size, + int problem_count, + int problem_shape_size, + int group_index, + int* zero_start_index_M, + int N, + int K) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // Each kernel annoyingly can only set the kernel args for one group. + // This could only be avoided with complicated memory management. + if (idx == 0) { + int64_t* xq_ptr_ = input_args_ptr + xq_ptr_offset; + int64_t* wq_ptr_ = input_args_ptr + wq_ptr_offset; + int64_t* scale_ptr_ = input_args_ptr + scale_ptr_offset; + uint8_t* problem_shape_buf = + reinterpret_cast(input_args_ptr + problem_shape_buf_offset); + uint8_t* stride_buf = + reinterpret_cast(input_args_ptr + stride_buf_offset); + + GroupedGemmArgs::ProblemShape::UnderlyingProblemShape* problem_shape_ptr = + reinterpret_cast< + GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>( + problem_shape_buf); + // Pass dummy configs to get Stride structure + GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + StrideInputA* stride_input_A_ptr = reinterpret_cast< + GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + StrideInputA*>(stride_buf); + GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + StrideInputB* stride_input_B_ptr = reinterpret_cast< + GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + StrideInputB*>(stride_buf + stride_size); + GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + StrideOutput* stride_output_ptr = reinterpret_cast< + GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + StrideOutput*>(stride_buf + (stride_size * 2)); + + output_args_ptr[group_index] = + reinterpret_cast(output_data + output_offset); + + // Write kernel arguments directly to memory. + xq_ptr_[group_index] = xq_ptr; + wq_ptr_[group_index] = wq_ptr; + scale_ptr_[group_index] = scale_ptr; + problem_shape_ptr[group_index] = + GroupedGemmArgs::ProblemShape::UnderlyingProblemShape( + zero_start_index_M[group_index], N, K); + stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride( + typename GroupedGemmArgs:: + GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{}, + {zero_start_index_M[group_index], K, 1}); + stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride( + typename GroupedGemmArgs:: + GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{}, + {N, K, 1}); + stride_output_ptr[group_index] = cutlass::make_cute_packed_stride( + typename GroupedGemmArgs:: + GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{}, + {zero_start_index_M[group_index], N, 1}); + } +} + template < int TB_M, int TB_N, @@ -194,7 +269,8 @@ template < std::vector f8f8bf16_grouped_impl( const std::vector& XQ, // FP8 const std::vector& WQ, // FP8 - const std::vector& scale) { + const std::vector& scale, + std::optional zero_start_index_M) { int problem_count = XQ.size(); TORCH_CHECK(WQ.size() == problem_count); if (problem_count == 0) { @@ -253,7 +329,7 @@ std::vector f8f8bf16_grouped_impl( } at::Tensor output_tensor = - at::empty(total_output_size, XQ[0].options().dtype(at::kBFloat16)); + at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); int blockSize = 256; int numBlocks = 1; @@ -262,32 +338,55 @@ std::vector f8f8bf16_grouped_impl( // Set arguments for (int i = 0; i < problem_count; ++i) { - int M = XQ[i].size(0); int N = WQ[i].size(0); int K = XQ[i].size(1); TORCH_CHECK_EQ(WQ[i].size(1), K); - set_kernel_args_kernel<<>>( - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - reinterpret_cast( - scale[i].data_ptr()), - input_args.data_ptr(), - output_args.data_ptr(), - output_tensor.data_ptr(), - output_offset, - xq_ptr_offset, - wq_ptr_offset, - scale_ptr_offset, - problem_shape_buf_offset, - stride_buf_offset, - stride_size, - problem_count, - problem_shape_size, - i, - M, - N, - K); - + if (zero_start_index_M.has_value() == true) { + set_dynamic_kernel_args_kernel<<>>( + reinterpret_cast(XQ[i].data_ptr()), + reinterpret_cast(WQ[i].data_ptr()), + reinterpret_cast( + scale[i].data_ptr()), + input_args.data_ptr(), + output_args.data_ptr(), + output_tensor.data_ptr(), + output_offset, + xq_ptr_offset, + wq_ptr_offset, + scale_ptr_offset, + problem_shape_buf_offset, + stride_buf_offset, + stride_size, + problem_count, + problem_shape_size, + i, + reinterpret_cast(zero_start_index_M.value().data_ptr()), + N, + K); + } else { + int M = XQ[i].size(0); + set_kernel_args_kernel<<>>( + reinterpret_cast(XQ[i].data_ptr()), + reinterpret_cast(WQ[i].data_ptr()), + reinterpret_cast( + scale[i].data_ptr()), + input_args.data_ptr(), + output_args.data_ptr(), + output_tensor.data_ptr(), + output_offset, + xq_ptr_offset, + wq_ptr_offset, + scale_ptr_offset, + problem_shape_buf_offset, + stride_buf_offset, + stride_size, + problem_count, + problem_shape_size, + i, + M, + N, + K); + } output_offset += output_sizes[i]; } @@ -376,17 +475,18 @@ template std::vector dispatch_fp8_grouped_kernel( const std::vector& xq_group, // FP8 const std::vector& wq_group, // FP8 - const std::vector& scale) { + const std::vector& scale, + std::optional zero_start_index_M) { KernelMode kernel = get_grouped_kernel_mode(xq_group, wq_group); if (kernel == KernelMode::Small) { return f8f8bf16_grouped_impl<64, 128, 128, 2, 1, 1, true, FastAccum>( - xq_group, wq_group, scale); + xq_group, wq_group, scale, zero_start_index_M); } else if (kernel == KernelMode::Large) { return f8f8bf16_grouped_impl<128, 128, 128, 2, 1, 1, true, FastAccum>( - xq_group, wq_group, scale); + xq_group, wq_group, scale, zero_start_index_M); } else { return f8f8bf16_grouped_impl<128, 128, 128, 1, 2, 1, true, FastAccum>( - xq_group, wq_group, scale); + xq_group, wq_group, scale, zero_start_index_M); } } @@ -394,11 +494,14 @@ std::vector f8f8bf16_grouped( const std::vector& xq_group, // FP8 const std::vector& wq_group, // FP8 const std::vector& scale, + std::optional zero_start_index_M, bool use_fast_accum) { if (use_fast_accum) { - return dispatch_fp8_grouped_kernel(xq_group, wq_group, scale); + return dispatch_fp8_grouped_kernel( + xq_group, wq_group, scale, zero_start_index_M); } else { - return dispatch_fp8_grouped_kernel(xq_group, wq_group, scale); + return dispatch_fp8_grouped_kernel( + xq_group, wq_group, scale, zero_start_index_M); } } @@ -408,6 +511,7 @@ std::vector f8f8bf16_grouped( const std::vector& xq_group, // FP8 const std::vector& wq_group, // FP8 const std::vector& scale, + std::optional zero_start_index_M, bool use_fast_accum) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 3bb5f4580..8f94e4953 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -59,6 +59,7 @@ std::vector f8f8bf16_grouped( const std::vector& XQ, const std::vector& WQ, const std::vector& scale, + std::optional zero_start_index_M, bool use_fast_accum = true); at::Tensor f8f8bf16_rowwise( at::Tensor XQ, @@ -163,7 +164,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( - "f8f8bf16_grouped(Tensor[] XQ, Tensor[] XQ, Tensor[] scale, bool use_fast_accum=True) -> Tensor[]"); + "f8f8bf16_grouped(Tensor[] XQ, Tensor[] XQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]"); m.def( "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( @@ -429,6 +430,7 @@ std::vector f8f8bf16_grouped_meta( const std::vector& XQ, const std::vector& WQ, const std::vector& /* scale */, + std::optional /* zero_start_index_M = c10::nullopt */, bool /* use_fast_accum = true */) { std::vector Y; for (int i = 0; i < XQ.size(); i++) { diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index cb0ea938a..400356066 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -748,6 +748,7 @@ def fp8_loopover_bmm( N=st.sampled_from([1024, 6144]), K=st.sampled_from([512, 3584]), use_cudagraph=st.sampled_from([True, False]), + use_padding_zeros=st.sampled_from([True, False]), ) def test_fp8_grouped_gemm( self, @@ -756,8 +757,17 @@ def test_fp8_grouped_gemm( N: int, K: int, use_cudagraph: bool, + use_padding_zeros: bool, ) -> None: - ms = torch.randint(1, (M // 64) + 1, (G,), dtype=torch.int) * 64 + ms = ( + torch.randint( + (258 // 64) + 1 if use_padding_zeros else 1, + (M // 64) + 1, + (G,), + dtype=torch.int, + ) + * 64 + ) ns = torch.randint(1, (N // 64) + 1, (G,), dtype=torch.int) * 64 ks = torch.randint(1, (K // 64) + 1, (G,), dtype=torch.int) * 64 @@ -766,10 +776,26 @@ def test_fp8_grouped_gemm( xq_group = [] wq_group = [] scale_group = [] + zero_start_index_M = None + zero_start_index_M_value = M + + if use_padding_zeros: + zero_start_index_M_value = 256 + zero_start_index_M = torch.full( + size=(G,), + fill_value=zero_start_index_M_value, + dtype=torch.int, + device="cuda", + ) - for m, n, k in zip(ms, ns, ks): + for i, (m, n, k) in enumerate(zip(ms, ns, ks)): x = torch.rand(size=(m, k), dtype=torch.bfloat16, device="cuda") w = torch.rand(size=(n, k), dtype=torch.bfloat16, device="cuda") + + if use_padding_zeros: + # Zero out dim M from index zero_start_index_M_value + x[zero_start_index_M_value:, :] = 0 + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) x_group.append(x) @@ -781,16 +807,29 @@ def test_fp8_grouped_gemm( # FP8 grouped gemm kernel if use_cudagraph: # warmup - torch.ops.fbgemm.f8f8bf16_grouped(xq_group, wq_group, scale_group) + torch.ops.fbgemm.f8f8bf16_grouped( + xq_group, + wq_group, + scale_group, + zero_start_index_M if use_padding_zeros else None, + ) # With cudagraph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): y_group = torch.ops.fbgemm.f8f8bf16_grouped( - xq_group, wq_group, scale_group + xq_group, + wq_group, + scale_group, + zero_start_index_M if use_padding_zeros else None, ) g.replay() else: - y_group = torch.ops.fbgemm.f8f8bf16_grouped(xq_group, wq_group, scale_group) + y_group = torch.ops.fbgemm.f8f8bf16_grouped( + xq_group, + wq_group, + scale_group, + zero_start_index_M if use_padding_zeros else None, + ) # BF16 loopover gemm reference y_group_ref = []