diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index 4b9946872..3450d2cad 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -126,13 +126,33 @@ static const std::unordered_map< fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, // EMU 1.6 Shapes. {{1536, 3584, 3584}, - fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{8192, 9728, 3584}, - fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, {{8192, 3584, 9728}, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5}, {{8192, 3584, 3584}, - fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{768, 3584, 3584}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{4096, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{4096, 3584, 9728}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{7200, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{7200, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{7200, 3584, 9728}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{3600, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{3600, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{3600, 3584, 9728}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, // Pro Shapes. {{32768, 128, 8192}, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..acf927c05 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 128 != 0) || (N % 32 != 0) || (K % 128 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip new file mode 100644 index 000000000..b47c082fd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 256 != 0) || (N % 256 != 0) || (K % 64 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h index db38343dc..146450ec1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h @@ -231,3 +231,21 @@ fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y); + +// Decent mid-size kernel. +at::Tensor +fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +// Kernel for small but not too small batch sizes. +at::Tensor +fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y);