diff --git a/python/aitemplate/compiler/transform/remove_no_ops.py b/python/aitemplate/compiler/transform/remove_no_ops.py index e53159632..8adb52c89 100644 --- a/python/aitemplate/compiler/transform/remove_no_ops.py +++ b/python/aitemplate/compiler/transform/remove_no_ops.py @@ -36,10 +36,48 @@ from aitemplate.compiler.transform import transform_utils -from aitemplate.utils import graph_utils +from aitemplate.utils import graph_utils, shape_utils from aitemplate.utils.shape_utils import is_singleton_dimension +def _remove_no_op_dynamic_slices(sorted_graph: List[Tensor]) -> List[Tensor]: + """ + Remove any no-op slices from the graph. A no-op slice is when the input tensor + and output tensor are exactly the same. This happens when the start indices + and end indices cover the entire dimension length. + + x = Tensor([1, 2, 3]) + y = x[:] + + xx = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + yy = xx[0:2, -4:4] + """ + + ops = graph_utils.get_sorted_ops(sorted_graph) + for op in ops: + if op._attrs["op"] != "dynamic_slice": + continue + + inputs = op._attrs["inputs"] + assert len(inputs) == 1, "dynamic_slice must only have 1 input" + + outputs = op._attrs["outputs"] + assert len(inputs) == 1, "dynamic_slice must only have 1 output" + + slice_input, slice_output = inputs[0], outputs[0] + if ( + not shape_utils.is_same_shape(slice_input.shape(), slice_output.shape()) + or slice_output._attrs["is_output"] + ): + continue + + for dst_op in slice_output.dst_ops(): + transform_utils.replace_tensor_for_op(dst_op, slice_output, slice_input) + transform_utils.remove_tensor_from_sorted_graph(slice_output) + + return transform_utils.sanitize_sorted_graph(sorted_graph) + + def _remove_no_op_splits(sorted_graph: List[Tensor]) -> List[Tensor]: """ Remove any no-op split from the graph where the input tensor is non-jagged. @@ -236,6 +274,7 @@ def remove_no_ops(sorted_graph: List[Tensor]) -> List[Tensor]: Graph after remove no-ops """ passes = [ + _remove_no_op_dynamic_slices, _remove_no_op_splits, _remove_no_op_expands, _fuse_expand_elementwise, diff --git a/tests/unittest/compiler/test_remove_no_op_dynamic_slices.py b/tests/unittest/compiler/test_remove_no_op_dynamic_slices.py new file mode 100644 index 000000000..274ccdc5a --- /dev/null +++ b/tests/unittest/compiler/test_remove_no_op_dynamic_slices.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest +from typing import List + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.compiler.ops.tensor.dynamic_slice import MAX_INT32 +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + gen_input_tensor, + get_random_torch_tensor, + graph_has_op, +) + + +class TestRemoveNoOpDynamicSlices(unittest.TestCase): + """ + Tests the compiler's behavior when removing no-op dynamic slices. + """ + + def test_remove_no_op_dynamic_slices(self): + TEST_CASES = ( + # These are no-ops. + { + # X[:] + "input_shape": [100], + "start_indices": [None], + "end_indices": [None], + "should_keep_dynamic_slice": False, + }, + { + # X[0:] + "input_shape": [100], + "start_indices": [0], + "end_indices": [None], + "should_keep_dynamic_slice": False, + }, + { + # X[:2_147_483_647, ] + "input_shape": [100, 100], + "start_indices": [None, 0], + "end_indices": [MAX_INT32, None], + "should_keep_dynamic_slice": False, + }, + # These are meaningful. + { + # X[-7:-7] + "input_shape": [10], + "start_indices": [-7], + "end_indices": [-7], + "should_keep_dynamic_slice": True, + }, + { + # X[7:, -7:, 0:] + "input_shape": [10, 10, 10], + "start_indices": [7, -7, 0], + "end_indices": [None, None, None], + "should_keep_dynamic_slice": True, + }, + { + # X[:7, :-7, :0] + "input_shape": [10, 10, 10], + "start_indices": [None, None, None], + "end_indices": [7, -7, 0], + "should_keep_dynamic_slice": True, + }, + { + # X[0:7, 0:-7] + "input_shape": [10, 10], + "start_indices": [0, 0], + "end_indices": [7, -7], + "should_keep_dynamic_slice": True, + }, + { + # X[-7:7, 7:-7] + "input_shape": [10, 10], + "start_indices": [-7, 7], + "end_indices": [7, -7], + "should_keep_dynamic_slice": True, + }, + { + # X[-7:7, 7:-7, :] + "input_shape": [10, 10, 10], + "start_indices": [-7, 7, None], + "end_indices": [7, -7, None], + "should_keep_dynamic_slice": True, + }, + ) + + for i, test_kwargs in enumerate(TEST_CASES): + start_indices = ",".join(map(str, test_kwargs["start_indices"])) + end_indices = ",".join(map(str, test_kwargs["end_indices"])) + + with self.subTest( + start=start_indices, + end=end_indices, + keep=test_kwargs["should_keep_dynamic_slice"], + ): + self._test_remove_no_op_dynamic_slices_impl( + **test_kwargs, + test_name=f"test_remove_no_op_dynamic_slice_{i}", + ) + + def _test_remove_no_op_dynamic_slices_impl( + self, + input_shape: List[int], + start_indices: List[int], + end_indices: List[int], + should_keep_dynamic_slice: bool, + test_name: str, + ): + X = gen_input_tensor(shape=input_shape, name="input_0") + X_sliced = ops.dynamic_slice()(X, start_indices, end_indices) + c = gen_input_tensor(shape=[1], name="input_const") + model_output = (X_sliced * c) + (X_sliced / c) + model_output._attrs["name"] = "output_0" + model_output._attrs["is_output"] = True + + X_pt = get_random_torch_tensor(shape=input_shape) + slices = [slice(s, e) for s, e in zip(start_indices, end_indices)] + X_sliced_pt = X_pt[slices] + c_pt = get_random_torch_tensor(shape=[1]) + Y_pt = (X_sliced_pt * c_pt) + (X_sliced_pt / c_pt) + Y_ait = torch.empty_like(Y_pt) + + # NOTE: We don't run every optimization pass to avoid fusion between + # dynamic_slice and elementwise. + with compile_model( + model_output, detect_target(), "/tmp", test_name, do_optimize_graph=False + ) as module: + module.run_with_tensors( + {"input_0": X_pt, "input_const": c_pt}, {"output_0": Y_ait} + ) + + self.assertEqual( + graph_has_op(module.debug_sorted_graph, "dynamic_slice"), + should_keep_dynamic_slice, + ) + self.assertTrue(torch.allclose(Y_pt, Y_ait, atol=1e-2, rtol=1e-3))