From 04d3c589f1899368d4ccd194aff5ed178cfb537e Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 11 Sep 2024 06:25:56 -0700 Subject: [PATCH] Fp8 gemm tweak for 405B Decoding (#3104) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/192 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3104 Improve fp8 gemm for memory-bound cases from 405B decoding - improved 405B decoding GEMMs: - [1, 13312, 16384] from 74us to 71us - [1, 16384, 6656] from 34us to 31us - Adjust tuning parameters that improves memory load performance - Adjust pipeline for overlapping scaling and gemm Reviewed By: jianyuh, jwfromm Differential Revision: D62363038 fbshipit-source-id: 4beea7eb8c66605539f9dd675a61ba3fe8870a3f --- .../ck_extensions/fp8_rowwise_gemm.hip | 4 +- ...16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip | 3 +- ...32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 39 +++++++++++++++++++ .../kernels/fp8_rowwise_kernel_manifest.h | 8 ++++ 4 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index f96cafb05..afb306432 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -95,7 +95,7 @@ static const std::unordered_map< fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, // Support for decode across batch sizes for [13312, 16384]. {{16, 13312, 16384}, - fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2}, + fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2}, {{32, 13312, 16384}, fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2}, {{64, 13312, 16384}, @@ -117,7 +117,7 @@ static const std::unordered_map< fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, // Support for decode across batch sizes for [16384, 16384]. {{16, 16384, 16384}, - fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2}, + fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2}, {{32, 16384, 16384}, fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2}, {{64, 16384, 16384}, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip index be8e814ec..26ac8ab4e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip @@ -32,7 +32,8 @@ fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1 1, 1, ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1>; + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..a0fb48f16 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<32, 2, 1>, + S<32, 2, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h index 2b0753017..db38343dc 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h @@ -36,6 +36,14 @@ fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( at::Tensor w_scale, at::Tensor Y); +at::Tensor +fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + // Alternate tiny kernel that seems to do well when M and K are all small. at::Tensor fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2(