From 11f320d1c636fa30358ba692dea9ee1c23375456 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 17 Jul 2023 07:55:06 -0700 Subject: [PATCH] Have softmax handle any dim via permutation (#832) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/832 Since the backend doesn't yet support reduction over non-last-dimensions, I'm having the frontend add two permute ops to move the desired dimension to the end of the shape. I considered modifying the softmax template code instead of creating new op instances, but the `permute()` op has a bunch of logic for dispatching to more specialized versions of itself, and it seems the easiest way to take advantage of that is by creating instances of the permute op. Reviewed By: aakhundov Differential Revision: D47506000 fbshipit-source-id: b3522fdbb0154d9caa5a4865653aec1816de6bc9 --- python/aitemplate/compiler/ops/softmax/softmax.py | 12 +++++++++--- tests/unittest/ops/test_softmax.py | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/aitemplate/compiler/ops/softmax/softmax.py b/python/aitemplate/compiler/ops/softmax/softmax.py index c6137537f..f38dbf576 100644 --- a/python/aitemplate/compiler/ops/softmax/softmax.py +++ b/python/aitemplate/compiler/ops/softmax/softmax.py @@ -37,6 +37,7 @@ Tensor, ) from aitemplate.compiler.ops.softmax.cache_entry import NormQueryEntry, NormRecordEntry +from aitemplate.compiler.ops.tensor.permute import permute from aitemplate.testing import detect_target @@ -205,10 +206,15 @@ def __call__(self, x: Tensor, dim: int = None) -> Tensor: ) dim = wrap_dim(dim, x._rank()) tail_shapes = x.shape()[dim + 1 :] + # The backend only supports reduction over the last non-1 dimension, so if we want + # to reduce over other dimensions we have to permute the tensor first. if not all(isinstance(s, IntImm) and s.value() == 1 for s in tail_shapes): - raise NotImplementedError( - f"softmax only supports tensors where all shapes after dim are 1, {dim=}, {x.shape()=}" - ) + perm_shape = list(range(x._rank())) + perm_shape[dim] = x._rank() - 1 + perm_shape[-1] = dim + x_perm = permute()(x, perm_shape) + x_perm_softmax = softmax()(x_perm, dim=-1) + return permute()(x_perm_softmax, perm_shape) self._attrs["inputs"] = [x] self._attrs["dim"] = dim diff --git a/tests/unittest/ops/test_softmax.py b/tests/unittest/ops/test_softmax.py index b206bc85c..f99f1d41b 100644 --- a/tests/unittest/ops/test_softmax.py +++ b/tests/unittest/ops/test_softmax.py @@ -69,6 +69,7 @@ def _test_softmax( TestEnv.CUDA_LESS_THAN_SM80: [ ("dim_1_fp16", "float16", (1, 1024), (6,), 1), ("tail_shapes_all_1_fp16", "float16", (1, 2), (6, 1, 1), 1), + ("tail_shapes_not_all_1_fp16", "float16", (1, 2), (6, 1, 2), 1), ("odd_small_fp16", "float16", (1, 13), (11,)), ("odd_mid_fp16", "float16", (1, 4096), (33,)), ("odd_large_fp16", "float16", (2, 31), (1409,)), @@ -102,6 +103,7 @@ def _test_softmax( TestEnv.CUDA_SM80: [ ("dim_1_bf16", "bfloat16", (1, 2), (6,), 1), ("tail_shapes_all_1_bf16", "bfloat16", (1, 2), (6, 1, 1), 1), + ("tail_shapes_not_all_1_bf16", "bfloat16", (1, 2), (6, 1, 2), 1), ("odd_small_bf16", "bfloat16", (1, 2), (11,)), ("odd_mid_bf16", "bfloat16", (1, 2), (33,)), ("odd_large_bf16", "bfloat16", (1, 2), (1409,)),