From 012a658e91cdeeb994a07fd8f5309756726e0ebd Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 23 Sep 2024 17:18:10 -0700 Subject: [PATCH] Hardcoded tuning for FBGEMM/fp8 for 70B/405B Prefill with T=1..8K (#3164) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3164 X-link: https://github.com/facebookresearch/FBGEMM/pull/258 Explicitly add instances for 405/70B Prefill with T=1024,2048,4096,8192 Reviewed By: jwfromm Differential Revision: D63234720 fbshipit-source-id: 5602cbb5428e7515dbc64da58dafca7fed2a26c4 --- .../ck_extensions/fp8_rowwise_gemm.hip | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index 3450d2cad..b076e281e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -72,8 +72,14 @@ static const std::unordered_map< fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2}, {{128, 7168, 8192}, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{1024, 7168, 8192}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5}, {{2048, 7168, 8192}, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 7168, 8192}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 7168, 8192}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Support for decode across batch sizes for [8192, 3584] {{16, 8192, 3584}, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2}, @@ -83,6 +89,14 @@ static const std::unordered_map< fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2}, {{128, 8192, 3584}, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{1024, 8192, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{2048, 8192, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 8192, 3584}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 8192, 3584}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Llama 405B Decode Shapes. // Support for decode across batch sizes for [13312, 6656]. {{16, 13312, 6656}, @@ -102,8 +116,14 @@ static const std::unordered_map< fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{128, 13312, 16384}, fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, - {{2048, 13312, 16384}, + {{1024, 13312, 16384}, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{2048, 13312, 16384}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{4096, 13312, 16384}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 13312, 16384}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Support for decode across batch sizes for [16384, 6656]. {{16, 16384, 6656}, fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1}, @@ -113,8 +133,14 @@ static const std::unordered_map< fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{128, 16384, 6656}, fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{1024, 16384, 6656}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, {{2048, 16384, 6656}, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 16384, 6656}, + fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{8192, 16384, 6656}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Support for decode across batch sizes for [16384, 16384]. {{16, 16384, 16384}, fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2},