Skip to content

Commit

Permalink
Fp8 gemm tweak for 405B Decoding (#3104)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#192

Pull Request resolved: #3104

Improve fp8 gemm for memory-bound cases from 405B decoding
- improved 405B decoding GEMMs:
  - [1, 13312, 16384] from 74us to 71us
  - [1, 16384, 6656] from 34us to 31us
- Adjust tuning parameters that improves memory load performance
- Adjust pipeline for overlapping scaling and gemm

Reviewed By: jianyuh, jwfromm

Differential Revision: D62363038

fbshipit-source-id: 4beea7eb8c66605539f9dd675a61ba3fe8870a3f
  • Loading branch information
zjing14 authored and facebook-github-bot committed Sep 11, 2024
1 parent a2b62a0 commit 04d3c58
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ static const std::unordered_map<
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
// Support for decode across batch sizes for [13312, 16384].
{{16, 13312, 16384},
fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2},
{{32, 13312, 16384},
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2},
{{64, 13312, 16384},
Expand All @@ -117,7 +117,7 @@ static const std::unordered_map<
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
// Support for decode across batch sizes for [16384, 16384].
{{16, 16384, 16384},
fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2},
{{32, 16384, 16384},
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2},
{{64, 16384, 16384},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1>;
ck::BlockGemmPipelineVersion::v1,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// The smallest kernel we have available. Works well for memory bound shapes.
using DeviceGemmInstance = DeviceGemmHelper<
64,
16,
16,
512,
16,
16,
1,
1,
S<32, 2, 1>,
S<32, 2, 1>,
S<1, 16, 1, 4>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2(
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

// Alternate tiny kernel that seems to do well when M and K are all small.
at::Tensor
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
Expand Down

0 comments on commit 04d3c58

Please sign in to comment.