Skip to content

Commit

Permalink
EMU1.6 CK FP8 Tuning (pytorch#3138)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3138

X-link: facebookresearch/FBGEMM#232

This diff adds a few more tuned shapes to our CK FP8 Kernel dispatch for better EMU1.6 performance. The before / after performance can be seen [here](https://docs.google.com/spreadsheets/d/1SAymyghA8V0ZXD1G7ButMy7GzohYogPISlQ6rHVtVdE/edit?usp=sharing).

The quick summary is that we see small to medium improvements across all EMU shapes. E2E performance should improve a bit but not massively as it seems the heuristics did acceptably in this case.

Reviewed By: mxz297

Differential Revision: D62761000

fbshipit-source-id: 378fabb7cb2e37ffc415b7e464044a266a24ec9c
  • Loading branch information
jwfromm authored and facebook-github-bot committed Sep 16, 2024
1 parent 49fa9a5 commit 9634774
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,33 @@ static const std::unordered_map<
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
// EMU 1.6 Shapes.
{{1536, 3584, 3584},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{8192, 9728, 3584},
fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{8192, 3584, 9728},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5},
{{8192, 3584, 3584},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{4096, 3584, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{768, 3584, 3584},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{4096, 9728, 3584},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{4096, 3584, 9728},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{7200, 3584, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{7200, 9728, 3584},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{7200, 3584, 9728},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{3600, 3584, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{3600, 9728, 3584},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{3600, 3584, 9728},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
// Pro Shapes.
{{32768, 128, 8192},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 128 != 0) || (N % 32 != 0) || (K % 128 != 0);

// This kernel seems optimal in the most purely compute bound tasks.
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
128,
32,
128,
32,
32,
2,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v2>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
128,
128,
32,
128,
32,
32,
2,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 256 != 0) || (N % 256 != 0) || (K % 64 != 0);

// This kernel seems optimal in the most purely compute bound tasks.
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
256,
64,
16,
16,
8,
8,
S<4, 64, 1>,
S<4, 64, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
256,
64,
16,
16,
8,
8,
S<4, 64, 1>,
S<4, 64, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,21 @@ fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

// Decent mid-size kernel.
at::Tensor
fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

// Kernel for small but not too small batch sizes.
at::Tensor
fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

0 comments on commit 9634774

Please sign in to comment.