From 06c910d52165710e77423e7bc16829cb499af6d7 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Wed, 16 Aug 2023 08:09:56 -0700 Subject: [PATCH] Avoid input shape mutation in the tile converter (#901) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/901 Currently, the `acc_ops_tile` converter modifies the shape of the input (by `insert`ing values into the mutable list shape). This can causes problems for the downstream fusion passes: e.g., layernorm + sigmoid + mul fusion that relies on the fact that the mul's output is the same rank as the input. This diff fixes the mutable input shape change in the tile converter by adding an extra reshape op instead. Reviewed By: sgrigory Differential Revision: D48389978 fbshipit-source-id: 2d11d722c6b24c7bd91fd4959d713071a17259ca --- fx2ait/fx2ait/converters/ait_converters.py | 8 ++++---- python/aitemplate/backend/cuda/tensor/masked_select.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index c792d26a7..e1446b1ad 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -1644,14 +1644,14 @@ def acc_ops_tile( for _ in range(input_dim_len - len(shape_dims)): shape_dims.insert(0, 1) if input_dim_len < len(shape_dims): - shape = input_val.shape() + new_shape = list(input_val.shape()) for _ in range(len(shape_dims) - input_dim_len): - shape.insert(0, IntImm(1)) - result = expand()(input_val, shape) + new_shape.insert(0, IntImm(1)) + result = reshape()(input_val, new_shape) for i, shape in enumerate(shape_dims): # Avoid operate on batch_size dim - if input_val.shape()[i]._attrs["name"] is not None: + if result.shape()[i]._attrs["name"] is not None: continue cat_groups = [result] * shape result = concatenate()(cat_groups, dim=i) diff --git a/python/aitemplate/backend/cuda/tensor/masked_select.py b/python/aitemplate/backend/cuda/tensor/masked_select.py index 14effddc0..a09cfa460 100644 --- a/python/aitemplate/backend/cuda/tensor/masked_select.py +++ b/python/aitemplate/backend/cuda/tensor/masked_select.py @@ -18,6 +18,7 @@ from typing import List import jinja2 + from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec