diff --git a/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py index 37b2466486..66b75b8233 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py @@ -69,7 +69,7 @@ class CKMatmul(torch.nn.Module): def forward( self, a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor ) -> torch.Tensor: - out = torch.ops.fbgemm.f8f8bf16(a, b, scale) + out = torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale) return out diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu index 367688f2a3..d973d28ace 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu @@ -601,9 +601,200 @@ template < int TBS_M, int TBS_N, int TBS_K, - bool PONG, bool FAST_ACCUM> at::Tensor f8f8bf16_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor scale) { + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementInputA = cutlass::float_e4m3_t; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = + 128 / + cutlass::sizeof_bits< + ElementInputA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = + 128 / + cutlass::sizeof_bits< + ElementInputB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = + 128 / + cutlass::sizeof_bits< + ElementOutput>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + using MainLoopSchedule = cute::conditional_t< + FAST_ACCUM, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecialized>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAuto, + MainLoopSchedule>::CollectiveOp; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{})); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{})); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{})); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K}, + {reinterpret_cast(WQ.data_ptr()), + stride_b, + reinterpret_cast(XQ.data_ptr()), + stride_a}, + {{scale.data_ptr(), 0}, + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +at::Tensor f8f8bf16( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor scale, + bool use_fast_accum) { + auto M = XQ.size(0); + // auto K = XQ.size(1); + // auto N = WQ.size(0); + if (use_fast_accum) { + if (M <= 128) { + return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true>(XQ, WQ, scale); + } else { + return f8f8bf16_impl<128, 128, 128, 1, 2, 1, true>(XQ, WQ, scale); + } + } else { + if (M <= 128) { + return f8f8bf16_impl<64, 128, 128, 2, 1, 1, false>(XQ, WQ, scale); + } else { + return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false>(XQ, WQ, scale); + } + } +} + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM> +at::Tensor f8f8bf16_tensorwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 double scale) { @@ -787,18 +978,21 @@ at::Tensor f8f8bf16_impl( return Y; } -at::Tensor f8f8bf16( +at::Tensor f8f8bf16_tensorwise( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 double scale, bool use_fast_accum) { KernelMode kernel = get_kernel_mode(XQ, WQ); if (kernel == KernelMode::Small) { - return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>( + XQ, WQ, scale); } else if (kernel == KernelMode::Large) { - return f8f8bf16_impl<128, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>( + XQ, WQ, scale); } else { - return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>( + XQ, WQ, scale); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 70bd3a388f..7c1d80b3b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -34,6 +34,11 @@ at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale); // Cutlass kernel at::Tensor f8f8bf16( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor scale, + bool use_fast_accum = true); +at::Tensor f8f8bf16_tensorwise( at::Tensor XQ, at::Tensor WQ, double scale, @@ -62,12 +67,7 @@ at::Tensor f8i4bf16_rowwise( at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale); std::tuple per_tensor_dynamic_quantize_i8(at::Tensor X); -std::tuple quantize_fp8_per_tensor( - at::Tensor input, - c10::optional bs, // batch size - c10::optional scale_ub); // scale upperbound - -std::tuple quantize_fp8_per_tensor_tensor_scale( +std::vector quantize_fp8_per_tensor( at::Tensor input, c10::optional bs, // batch size c10::optional scale_ub); // scale upperbound @@ -102,6 +102,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // torch.ops.load_library, similar to below for quantize_fp8_per_tensor m.def("i8i8bf16(Tensor XQ, Tensor WQ, float scale, int split_k=1) -> Tensor"); + m.def( + "f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor"); + m.def( "f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor"); @@ -118,7 +121,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { #endif m.def( - "f8f8bf16(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); + "f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor"); m.impl("per_tensor_quantize_i8", per_tensor_quantize_i8); m.def("per_tensor_dynamic_quantize_i8(Tensor X) -> (Tensor, Tensor)"); @@ -135,13 +138,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // quantize_ops with // torch.ops.load_library m.def( - "quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, float)"); + "quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> Tensor[]"); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); - m.def( - "quantize_fp8_per_tensor_tensor_scale(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, Tensor)"); - m.impl( - "quantize_fp8_per_tensor_tensor_scale", - quantize_fp8_per_tensor_tensor_scale); m.def( "quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None) -> Tensor[]"); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); @@ -165,14 +163,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { } TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { - m.impl("f8f8bf16", f8f8bf16); + m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); - m.impl( - "quantize_fp8_per_tensor_tensor_scale", - quantize_fp8_per_tensor_tensor_scale); + m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); #endif } @@ -201,22 +197,13 @@ at::Tensor f8f8bf16_rowwise_meta( return Y; } -std::tuple quantize_fp8_per_tensor_meta( - at::Tensor X, - c10::optional bs, - c10::optional scale_ub) { - auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn)); - auto scale = 0.0; - return std::tuple{Y, scale}; -} - -std::tuple quantize_fp8_per_tensor_tensor_scale_meta( +std::vector quantize_fp8_per_tensor_meta( at::Tensor X, c10::optional bs, c10::optional scale_ub) { auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn)); auto scale = at::empty({}, X.options().dtype(at::kBFloat16)); - return std::tuple{Y, scale}; + return {Y, scale}; } at::Tensor f8f8bf16_cublas_meta( @@ -233,6 +220,17 @@ at::Tensor f8f8bf16_cublas_meta( } at::Tensor f8f8bf16_meta( + at::Tensor X, + at::Tensor W, + at::Tensor scale, + bool use_fast_accum = true) { + const at::SymInt M = X.sym_size(0); + const at::SymInt N = W.sym_size(0); + auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16)); + return Y; +} + +at::Tensor f8f8bf16_tensorwise_meta( at::Tensor X, at::Tensor W, double scale, @@ -256,14 +254,12 @@ at::Tensor f8i4bf16_rowwise_meta( } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { - m.impl("f8f8bf16", f8f8bf16_meta); + m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta); - m.impl( - "quantize_fp8_per_tensor_tensor_scale", - quantize_fp8_per_tensor_tensor_scale_meta); + m.impl("f8f8bf16", f8f8bf16_meta); m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu index aad3d9bb02..d023936986 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu @@ -680,7 +680,7 @@ at::Tensor quantize_fp8_per_tensor_fixed_scale( // TODO: Extend to support other data types for other // usecases/models when needed -std::tuple quantize_fp8_per_tensor( +std::vector quantize_fp8_per_tensor( at::Tensor input, c10::optional bs, // batch size c10::optional scale_ub) // scale upperbound) @@ -756,89 +756,7 @@ std::tuple quantize_fp8_per_tensor( input.size(-1), stream); } - float scales_host; - C10_CUDA_CHECK(cudaMemcpyAsync( - &scales_host, scales.data_ptr(), sizeof(float), cudaMemcpyDeviceToHost)); - return std::tuple{quantized_input, scales_host}; -} - -std::tuple quantize_fp8_per_tensor_tensor_scale( - at::Tensor input, - c10::optional bs, // batch size - c10::optional scale_ub) // scale upperbound) -{ - CUDA_DEVICE_GUARD(input); - TORCH_CHECK(input.numel() != 0, "input should not be empty tensor"); - TORCH_CHECK( - input.dim() >= 2, - "Invalid dim. The dim of input should be greater than or equal to 2"); - auto _st = input.scalar_type(); - TORCH_CHECK(_st == torch::kBFloat16, "Invalid datatype. input must be BF16"); - std::vector quantized_input_shape; - quantized_input_shape.reserve(input.dim()); - for (int i = 0; i < input.dim(); i++) { - quantized_input_shape.push_back(input.size(i)); - } - std::vector scale_shape = {1}; - input = input.cuda(); - at::Tensor quantized_input = torch::empty( - quantized_input_shape, - torch::dtype(torch::kFloat8_e4m3fn) - .device(torch::kCUDA, at::cuda::current_device()) - .requires_grad(false)); - at::Tensor scales = torch::empty( - scale_shape, - torch::dtype(torch::kFloat32) - .device(torch::kCUDA, at::cuda::current_device()) - .requires_grad(false)); - auto* const quantized_input_ptr = - reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr()); - const auto stream = at::cuda::getCurrentCUDAStream(); - if (bs.has_value()) { - int64_t total_elements_per_slice = quantized_input_shape[0]; - for (int i = 1; i < input.dim() - 1; i++) { - total_elements_per_slice = - total_elements_per_slice * quantized_input_shape[i]; - } - invokeComputeScale( - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - total_elements_per_slice, - reinterpret_cast(bs.value().data_ptr()), - scale_ub.has_value() - ? reinterpret_cast(scale_ub.value().data_ptr()) - : nullptr, - stream); - invokeQuantizeMatrix( - quantized_input_ptr, - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - stream); - } else { - invokeComputeScale( - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - -1, - nullptr, - scale_ub.has_value() - ? reinterpret_cast(scale_ub.value().data_ptr()) - : nullptr, - stream); - invokeQuantizeMatrix( - quantized_input_ptr, - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - stream); - } - return std::tuple{quantized_input, scales}; + return std::vector{quantized_input, scales}; } template @@ -1094,15 +1012,7 @@ std::vector quantize_fp8_per_col( } #else -std::tuple quantize_fp8_per_tensor( - at::Tensor input, - c10::optional bs, // batch size - c10::optional scale_ub) { // scale upperbound - throw std::runtime_error( - "CUDA version is older than 12.0"); // requires CUDA>=12 -} - -std::tuple quantize_fp8_per_tensor_tensor_scale( +std::vector quantize_fp8_per_tensor( at::Tensor input, c10::optional bs, // batch size c10::optional scale_ub) { // scale upperbound diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index be82947f9f..6274543363 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -88,9 +88,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: wq = (w * fp8_max / w_max).to(fp8_e4m3) if kernel == "cutlass": - zq = torch.ops.fbgemm.f8f8bf16( - xq, wq, (x_scale * w_scale).item(), use_fast_accum - ) + zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale, use_fast_accum) else: zq = torch.ops.fbgemm.f8f8bf16_cublas( xq, wq, x_scale, w_scale, use_fast_accum @@ -112,7 +110,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: B_T=st.sampled_from([2048, 4096]), D=st.sampled_from([128, 256]), HD_L=st.sampled_from([256, 512]), - Mode=st.sampled_from(["tensorwise", "rowwise"]), + Mode=st.sampled_from(["tensorwise", "tensorwise_broadcast", "rowwise"]), QType=st.sampled_from([torch.float8_e4m3fn, torch.float8_e5m2]), Bias=st.sampled_from([True, False]), ) @@ -133,6 +131,14 @@ def test_quantize_fp8_matmul( zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) if bias is not None: zq += bias + elif Mode == "tensorwise_broadcast": + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) + zq = torch.ops.fbgemm.f8f8bf16_tensorwise( + xq, wq, (x_scale * w_scale).item() + ) + if bias is not None: + zq += bias elif Mode == "rowwise": xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, output_dtype=QType) wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)