Skip to content

Commit

Permalink
Add pass to remove no-op concats (#865)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #865

## This diff
We add a pass to remove no-op concats in remove_no_ops.py. We ignore any concats that are also marked as outputs. A concat is removed in these two cases:
 1. There is a single input tensor.
 2. There is a single *non-empty* input tensor and the remaining input tensors are empty.

Reviewed By: muchulee8

Differential Revision: D47763445

fbshipit-source-id: 2c9a24cffa3862670983438f0de2280e24a72d41
  • Loading branch information
ColinPeppler authored and facebook-github-bot committed Aug 7, 2023
1 parent 838e5b9 commit e3a89be
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 1 deletion.
57 changes: 56 additions & 1 deletion python/aitemplate/compiler/transform/remove_no_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""
from typing import List

from aitemplate.compiler.base import IntVar, JaggedIntVar, Operator, Tensor
from aitemplate.compiler.base import IntImm, IntVar, JaggedIntVar, Operator, Tensor
from aitemplate.compiler.ops.tensor.expand import ExpandDimensionType

from aitemplate.compiler.transform import transform_utils
Expand All @@ -40,6 +40,60 @@
from aitemplate.utils.shape_utils import is_singleton_dimension


def _remove_no_op_concats(sorted_graph: List[Tensor]) -> List[Tensor]:
"""
Remove no-op concats from the graph. A no-op concat is where the output
tensor is exactly the same as the input tensor(s) and it isn't the model output.
This is the case when:
1. There is a single input tensor.
2. There is a single non-empty input tensor and the remaining input tensors
are empty.
x = Tensor(shape=[7])
empty1 = Tensor(shape=[0], value=[])
empty2 = Tensor(shape=[0], value=[])
y1 = ops.concatenate([x]) # Case 1
y2 = ops.concatenate([empty1]) # Case 1
y2 = ops.concatenate([empty1, x, empty2]) # Case 2
"""

def is_dim_gt_zero(dim):
if isinstance(dim, IntImm):
return dim.value() > 0
elif isinstance(dim, IntVar):
return dim.lower_bound() > 0

ops = graph_utils.get_sorted_ops(sorted_graph)
for op in ops:
if op._attrs["op"] != "concatenate":
continue

inputs = op._attrs["inputs"]
assert len(inputs) >= 1, "concat must have at least 1 input"

outputs = op._attrs["outputs"]
concat_output = outputs[0]
assert len(outputs) == 1, "concat must have a single output"

# Assumes non-empty tensors have non-zero dimensions.
# And empty tensors have dimensions of size 0.
is_input_non_empty = [
all(is_dim_gt_zero(dim) for dim in tensor.shape()) for tensor in inputs
]
n_non_empty = sum(is_input_non_empty)
if len(inputs) > 1 and n_non_empty > 1 or outputs[0]._attrs["is_output"]:
continue

idx = is_input_non_empty.index(True) if n_non_empty == 1 else 0
concat_input = inputs[idx]
for dst_op in concat_output.dst_ops():
transform_utils.replace_tensor_for_op(dst_op, concat_output, concat_input)
transform_utils.remove_tensor_from_sorted_graph(concat_output)

return transform_utils.sanitize_sorted_graph(sorted_graph)


def _remove_no_op_dynamic_slices(sorted_graph: List[Tensor]) -> List[Tensor]:
"""
Remove any no-op slices from the graph. A no-op slice is when the input tensor
Expand Down Expand Up @@ -274,6 +328,7 @@ def remove_no_ops(sorted_graph: List[Tensor]) -> List[Tensor]:
Graph after remove no-ops
"""
passes = [
_remove_no_op_concats,
_remove_no_op_dynamic_slices,
_remove_no_op_splits,
_remove_no_op_expands,
Expand Down
147 changes: 147 additions & 0 deletions tests/unittest/compiler/test_remove_no_op_concats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from typing import Sequence

import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.compiler.base import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import get_random_torch_tensor, graph_has_op


class TestRemoveNoOpConcats(unittest.TestCase):
"""
Tests the compiler's behavior of removing no-op concats.
NOTE: Whenever we include an empty input tensor, the non-empty input tensor
must be rank 1. That's because AIT's concat expects all its inputs to have
the same rank and have matching dimension sizes except along the
concatenating dimension.
We run the following tests:
# These are no-ops
1. inputs=[non-empty]
2. inputs=[rank-1 empty, rank-1 non-empty, rank-1 empty]
3. inputs=[empty]
4. inputs=[empty, empty]
# These are meaningful
5. inputs=[non-empty, non-empty]
6. inputs=[non-empty, empty, non-empty]
# These should have exceptions
7. inputs=[rank-2 non-empty, rank-1 empty]
8. inputs=[rank-2 non-empty, rank-2 empty]
"""

def test_remove_no_op_concats_no_ops(self):
self._test_remove_no_op_concats_impl(
input_shapes=[[2, 4, 6]],
should_keep_concat=False,
test_name="test_remove_no_op_concats_single_non_empty",
)

self._test_remove_no_op_concats_impl(
input_shapes=[[0], [3], [0]],
should_keep_concat=False,
test_name="test_remove_no_op_concats_single_non_empty_and_double_empty",
)

def test_remove_no_op_concats_no_ops_all_empty(self):
"""Below we test when all the input tensors are empty. fx2ait will fail
in these cases. However, it's possible to create it directly in AIT.
Therefore, we test this case and treat it as a no-op.
"""
self._test_remove_no_op_concats_impl(
input_shapes=[[0, 0, 0]],
should_keep_concat=False,
test_name="test_remove_no_op_concats_single_empty",
)

self._test_remove_no_op_concats_impl(
input_shapes=[[0, 0, 0], [0, 0, 0]],
should_keep_concat=False,
test_name="test_remove_no_op_concats_double_empty",
)

def test_remove_no_op_concats_meaningful(self):
self._test_remove_no_op_concats_impl(
input_shapes=[[3, 5], [3, 5]],
should_keep_concat=True,
test_name="test_remove_no_op_concats_double_non_empty",
)

self._test_remove_no_op_concats_impl(
input_shapes=[[3], [0], [5]],
should_keep_concat=True,
test_name="test_remove_no_op_concats_two_non_empty_and_empty",
)

def test_remove_no_op_concats_exceptions(self):
"""We expect this to raise an exception in these test cases."""

# AIT expects all concat inputs to have the same rank.
with self.assertRaises(RuntimeError):
self._test_remove_no_op_concats_impl(
input_shapes=[[2, 4], [0]],
should_keep_concat=False,
test_name="test_remove_no_op_concats_same_rank",
)

# AIT expects all concat inputs to have the same dimension sizes except for the concat_dim.
with self.assertRaises(RuntimeError):
self._test_remove_no_op_concats_impl(
input_shapes=[[2, 4], [0, 0]],
should_keep_concat=False,
test_name="test_remove_no_ops_concat_same_dim_sizes",
)

def _test_remove_no_op_concats_impl(
self,
input_shapes: Sequence[Sequence[int]],
should_keep_concat: bool,
test_name: str,
):
inputs = [
Tensor(shape=shape, name=f"input_{i}", is_input=True)
for i, shape in enumerate(input_shapes)
]
concatenated = ops.concatenate()(inputs)
c = Tensor(shape=[1], name="input_const", is_input=True)
model_output = (concatenated * c) + (concatenated / c)
model_output._attrs["name"] = "output_0"
model_output._attrs["is_output"] = True

inputs_pt = {
f"input_{i}": get_random_torch_tensor(shape=shape)
for i, shape in enumerate(input_shapes)
}
concatenated_pt = torch.concat(list(inputs_pt.values()))
c_pt = get_random_torch_tensor(shape=[1])
Y_pt = (concatenated_pt * c_pt) + (concatenated_pt / c_pt)
Y_ait = torch.empty_like(Y_pt)

with compile_model(model_output, detect_target(), "./tmp", test_name) as module:
module.run_with_tensors(
{**inputs_pt, "input_const": c_pt}, {"output_0": Y_ait}
)

self.assertEquals(
graph_has_op(module.debug_sorted_graph, "concatenate"),
should_keep_concat,
)
self.assertTrue(torch.allclose(Y_pt, Y_ait, atol=1e-2, rtol=1e-2))

0 comments on commit e3a89be

Please sign in to comment.