From f6ecbed00c991d7132378ced271ffbcc0c373d9d Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 17 Jul 2023 07:49:13 -0700 Subject: [PATCH] Have softmax handle any dim via permutation Summary: Since the backend doesn't yet support reduction over non-last-dimensions, I'm having the frontend add two permute ops to move the desired dimension to the end of the shape. I considered modifying the softmax template code instead of creating new op instances, but the `permute()` op has a bunch of logic for dispatching to more specialized versions of itself, and it seems the easiest way to take advantage of that is by creating instances of the permute op. Differential Revision: D47506000 fbshipit-source-id: 10c8685c22dda72be4e2671b4c1369ac80062f1f --- python/aitemplate/compiler/ops/softmax/softmax.py | 10 +++++++--- tests/unittest/ops/test_softmax.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/aitemplate/compiler/ops/softmax/softmax.py b/python/aitemplate/compiler/ops/softmax/softmax.py index c6137537f..4c82dd1a9 100644 --- a/python/aitemplate/compiler/ops/softmax/softmax.py +++ b/python/aitemplate/compiler/ops/softmax/softmax.py @@ -37,6 +37,7 @@ Tensor, ) from aitemplate.compiler.ops.softmax.cache_entry import NormQueryEntry, NormRecordEntry +from aitemplate.compiler.ops.tensor.permute import permute from aitemplate.testing import detect_target @@ -205,10 +206,13 @@ def __call__(self, x: Tensor, dim: int = None) -> Tensor: ) dim = wrap_dim(dim, x._rank()) tail_shapes = x.shape()[dim + 1 :] + # The backend only supports reduction over the last non-1 dimension, so if we want + # to reduce over other dimensions we have to permute the tensor first. if not all(isinstance(s, IntImm) and s.value() == 1 for s in tail_shapes): - raise NotImplementedError( - f"softmax only supports tensors where all shapes after dim are 1, {dim=}, {x.shape()=}" - ) + perm_shape = list(range(x._rank())) + perm_shape[dim] = x._rank() - 1 + perm_shape[-1] = dim + return permute()(softmax()(permute()(x, perm_shape), dim=-1), perm_shape) self._attrs["inputs"] = [x] self._attrs["dim"] = dim diff --git a/tests/unittest/ops/test_softmax.py b/tests/unittest/ops/test_softmax.py index b206bc85c..f99f1d41b 100644 --- a/tests/unittest/ops/test_softmax.py +++ b/tests/unittest/ops/test_softmax.py @@ -69,6 +69,7 @@ def _test_softmax( TestEnv.CUDA_LESS_THAN_SM80: [ ("dim_1_fp16", "float16", (1, 1024), (6,), 1), ("tail_shapes_all_1_fp16", "float16", (1, 2), (6, 1, 1), 1), + ("tail_shapes_not_all_1_fp16", "float16", (1, 2), (6, 1, 2), 1), ("odd_small_fp16", "float16", (1, 13), (11,)), ("odd_mid_fp16", "float16", (1, 4096), (33,)), ("odd_large_fp16", "float16", (2, 31), (1409,)), @@ -102,6 +103,7 @@ def _test_softmax( TestEnv.CUDA_SM80: [ ("dim_1_bf16", "bfloat16", (1, 2), (6,), 1), ("tail_shapes_all_1_bf16", "bfloat16", (1, 2), (6, 1, 1), 1), + ("tail_shapes_not_all_1_bf16", "bfloat16", (1, 2), (6, 1, 2), 1), ("odd_small_bf16", "bfloat16", (1, 2), (11,)), ("odd_mid_bf16", "bfloat16", (1, 2), (33,)), ("odd_large_bf16", "bfloat16", (1, 2), (1409,)),