Skip to content

Commit

Permalink
Dedup GQA splitk kernel
Browse files Browse the repository at this point in the history
Summary: We want to keep use_tensor_cores = False option for gqa_attn_splitk function for backward compatibility (GPUs before Hopper, AMD).

Reviewed By: sryap

Differential Revision: D56687037

fbshipit-source-id: 0c98fe6327fd063b62d59aaaacd238cacbfb20c5
  • Loading branch information
jianyuh authored and facebook-github-bot committed May 1, 2024
1 parent 0da2f0c commit 0aecd17
Show file tree
Hide file tree
Showing 2 changed files with 983 additions and 38 deletions.
11 changes: 6 additions & 5 deletions fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

namespace fbgemm_gpu::gen_ai::attention {

std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_cuda(
std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk(
const at::Tensor& XQ,
const at::Tensor& cache_K,
const at::Tensor& cache_V,
const at::Tensor& seq_positions,
const double qk_scale,
const int64_t num_split_ks,
const int64_t num_groups);

const int64_t num_int4_kv_groups,
const bool use_tensor_cores);
} // namespace fbgemm_gpu::gen_ai::attention

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
Expand All @@ -32,7 +32,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor seq_positions, "
" float qk_scale, "
" int num_split_ks, "
" int num_int4_kv_groups=1"
" int num_int4_kv_groups=1, "
" bool use_tensor_cores=True"
") -> (Tensor, Tensor, Tensor)");
}

Expand All @@ -41,5 +42,5 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"gqa_attn_splitk",
torch::dispatch(
c10::DispatchKey::CUDA,
TORCH_FN(fbgemm_gpu::gen_ai::attention::gqa_attn_splitk_cuda)));
TORCH_FN(fbgemm_gpu::gen_ai::attention::gqa_attn_splitk)));
}
Loading

0 comments on commit 0aecd17

Please sign in to comment.