Skip to content

Commit

Permalink
Update scaled_mm signature in quantize benchmark. (pytorch#2779)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2779

D58699985 changed the return type of `torch._scaled_mm`. We need to update our implementation to match to avoid errors on main. We should consider adding a test to prevent future API breakages.

Differential Revision: D58981269

fbshipit-source-id: 24891caafbe71bc918f909fae0171b9cfdbda0eb
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jun 25, 2024
1 parent 1210b6f commit 9c4b799
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def quantize(self, x, w):
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
output, _ = torch._scaled_mm(
output = torch._scaled_mm(
xq,
wq,
bias=None,
Expand Down

0 comments on commit 9c4b799

Please sign in to comment.