Skip to content

Commit

Permalink
Implement SDPA kernel wrapper to use run_kernel flow for perf (pytorc…
Browse files Browse the repository at this point in the history
…h#2820)

Summary:
X-link: facebookresearch/FBGEMM#21

Pull Request resolved: pytorch#2820

Implments a custom kernel wrapper around scaled_dot_product_attention to directly use scaled_dot_product_attention athena kernel implementation.

Reviewed By: jiyuanzFB

Differential Revision: D59565188

fbshipit-source-id: df79e84576e0d3163d8b97b9eaaf5f50eb1d4071
  • Loading branch information
Adhitha Dias authored and facebook-github-bot committed Jul 11, 2024
1 parent eb97848 commit 8cf6133
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions fbgemm_gpu/src/qlinear_channelwise/qlinear_channelwise_mtia.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <torch/library.h>
#include "torch/types.h"

// TODO move these ops to fb directory

static at::Tensor qlinear_channelwise(
at::Tensor x,
at::Tensor weight,
Expand Down Expand Up @@ -169,6 +171,14 @@ static at::Tensor qlinear_dynamic(
x, weight, bias, input_scale, weight_scale, weight_zero_point, relu);
}

static at::Tensor fused_scaled_dot_product_attention(
at::Tensor query,
at::Tensor key,
at::Tensor value) {
// call the existing cpu op here for cpu results
return at::scaled_dot_product_attention(query, key, value);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"qlinear_channelwise(Tensor x, Tensor weight, Tensor "
Expand Down Expand Up @@ -204,4 +214,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl(
"qlinear_dynamic",
torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(qlinear_dynamic)));

m.def(
"fused_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value) -> Tensor");

m.impl(
"fused_scaled_dot_product_attention",
torch::dispatch(
c10::DispatchKey::CPU, TORCH_FN(fused_scaled_dot_product_attention)));
}

0 comments on commit 8cf6133

Please sign in to comment.