diff --git a/python/aitemplate/backend/cuda/softmax/softmax.py b/python/aitemplate/backend/cuda/softmax/softmax.py index ad8d493ab..168cb21a7 100644 --- a/python/aitemplate/backend/cuda/softmax/softmax.py +++ b/python/aitemplate/backend/cuda/softmax/softmax.py @@ -182,10 +182,6 @@ def softmax_gen_function(func_attrs: Dict[str, Any]) -> str: shapes = func_attrs["inputs"][0]._attrs["shape"] rank = len(shapes) - assert ( - dim == rank - 1 - ), f"softmax only supports dim == rank - 1, dim={dim}, rank={rank}" - assert isinstance( shapes[dim], IntImm ), "softmax requires reduction dim to be static" diff --git a/python/aitemplate/compiler/ops/softmax/softmax.py b/python/aitemplate/compiler/ops/softmax/softmax.py index 5fb63112d..c6137537f 100644 --- a/python/aitemplate/compiler/ops/softmax/softmax.py +++ b/python/aitemplate/compiler/ops/softmax/softmax.py @@ -31,6 +31,7 @@ from aitemplate.compiler.base import ( DynamicProfileStrategy, ExecItem, + IntImm, IntVar, Operator, Tensor, @@ -203,9 +204,10 @@ def __call__(self, x: Tensor, dim: int = None) -> Tensor: "flattening input tensor before normalization is not supported yet" ) dim = wrap_dim(dim, x._rank()) - if dim != x._rank() - 1: + tail_shapes = x.shape()[dim + 1 :] + if not all(isinstance(s, IntImm) and s.value() == 1 for s in tail_shapes): raise NotImplementedError( - f"softmax currently only supports dim=x._rank() - 1, dim={dim}, x._rank()={x._rank()}" + f"softmax only supports tensors where all shapes after dim are 1, {dim=}, {x.shape()=}" ) self._attrs["inputs"] = [x] diff --git a/tests/unittest/ops/test_softmax.py b/tests/unittest/ops/test_softmax.py index 7f31604e9..b206bc85c 100644 --- a/tests/unittest/ops/test_softmax.py +++ b/tests/unittest/ops/test_softmax.py @@ -68,6 +68,7 @@ def _test_softmax( { TestEnv.CUDA_LESS_THAN_SM80: [ ("dim_1_fp16", "float16", (1, 1024), (6,), 1), + ("tail_shapes_all_1_fp16", "float16", (1, 2), (6, 1, 1), 1), ("odd_small_fp16", "float16", (1, 13), (11,)), ("odd_mid_fp16", "float16", (1, 4096), (33,)), ("odd_large_fp16", "float16", (2, 31), (1409,)), @@ -100,6 +101,7 @@ def _test_softmax( ], TestEnv.CUDA_SM80: [ ("dim_1_bf16", "bfloat16", (1, 2), (6,), 1), + ("tail_shapes_all_1_bf16", "bfloat16", (1, 2), (6, 1, 1), 1), ("odd_small_bf16", "bfloat16", (1, 2), (11,)), ("odd_mid_bf16", "bfloat16", (1, 2), (33,)), ("odd_large_bf16", "bfloat16", (1, 2), (1409,)),