Skip to content

Commit

Permalink
Fix CK Profiler Build and Tune Small CK FP8 Shapes (pytorch#3017)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3017

X-link: facebookresearch/FBGEMM#113

A recent bump to CK broke the profiler build, but excluding the problematic targets resolves the issue.

I also snuck in two improvements to the CK shape dispatch, the most significant of which doubles the performance for [64, 1280, 8192], which may be impactful for Llama70B.

Reviewed By: jianyuh

Differential Revision: D61558684

fbshipit-source-id: c4865c8a04ee14bd9fb9e81188cb69f2989b5da0
  • Loading branch information
jwfromm authored and facebook-github-bot committed Aug 21, 2024
1 parent 2e190b4 commit 1c8ae9d
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static const std::unordered_map<
{{32, 1280, 8192},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{64, 1280, 8192},
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{128, 1280, 8192},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
// Support for decode across batch sizes for [8192, 1024]
Expand All @@ -60,7 +60,7 @@ static const std::unordered_map<
{{32, 8192, 1024},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{64, 8192, 1024},
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{128, 8192, 1024},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
// Support for decode across batch sizes for [7168, 8192]
Expand Down

0 comments on commit 1c8ae9d

Please sign in to comment.