Skip to content

Commit

Permalink
Have softmax handle case where all dims after reduction dim are 1
Browse files Browse the repository at this point in the history
Summary:
Memory-layout-wise, a tensor with shape [2, 1] is identical to that
with shape [2], so the current implementation of softmax can handle it.

Differential Revision: D47463543

fbshipit-source-id: 5523379ede1d508cf7acb3a037e0e6a36f20cecc
  • Loading branch information
int3 authored and facebook-github-bot committed Jul 14, 2023
1 parent 61646f7 commit 186ef6c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 0 additions & 4 deletions python/aitemplate/backend/cuda/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,6 @@ def softmax_gen_function(func_attrs: Dict[str, Any]) -> str:
shapes = func_attrs["inputs"][0]._attrs["shape"]
rank = len(shapes)

assert (
dim == rank - 1
), f"softmax only supports dim == rank - 1, dim={dim}, rank={rank}"

assert isinstance(
shapes[dim], IntImm
), "softmax requires reduction dim to be static"
Expand Down
6 changes: 4 additions & 2 deletions python/aitemplate/compiler/ops/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from aitemplate.compiler.base import (
DynamicProfileStrategy,
ExecItem,
IntImm,
IntVar,
Operator,
Tensor,
Expand Down Expand Up @@ -203,9 +204,10 @@ def __call__(self, x: Tensor, dim: int = None) -> Tensor:
"flattening input tensor before normalization is not supported yet"
)
dim = wrap_dim(dim, x._rank())
if dim != x._rank() - 1:
tail_shapes = x.shape()[dim + 1 :]
if not all(isinstance(s, IntImm) and s.value() == 1 for s in tail_shapes):
raise NotImplementedError(
f"softmax currently only supports dim=x._rank() - 1, dim={dim}, x._rank()={x._rank()}"
f"softmax only supports tensors where all shapes after dim are 1, {x.shape()=}"
)

self._attrs["inputs"] = [x]
Expand Down
2 changes: 2 additions & 0 deletions tests/unittest/ops/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _test_softmax(
{
TestEnv.CUDA_LESS_THAN_SM80: [
("dim_1_fp16", "float16", (1, 1024), (6,), 1),
("tail_shapes_all_1_fp16", "float16", (1, 2), (6, 1, 1), 1),
("odd_small_fp16", "float16", (1, 13), (11,)),
("odd_mid_fp16", "float16", (1, 4096), (33,)),
("odd_large_fp16", "float16", (2, 31), (1409,)),
Expand Down Expand Up @@ -100,6 +101,7 @@ def _test_softmax(
],
TestEnv.CUDA_SM80: [
("dim_1_bf16", "bfloat16", (1, 2), (6,), 1),
("tail_shapes_all_1_bf16", "bfloat16", (1, 2), (6, 1, 1), 1),
("odd_small_bf16", "bfloat16", (1, 2), (11,)),
("odd_mid_bf16", "bfloat16", (1, 2), (33,)),
("odd_large_bf16", "bfloat16", (1, 2), (1409,)),
Expand Down

0 comments on commit 186ef6c

Please sign in to comment.