Skip to content

Commit

Permalink
Add pass to remove no-op splits. (facebookincubator#813)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#813

## This diff
We add a pass to remove no-op splits in remove_no_ops.py. We ignore splits that are also model outputs.

A split is removed if it meets the following conditions:
1. it's a no-op -- the split has a single output (and its output's and input's shape are the same along the split dimension
2. it's an intermediate op -- the split is not the final model output

Reviewed By: muchulee8

Differential Revision: D46462236

fbshipit-source-id: 3fa00da35994bca7f694c7126dd1fad1f76b49ce
  • Loading branch information
ColinPeppler authored and facebook-github-bot committed Jul 6, 2023
1 parent 5e8fa7a commit 8653901
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 0 deletions.
56 changes: 56 additions & 0 deletions python/aitemplate/compiler/transform/remove_no_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,61 @@
from aitemplate.utils.shape_utils import is_singleton_dimension


def _remove_no_op_splits(sorted_graph: List[Tensor]) -> List[Tensor]:
"""
Remove any no-op split from the graph where the input tensor is non-jagged.
A no-op split is where the input tensor isn't divided into multiple parts.
This happens when the split_size_or_sections argument is:
1. an integer representing the length of the dimension indicated by dim
2. a singleton list containing the length of the dimension indicated by dim.
x = Tensor([1, 2, 3])
y1 = split(x, split_size_or_sections=3, dim=0) # Case 1
y2 = split(x, split_size_or_sections=[3], dim=0) # Case 2
xx = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
yy1 = split(xx, split_size_or_sections=2, dim=0) # Case 1
yy2 = split(xx, split_size_or_sections=4, dim=1) # Case 1
yy3 = split(xx, split_size_or_sections=[2], dim=0) # Case 2
yy4 = split(xx, split_size_or_sections=[4], dim=1) # Case 2
"""

ops = graph_utils.get_sorted_ops(sorted_graph)
for op in ops:
if op._attrs["op"] != "split":
continue

inputs = op._attrs["inputs"]
assert len(inputs) == 1, "split must only have 1 input"

outputs = op._attrs["outputs"]
assert len(inputs) >= 1, "split must have at least 1 output"

split_dim = op._attrs["split_dim"]
split_input, split_output = inputs[0], outputs[0]
input_split_dim_len, output_split_dim_len = (
split_input._attrs["shape"][split_dim],
split_output._attrs["shape"][split_dim],
)

# No-op splits must have one output, and the input and output shapes
# must match along split_dim. We ignore no-op splits that are outputs.
if (
len(outputs) > 1
or input_split_dim_len != output_split_dim_len
or outputs[0]._attrs["is_output"]
):
continue

# Delete the split output in the graph.
for dst_op in list(split_output.dst_ops()):
transform_utils.replace_tensor_for_op(dst_op, split_output, split_input)

transform_utils.remove_tensor_from_sorted_graph(split_output)

return transform_utils.sanitize_sorted_graph(sorted_graph)


def _remove_no_op_expands(sorted_graph: List[Tensor]) -> List[Tensor]:
"""
Remove no-op expands from the graph. A no-op expand is one
Expand Down Expand Up @@ -181,6 +236,7 @@ def remove_no_ops(sorted_graph: List[Tensor]) -> List[Tensor]:
Graph after remove no-ops
"""
passes = [
_remove_no_op_splits,
_remove_no_op_expands,
_fuse_expand_elementwise,
]
Expand Down
168 changes: 168 additions & 0 deletions tests/unittest/compiler/test_remove_no_op_splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from typing import List, Sequence, Union

import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import (
gen_input_tensor,
get_random_torch_tensor,
graph_has_op,
)


class TestRemoveNoOpSplits(unittest.TestCase):
"""
Tests _remove_no_op_splits() in remove_no_ops.py
"""

def test_remove_no_op_split(self):
"""
Test cases:
0. No-op split with split_size_or_sections as integer
1. No-op split with split_size_or_sections as a singleton list
2. No-op split with split_size > length along split_dim
3. No-op split with split_dim = -1
4. Meaningful split
5. Meaningful split with split_dim = -1
6. No-op split is a model output
7. Meaningful split is a model output
"""

test_cases = (
# Split is a no-op.
{
"split_input_shape": (5,),
"split_size_or_sections": 5,
"split_dim": 0,
"split_is_output": False,
"should_remove_no_op_split": True,
"test_name": "test_remove_no_op_split_no_op_0",
},
{
"split_input_shape": (5,),
"split_size_or_sections": [5],
"split_dim": -1,
"split_is_output": False,
"should_remove_no_op_split": True,
"test_name": "test_remove_no_op_split_no_op_1",
},
{
"split_input_shape": (2, 3, 4),
"split_size_or_sections": 10, # split_size > length along dim=1
"split_dim": 1,
"split_is_output": False,
"should_remove_no_op_split": True,
"test_name": "test_remove_no_op_split_no_op_2",
},
{
"split_input_shape": (2, 3, 4, 5),
"split_size_or_sections": [5],
"split_dim": -1,
"split_is_output": False,
"should_remove_no_op_split": True,
"test_name": "test_remove_no_op_split_no_op_3",
},
# Split is meaningful.
{
"split_input_shape": (7,),
"split_size_or_sections": 2,
"split_dim": 0,
"split_is_output": False,
"should_remove_no_op_split": False,
"test_name": "test_remove_no_op_split_meaningful_4",
},
{
"split_input_shape": (2, 3, 4, 5),
"split_size_or_sections": [2, 1, 2],
"split_dim": -1,
"split_is_output": False,
"should_remove_no_op_split": False,
"test_name": "test_remove_no_op_split_meaningful_5",
},
# Split is a model output.
{
"split_input_shape": (9,),
"split_size_or_sections": [9],
"split_dim": 0,
"split_is_output": True,
"should_remove_no_op_split": False,
"test_name": "test_remove_no_op_split_output_6",
},
{
"split_input_shape": (1, 9),
"split_size_or_sections": [4, 5],
"split_dim": -1,
"split_is_output": True,
"should_remove_no_op_split": False,
"test_name": "test_remove_no_op_split_output_7",
},
)

for i, test_kwargs in enumerate(test_cases):
with self.subTest(test_no=i):
self._test_remove_no_op_split_impl(**test_kwargs)

def _test_remove_no_op_split_impl(
self,
split_input_shape: Sequence[int],
split_size_or_sections: Union[int, List[int]],
split_dim: int,
split_is_output: bool,
should_remove_no_op_split: bool,
test_name: str,
):
# Define model graph.
X = gen_input_tensor(shape=split_input_shape, name="input_0")
c = gen_input_tensor(shape=(1,), name="input_1")
Zs = ops.split()(X, split_size_or_sections, split_dim)

model_outputs = []
for i, Z in enumerate(Zs):
out = Z if split_is_output else Z + c
out._attrs["name"] = f"output_{i}"
out._attrs["is_output"] = True
model_outputs.append(out)

# Run PyTorch.
X_pt = get_random_torch_tensor(shape=split_input_shape)
c_pt = get_random_torch_tensor(shape=(1,))
Zs_pt = torch.split(X_pt, split_size_or_sections, split_dim)
outputs_pt = Zs_pt if split_is_output else [Z_pt + c_pt for Z_pt in Zs_pt]

# Run AIT.
with compile_model(
model_outputs, detect_target(), "./tmp", test_name
) as module:
inputs_pt = (
{"input_0": X_pt}
if split_is_output
else {"input_0": X_pt, "input_1": c_pt}
)
outputs_ait = {
f"output_{i}": torch.empty_like(out_pt)
for (i, out_pt) in enumerate(outputs_pt)
}
module.run_with_tensors(inputs_pt, outputs_ait)

self.assertNotEqual(
graph_has_op(module.debug_sorted_graph, "split"),
should_remove_no_op_split,
)
for out_pt, out_ait in zip(outputs_pt, outputs_ait.values()):
self.assertTrue(torch.allclose(out_pt, out_ait, atol=1e-2, rtol=1e-3))

0 comments on commit 8653901

Please sign in to comment.