Skip to content

Commit

Permalink
Avoid input shape mutation in the tile converter (#901)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #901

Currently, the `acc_ops_tile` converter modifies the shape of the input (by `insert`ing values into the mutable list shape). This can causes problems for the downstream fusion passes: e.g., layernorm + sigmoid + mul fusion that relies on the fact that the mul's output is the same rank as the input.

This diff fixes the mutable input shape change in the tile converter by adding an extra reshape op instead.

Reviewed By: sgrigory

Differential Revision: D48389978

fbshipit-source-id: 2d11d722c6b24c7bd91fd4959d713071a17259ca
  • Loading branch information
aakhundov authored and facebook-github-bot committed Aug 16, 2023
1 parent 7ca9488 commit 06c910d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,14 +1644,14 @@ def acc_ops_tile(
for _ in range(input_dim_len - len(shape_dims)):
shape_dims.insert(0, 1)
if input_dim_len < len(shape_dims):
shape = input_val.shape()
new_shape = list(input_val.shape())
for _ in range(len(shape_dims) - input_dim_len):
shape.insert(0, IntImm(1))
result = expand()(input_val, shape)
new_shape.insert(0, IntImm(1))
result = reshape()(input_val, new_shape)

for i, shape in enumerate(shape_dims):
# Avoid operate on batch_size dim
if input_val.shape()[i]._attrs["name"] is not None:
if result.shape()[i]._attrs["name"] is not None:
continue
cat_groups = [result] * shape
result = concatenate()(cat_groups, dim=i)
Expand Down
1 change: 1 addition & 0 deletions python/aitemplate/backend/cuda/tensor/masked_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List

import jinja2

from aitemplate.backend import registry

from aitemplate.backend.backend_spec import CUDASpec
Expand Down

0 comments on commit 06c910d

Please sign in to comment.