diff --git a/python/aitemplate/compiler/transform/remove_no_ops.py b/python/aitemplate/compiler/transform/remove_no_ops.py index 8adb52c89..fd5d325dd 100644 --- a/python/aitemplate/compiler/transform/remove_no_ops.py +++ b/python/aitemplate/compiler/transform/remove_no_ops.py @@ -31,7 +31,7 @@ """ from typing import List -from aitemplate.compiler.base import IntVar, JaggedIntVar, Operator, Tensor +from aitemplate.compiler.base import IntImm, IntVar, JaggedIntVar, Operator, Tensor from aitemplate.compiler.ops.tensor.expand import ExpandDimensionType from aitemplate.compiler.transform import transform_utils @@ -40,6 +40,60 @@ from aitemplate.utils.shape_utils import is_singleton_dimension +def _remove_no_op_concats(sorted_graph: List[Tensor]) -> List[Tensor]: + """ + Remove no-op concats from the graph. A no-op concat is where the output + tensor is exactly the same as the input tensor(s) and it isn't the model output. + This is the case when: + 1. There is a single input tensor. + 2. There is a single non-empty input tensor and the remaining input tensors + are empty. + + x = Tensor(shape=[7]) + empty1 = Tensor(shape=[0], value=[]) + empty2 = Tensor(shape=[0], value=[]) + + y1 = ops.concatenate([x]) # Case 1 + y2 = ops.concatenate([empty1]) # Case 1 + y2 = ops.concatenate([empty1, x, empty2]) # Case 2 + """ + + def is_dim_gt_zero(dim): + if isinstance(dim, IntImm): + return dim.value() > 0 + elif isinstance(dim, IntVar): + return dim.lower_bound() > 0 + + ops = graph_utils.get_sorted_ops(sorted_graph) + for op in ops: + if op._attrs["op"] != "concatenate": + continue + + inputs = op._attrs["inputs"] + assert len(inputs) >= 1, "concat must have at least 1 input" + + outputs = op._attrs["outputs"] + concat_output = outputs[0] + assert len(outputs) == 1, "concat must have a single output" + + # Assumes non-empty tensors have non-zero dimensions. + # And empty tensors have dimensions of size 0. + is_input_non_empty = [ + all(is_dim_gt_zero(dim) for dim in tensor.shape()) for tensor in inputs + ] + n_non_empty = sum(is_input_non_empty) + if len(inputs) > 1 and n_non_empty > 1 or outputs[0]._attrs["is_output"]: + continue + + idx = is_input_non_empty.index(True) if n_non_empty == 1 else 0 + concat_input = inputs[idx] + for dst_op in concat_output.dst_ops(): + transform_utils.replace_tensor_for_op(dst_op, concat_output, concat_input) + transform_utils.remove_tensor_from_sorted_graph(concat_output) + + return transform_utils.sanitize_sorted_graph(sorted_graph) + + def _remove_no_op_dynamic_slices(sorted_graph: List[Tensor]) -> List[Tensor]: """ Remove any no-op slices from the graph. A no-op slice is when the input tensor @@ -274,6 +328,7 @@ def remove_no_ops(sorted_graph: List[Tensor]) -> List[Tensor]: Graph after remove no-ops """ passes = [ + _remove_no_op_concats, _remove_no_op_dynamic_slices, _remove_no_op_splits, _remove_no_op_expands, diff --git a/tests/unittest/compiler/test_remove_no_op_concats.py b/tests/unittest/compiler/test_remove_no_op_concats.py new file mode 100644 index 000000000..d1d5d2d28 --- /dev/null +++ b/tests/unittest/compiler/test_remove_no_op_concats.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest +from typing import Sequence + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.compiler.base import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import get_random_torch_tensor, graph_has_op + + +class TestRemoveNoOpConcats(unittest.TestCase): + """ + Tests the compiler's behavior of removing no-op concats. + + NOTE: Whenever we include an empty input tensor, the non-empty input tensor + must be rank 1. That's because AIT's concat expects all its inputs to have + the same rank and have matching dimension sizes except along the + concatenating dimension. + + We run the following tests: + # These are no-ops + 1. inputs=[non-empty] + 2. inputs=[rank-1 empty, rank-1 non-empty, rank-1 empty] + 3. inputs=[empty] + 4. inputs=[empty, empty] + + # These are meaningful + 5. inputs=[non-empty, non-empty] + 6. inputs=[non-empty, empty, non-empty] + + # These should have exceptions + 7. inputs=[rank-2 non-empty, rank-1 empty] + 8. inputs=[rank-2 non-empty, rank-2 empty] + """ + + def test_remove_no_op_concats_no_ops(self): + self._test_remove_no_op_concats_impl( + input_shapes=[[2, 4, 6]], + should_keep_concat=False, + test_name="test_remove_no_op_concats_single_non_empty", + ) + + self._test_remove_no_op_concats_impl( + input_shapes=[[0], [3], [0]], + should_keep_concat=False, + test_name="test_remove_no_op_concats_single_non_empty_and_double_empty", + ) + + def test_remove_no_op_concats_no_ops_all_empty(self): + """Below we test when all the input tensors are empty. fx2ait will fail + in these cases. However, it's possible to create it directly in AIT. + Therefore, we test this case and treat it as a no-op. + """ + self._test_remove_no_op_concats_impl( + input_shapes=[[0, 0, 0]], + should_keep_concat=False, + test_name="test_remove_no_op_concats_single_empty", + ) + + self._test_remove_no_op_concats_impl( + input_shapes=[[0, 0, 0], [0, 0, 0]], + should_keep_concat=False, + test_name="test_remove_no_op_concats_double_empty", + ) + + def test_remove_no_op_concats_meaningful(self): + self._test_remove_no_op_concats_impl( + input_shapes=[[3, 5], [3, 5]], + should_keep_concat=True, + test_name="test_remove_no_op_concats_double_non_empty", + ) + + self._test_remove_no_op_concats_impl( + input_shapes=[[3], [0], [5]], + should_keep_concat=True, + test_name="test_remove_no_op_concats_two_non_empty_and_empty", + ) + + def test_remove_no_op_concats_exceptions(self): + """We expect this to raise an exception in these test cases.""" + + # AIT expects all concat inputs to have the same rank. + with self.assertRaises(RuntimeError): + self._test_remove_no_op_concats_impl( + input_shapes=[[2, 4], [0]], + should_keep_concat=False, + test_name="test_remove_no_op_concats_same_rank", + ) + + # AIT expects all concat inputs to have the same dimension sizes except for the concat_dim. + with self.assertRaises(RuntimeError): + self._test_remove_no_op_concats_impl( + input_shapes=[[2, 4], [0, 0]], + should_keep_concat=False, + test_name="test_remove_no_ops_concat_same_dim_sizes", + ) + + def _test_remove_no_op_concats_impl( + self, + input_shapes: Sequence[Sequence[int]], + should_keep_concat: bool, + test_name: str, + ): + inputs = [ + Tensor(shape=shape, name=f"input_{i}", is_input=True) + for i, shape in enumerate(input_shapes) + ] + concatenated = ops.concatenate()(inputs) + c = Tensor(shape=[1], name="input_const", is_input=True) + model_output = (concatenated * c) + (concatenated / c) + model_output._attrs["name"] = "output_0" + model_output._attrs["is_output"] = True + + inputs_pt = { + f"input_{i}": get_random_torch_tensor(shape=shape) + for i, shape in enumerate(input_shapes) + } + concatenated_pt = torch.concat(list(inputs_pt.values())) + c_pt = get_random_torch_tensor(shape=[1]) + Y_pt = (concatenated_pt * c_pt) + (concatenated_pt / c_pt) + Y_ait = torch.empty_like(Y_pt) + + with compile_model(model_output, detect_target(), "./tmp", test_name) as module: + module.run_with_tensors( + {**inputs_pt, "input_const": c_pt}, {"output_0": Y_ait} + ) + + self.assertEquals( + graph_has_op(module.debug_sorted_graph, "concatenate"), + should_keep_concat, + ) + self.assertTrue(torch.allclose(Y_pt, Y_ait, atol=1e-2, rtol=1e-2))