Skip to content

Commit

Permalink
Have softmax handle any dim via permutation
Browse files Browse the repository at this point in the history
Summary:
Since the backend doesn't yet support reduction over non-last-dimensions, I'm
having the frontend add two permute ops to move the desired dimension to the
end of the shape.

I considered modifying the softmax template code instead of creating new op
instances, but the `permute()` op has a bunch of logic for dispatching to more
specialized versions of itself, and it seems the easiest way to take advantage
of that is by creating instances of the permute op.

Differential Revision: D47506000

fbshipit-source-id: 10c8685c22dda72be4e2671b4c1369ac80062f1f
  • Loading branch information
int3 authored and facebook-github-bot committed Jul 17, 2023
1 parent 6305588 commit f6ecbed
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/aitemplate/compiler/ops/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Tensor,
)
from aitemplate.compiler.ops.softmax.cache_entry import NormQueryEntry, NormRecordEntry
from aitemplate.compiler.ops.tensor.permute import permute

from aitemplate.testing import detect_target

Expand Down Expand Up @@ -205,10 +206,13 @@ def __call__(self, x: Tensor, dim: int = None) -> Tensor:
)
dim = wrap_dim(dim, x._rank())
tail_shapes = x.shape()[dim + 1 :]
# The backend only supports reduction over the last non-1 dimension, so if we want
# to reduce over other dimensions we have to permute the tensor first.
if not all(isinstance(s, IntImm) and s.value() == 1 for s in tail_shapes):
raise NotImplementedError(
f"softmax only supports tensors where all shapes after dim are 1, {dim=}, {x.shape()=}"
)
perm_shape = list(range(x._rank()))
perm_shape[dim] = x._rank() - 1
perm_shape[-1] = dim
return permute()(softmax()(permute()(x, perm_shape), dim=-1), perm_shape)

self._attrs["inputs"] = [x]
self._attrs["dim"] = dim
Expand Down
2 changes: 2 additions & 0 deletions tests/unittest/ops/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _test_softmax(
TestEnv.CUDA_LESS_THAN_SM80: [
("dim_1_fp16", "float16", (1, 1024), (6,), 1),
("tail_shapes_all_1_fp16", "float16", (1, 2), (6, 1, 1), 1),
("tail_shapes_not_all_1_fp16", "float16", (1, 2), (6, 1, 2), 1),
("odd_small_fp16", "float16", (1, 13), (11,)),
("odd_mid_fp16", "float16", (1, 4096), (33,)),
("odd_large_fp16", "float16", (2, 31), (1409,)),
Expand Down Expand Up @@ -102,6 +103,7 @@ def _test_softmax(
TestEnv.CUDA_SM80: [
("dim_1_bf16", "bfloat16", (1, 2), (6,), 1),
("tail_shapes_all_1_bf16", "bfloat16", (1, 2), (6, 1, 1), 1),
("tail_shapes_not_all_1_bf16", "bfloat16", (1, 2), (6, 1, 2), 1),
("odd_small_bf16", "bfloat16", (1, 2), (11,)),
("odd_mid_bf16", "bfloat16", (1, 2), (33,)),
("odd_large_bf16", "bfloat16", (1, 2), (1409,)),
Expand Down

0 comments on commit f6ecbed

Please sign in to comment.