From 6305588af76eeec987762c5b5ee373a61f8a7fb3 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Sat, 15 Jul 2023 20:34:30 -0700 Subject: [PATCH] Have softmax handle case where all dims after reduction dim are 1 (#831) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/831 Memory-layout-wise, a tensor with shape [2, 1] is identical to that with shape [2], so the current implementation of softmax can handle it. Reviewed By: aakhundov Differential Revision: D47463543 fbshipit-source-id: 7799825c6c3b20f6412f2022e08d4ae82ead457e --- python/aitemplate/backend/cuda/softmax/softmax.py | 4 ---- python/aitemplate/compiler/ops/softmax/softmax.py | 6 ++++-- tests/unittest/ops/test_softmax.py | 2 ++ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/aitemplate/backend/cuda/softmax/softmax.py b/python/aitemplate/backend/cuda/softmax/softmax.py index ad8d493ab..168cb21a7 100644 --- a/python/aitemplate/backend/cuda/softmax/softmax.py +++ b/python/aitemplate/backend/cuda/softmax/softmax.py @@ -182,10 +182,6 @@ def softmax_gen_function(func_attrs: Dict[str, Any]) -> str: shapes = func_attrs["inputs"][0]._attrs["shape"] rank = len(shapes) - assert ( - dim == rank - 1 - ), f"softmax only supports dim == rank - 1, dim={dim}, rank={rank}" - assert isinstance( shapes[dim], IntImm ), "softmax requires reduction dim to be static" diff --git a/python/aitemplate/compiler/ops/softmax/softmax.py b/python/aitemplate/compiler/ops/softmax/softmax.py index 5fb63112d..c6137537f 100644 --- a/python/aitemplate/compiler/ops/softmax/softmax.py +++ b/python/aitemplate/compiler/ops/softmax/softmax.py @@ -31,6 +31,7 @@ from aitemplate.compiler.base import ( DynamicProfileStrategy, ExecItem, + IntImm, IntVar, Operator, Tensor, @@ -203,9 +204,10 @@ def __call__(self, x: Tensor, dim: int = None) -> Tensor: "flattening input tensor before normalization is not supported yet" ) dim = wrap_dim(dim, x._rank()) - if dim != x._rank() - 1: + tail_shapes = x.shape()[dim + 1 :] + if not all(isinstance(s, IntImm) and s.value() == 1 for s in tail_shapes): raise NotImplementedError( - f"softmax currently only supports dim=x._rank() - 1, dim={dim}, x._rank()={x._rank()}" + f"softmax only supports tensors where all shapes after dim are 1, {dim=}, {x.shape()=}" ) self._attrs["inputs"] = [x] diff --git a/tests/unittest/ops/test_softmax.py b/tests/unittest/ops/test_softmax.py index 7f31604e9..b206bc85c 100644 --- a/tests/unittest/ops/test_softmax.py +++ b/tests/unittest/ops/test_softmax.py @@ -68,6 +68,7 @@ def _test_softmax( { TestEnv.CUDA_LESS_THAN_SM80: [ ("dim_1_fp16", "float16", (1, 1024), (6,), 1), + ("tail_shapes_all_1_fp16", "float16", (1, 2), (6, 1, 1), 1), ("odd_small_fp16", "float16", (1, 13), (11,)), ("odd_mid_fp16", "float16", (1, 4096), (33,)), ("odd_large_fp16", "float16", (2, 31), (1409,)), @@ -100,6 +101,7 @@ def _test_softmax( ], TestEnv.CUDA_SM80: [ ("dim_1_bf16", "bfloat16", (1, 2), (6,), 1), + ("tail_shapes_all_1_bf16", "bfloat16", (1, 2), (6, 1, 1), 1), ("odd_small_bf16", "bfloat16", (1, 2), (11,)), ("odd_mid_bf16", "bfloat16", (1, 2), (33,)), ("odd_large_bf16", "bfloat16", (1, 2), (1409,)),