From 6fcdd4b2e085ff06c3357d5cfd37d0e41913e08d Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Fri, 17 May 2024 22:20:57 -0700 Subject: [PATCH] Back out "Support FP8 scale calculation with scalar and cleanup" Summary: There is garbage output issue because the new introduced Sm90ScalarBroadcast requires scalar scale that causes an issue under cudagraph. Reverting the changes in D57367680 resolves the issue. Will follow up on adding more unittests to cover similar issues and will test E2E before make those changes. Differential Revision: D57521470 --- .../experimental/gen_ai/bench/ck_fp8_bench.py | 2 +- .../gen_ai/src/quantize/cutlass_extensions.cu | 204 +++++++++++++++++- .../gen_ai/src/quantize/quantize.cpp | 60 +++--- .../gen_ai/src/quantize/quantize.cu | 96 +-------- .../gen_ai/test/quantize/quantize_test.py | 14 +- 5 files changed, 241 insertions(+), 135 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py index 37b2466486..66b75b8233 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py @@ -69,7 +69,7 @@ class CKMatmul(torch.nn.Module): def forward( self, a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor ) -> torch.Tensor: - out = torch.ops.fbgemm.f8f8bf16(a, b, scale) + out = torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale) return out diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu index 367688f2a3..d973d28ace 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu @@ -601,9 +601,200 @@ template < int TBS_M, int TBS_N, int TBS_K, - bool PONG, bool FAST_ACCUM> at::Tensor f8f8bf16_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor scale) { + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementInputA = cutlass::float_e4m3_t; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = + 128 / + cutlass::sizeof_bits< + ElementInputA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = + 128 / + cutlass::sizeof_bits< + ElementInputB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = + 128 / + cutlass::sizeof_bits< + ElementOutput>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + using MainLoopSchedule = cute::conditional_t< + FAST_ACCUM, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecialized>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAuto, + MainLoopSchedule>::CollectiveOp; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{})); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{})); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{})); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K}, + {reinterpret_cast(WQ.data_ptr()), + stride_b, + reinterpret_cast(XQ.data_ptr()), + stride_a}, + {{scale.data_ptr(), 0}, + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +at::Tensor f8f8bf16( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor scale, + bool use_fast_accum) { + auto M = XQ.size(0); + // auto K = XQ.size(1); + // auto N = WQ.size(0); + if (use_fast_accum) { + if (M <= 128) { + return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true>(XQ, WQ, scale); + } else { + return f8f8bf16_impl<128, 128, 128, 1, 2, 1, true>(XQ, WQ, scale); + } + } else { + if (M <= 128) { + return f8f8bf16_impl<64, 128, 128, 2, 1, 1, false>(XQ, WQ, scale); + } else { + return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false>(XQ, WQ, scale); + } + } +} + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM> +at::Tensor f8f8bf16_tensorwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 double scale) { @@ -787,18 +978,21 @@ at::Tensor f8f8bf16_impl( return Y; } -at::Tensor f8f8bf16( +at::Tensor f8f8bf16_tensorwise( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 double scale, bool use_fast_accum) { KernelMode kernel = get_kernel_mode(XQ, WQ); if (kernel == KernelMode::Small) { - return f8f8bf16_impl<64, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>( + XQ, WQ, scale); } else if (kernel == KernelMode::Large) { - return f8f8bf16_impl<128, 128, 128, 2, 1, 1, true, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>( + XQ, WQ, scale); } else { - return f8f8bf16_impl<128, 128, 128, 1, 2, 1, false, true>(XQ, WQ, scale); + return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>( + XQ, WQ, scale); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 70bd3a388f..7c1d80b3b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -34,6 +34,11 @@ at::Tensor silu_mul_quantize_i8(at::Tensor X1, at::Tensor X2, double scale); // Cutlass kernel at::Tensor f8f8bf16( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor scale, + bool use_fast_accum = true); +at::Tensor f8f8bf16_tensorwise( at::Tensor XQ, at::Tensor WQ, double scale, @@ -62,12 +67,7 @@ at::Tensor f8i4bf16_rowwise( at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale); std::tuple per_tensor_dynamic_quantize_i8(at::Tensor X); -std::tuple quantize_fp8_per_tensor( - at::Tensor input, - c10::optional bs, // batch size - c10::optional scale_ub); // scale upperbound - -std::tuple quantize_fp8_per_tensor_tensor_scale( +std::vector quantize_fp8_per_tensor( at::Tensor input, c10::optional bs, // batch size c10::optional scale_ub); // scale upperbound @@ -102,6 +102,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // torch.ops.load_library, similar to below for quantize_fp8_per_tensor m.def("i8i8bf16(Tensor XQ, Tensor WQ, float scale, int split_k=1) -> Tensor"); + m.def( + "f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor"); + m.def( "f8f8bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True) -> Tensor"); @@ -118,7 +121,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { #endif m.def( - "f8f8bf16(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); + "f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor"); m.impl("per_tensor_quantize_i8", per_tensor_quantize_i8); m.def("per_tensor_dynamic_quantize_i8(Tensor X) -> (Tensor, Tensor)"); @@ -135,13 +138,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // quantize_ops with // torch.ops.load_library m.def( - "quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, float)"); + "quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> Tensor[]"); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); - m.def( - "quantize_fp8_per_tensor_tensor_scale(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> (Tensor, Tensor)"); - m.impl( - "quantize_fp8_per_tensor_tensor_scale", - quantize_fp8_per_tensor_tensor_scale); m.def( "quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None) -> Tensor[]"); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); @@ -165,14 +163,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { } TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { - m.impl("f8f8bf16", f8f8bf16); + m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); - m.impl( - "quantize_fp8_per_tensor_tensor_scale", - quantize_fp8_per_tensor_tensor_scale); + m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); #endif } @@ -201,22 +197,13 @@ at::Tensor f8f8bf16_rowwise_meta( return Y; } -std::tuple quantize_fp8_per_tensor_meta( - at::Tensor X, - c10::optional bs, - c10::optional scale_ub) { - auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn)); - auto scale = 0.0; - return std::tuple{Y, scale}; -} - -std::tuple quantize_fp8_per_tensor_tensor_scale_meta( +std::vector quantize_fp8_per_tensor_meta( at::Tensor X, c10::optional bs, c10::optional scale_ub) { auto Y = at::empty_like(X, X.options().dtype(at::kFloat8_e4m3fn)); auto scale = at::empty({}, X.options().dtype(at::kBFloat16)); - return std::tuple{Y, scale}; + return {Y, scale}; } at::Tensor f8f8bf16_cublas_meta( @@ -233,6 +220,17 @@ at::Tensor f8f8bf16_cublas_meta( } at::Tensor f8f8bf16_meta( + at::Tensor X, + at::Tensor W, + at::Tensor scale, + bool use_fast_accum = true) { + const at::SymInt M = X.sym_size(0); + const at::SymInt N = W.sym_size(0); + auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16)); + return Y; +} + +at::Tensor f8f8bf16_tensorwise_meta( at::Tensor X, at::Tensor W, double scale, @@ -256,14 +254,12 @@ at::Tensor f8i4bf16_rowwise_meta( } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { - m.impl("f8f8bf16", f8f8bf16_meta); + m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta); - m.impl( - "quantize_fp8_per_tensor_tensor_scale", - quantize_fp8_per_tensor_tensor_scale_meta); + m.impl("f8f8bf16", f8f8bf16_meta); m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu index aad3d9bb02..d023936986 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu @@ -680,7 +680,7 @@ at::Tensor quantize_fp8_per_tensor_fixed_scale( // TODO: Extend to support other data types for other // usecases/models when needed -std::tuple quantize_fp8_per_tensor( +std::vector quantize_fp8_per_tensor( at::Tensor input, c10::optional bs, // batch size c10::optional scale_ub) // scale upperbound) @@ -756,89 +756,7 @@ std::tuple quantize_fp8_per_tensor( input.size(-1), stream); } - float scales_host; - C10_CUDA_CHECK(cudaMemcpyAsync( - &scales_host, scales.data_ptr(), sizeof(float), cudaMemcpyDeviceToHost)); - return std::tuple{quantized_input, scales_host}; -} - -std::tuple quantize_fp8_per_tensor_tensor_scale( - at::Tensor input, - c10::optional bs, // batch size - c10::optional scale_ub) // scale upperbound) -{ - CUDA_DEVICE_GUARD(input); - TORCH_CHECK(input.numel() != 0, "input should not be empty tensor"); - TORCH_CHECK( - input.dim() >= 2, - "Invalid dim. The dim of input should be greater than or equal to 2"); - auto _st = input.scalar_type(); - TORCH_CHECK(_st == torch::kBFloat16, "Invalid datatype. input must be BF16"); - std::vector quantized_input_shape; - quantized_input_shape.reserve(input.dim()); - for (int i = 0; i < input.dim(); i++) { - quantized_input_shape.push_back(input.size(i)); - } - std::vector scale_shape = {1}; - input = input.cuda(); - at::Tensor quantized_input = torch::empty( - quantized_input_shape, - torch::dtype(torch::kFloat8_e4m3fn) - .device(torch::kCUDA, at::cuda::current_device()) - .requires_grad(false)); - at::Tensor scales = torch::empty( - scale_shape, - torch::dtype(torch::kFloat32) - .device(torch::kCUDA, at::cuda::current_device()) - .requires_grad(false)); - auto* const quantized_input_ptr = - reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr()); - const auto stream = at::cuda::getCurrentCUDAStream(); - if (bs.has_value()) { - int64_t total_elements_per_slice = quantized_input_shape[0]; - for (int i = 1; i < input.dim() - 1; i++) { - total_elements_per_slice = - total_elements_per_slice * quantized_input_shape[i]; - } - invokeComputeScale( - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - total_elements_per_slice, - reinterpret_cast(bs.value().data_ptr()), - scale_ub.has_value() - ? reinterpret_cast(scale_ub.value().data_ptr()) - : nullptr, - stream); - invokeQuantizeMatrix( - quantized_input_ptr, - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - stream); - } else { - invokeComputeScale( - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - -1, - nullptr, - scale_ub.has_value() - ? reinterpret_cast(scale_ub.value().data_ptr()) - : nullptr, - stream); - invokeQuantizeMatrix( - quantized_input_ptr, - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input.data_ptr()), - input.numel(), - input.size(-1), - stream); - } - return std::tuple{quantized_input, scales}; + return std::vector{quantized_input, scales}; } template @@ -1094,15 +1012,7 @@ std::vector quantize_fp8_per_col( } #else -std::tuple quantize_fp8_per_tensor( - at::Tensor input, - c10::optional bs, // batch size - c10::optional scale_ub) { // scale upperbound - throw std::runtime_error( - "CUDA version is older than 12.0"); // requires CUDA>=12 -} - -std::tuple quantize_fp8_per_tensor_tensor_scale( +std::vector quantize_fp8_per_tensor( at::Tensor input, c10::optional bs, // batch size c10::optional scale_ub) { // scale upperbound diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index be82947f9f..6274543363 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -88,9 +88,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: wq = (w * fp8_max / w_max).to(fp8_e4m3) if kernel == "cutlass": - zq = torch.ops.fbgemm.f8f8bf16( - xq, wq, (x_scale * w_scale).item(), use_fast_accum - ) + zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale, use_fast_accum) else: zq = torch.ops.fbgemm.f8f8bf16_cublas( xq, wq, x_scale, w_scale, use_fast_accum @@ -112,7 +110,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: B_T=st.sampled_from([2048, 4096]), D=st.sampled_from([128, 256]), HD_L=st.sampled_from([256, 512]), - Mode=st.sampled_from(["tensorwise", "rowwise"]), + Mode=st.sampled_from(["tensorwise", "tensorwise_broadcast", "rowwise"]), QType=st.sampled_from([torch.float8_e4m3fn, torch.float8_e5m2]), Bias=st.sampled_from([True, False]), ) @@ -133,6 +131,14 @@ def test_quantize_fp8_matmul( zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) if bias is not None: zq += bias + elif Mode == "tensorwise_broadcast": + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) + zq = torch.ops.fbgemm.f8f8bf16_tensorwise( + xq, wq, (x_scale * w_scale).item() + ) + if bias is not None: + zq += bias elif Mode == "rowwise": xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, output_dtype=QType) wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)