From d6bf98b039dd29ff7aa8afe9a7c34aef239a61eb Mon Sep 17 00:00:00 2001 From: "Amit Agarwal (Ads ML Serving)" Date: Fri, 14 Jul 2023 17:30:09 -0700 Subject: [PATCH] Fix and re-enable eliminate_permutations graph transformation (#824) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/824 Fix eliminate_permutations graph transformation in AITemplate to retain successive cancelling permutations on strided input Tensor. Re-enabled the transformation as well. Reviewed By: chenyang78 Differential Revision: D47371853 fbshipit-source-id: ed7a7724d13a6b729e75ad78af99f9f23969763e --- .../compiler/transform/optimize_graph.py | 4 +- .../transform/transform_permutations.py | 36 +++++- .../compiler/test_eliminate_permutations.py | 121 +++++++++++++++++- 3 files changed, 152 insertions(+), 9 deletions(-) diff --git a/python/aitemplate/compiler/transform/optimize_graph.py b/python/aitemplate/compiler/transform/optimize_graph.py index 16ca6c0e4..92bfe0dde 100644 --- a/python/aitemplate/compiler/transform/optimize_graph.py +++ b/python/aitemplate/compiler/transform/optimize_graph.py @@ -51,6 +51,7 @@ from aitemplate.compiler.transform.transform_odd_alignment import ( transform_odd_alignment, ) +from aitemplate.compiler.transform.transform_permutations import eliminate_permutations from aitemplate.compiler.transform.transform_permute_to_reshape import ( transform_permute_to_reshape, ) @@ -125,8 +126,7 @@ def optimize_graph( split_large_split_ops, transform_permute_to_reshape, transform_memory_ops, - # FIXME: temporarily disable this due to some accuracy issue - # eliminate_permutations, + eliminate_permutations, ] if not optimize: diff --git a/python/aitemplate/compiler/transform/transform_permutations.py b/python/aitemplate/compiler/transform/transform_permutations.py index ca6488e3e..3b81e68ad 100644 --- a/python/aitemplate/compiler/transform/transform_permutations.py +++ b/python/aitemplate/compiler/transform/transform_permutations.py @@ -17,6 +17,7 @@ import numpy as np from aitemplate.compiler.base import Operator, Tensor +from aitemplate.compiler.tensor_accessor import TensorAccessor from aitemplate.compiler.transform import transform_utils @@ -60,6 +61,32 @@ def remove_second_permutation_from_graph( transform_utils.remove_tensor_from_sorted_graph(output_tensor) +def _reshaped_or_strided_input_or_output_accessor(op: Operator) -> bool: + def _reshaped_or_strided_tensor_accessor(accessor: TensorAccessor) -> bool: + if ( + accessor.actual_shapes is not None + and accessor.actual_shapes != accessor.original_shapes + ): + return True + + # Is it a strided accessor + if hasattr(accessor, "stride_dim") and accessor.stride_dim is not None: + return True + + return False + + input_accessors = op._attrs.get("input_accessors", None) + output_accessors = op._attrs.get("output_accessors", None) + + return ( + (input_accessors is not None) + and _reshaped_or_strided_tensor_accessor(input_accessors[0]) + ) or ( + (output_accessors is not None) + and _reshaped_or_strided_tensor_accessor(output_accessors[0]) + ) + + def eliminate_permutations( sorted_graph: List[Tensor], workdir: str = None ) -> List[Tensor]: @@ -73,12 +100,7 @@ def eliminate_permutations( continue if not cur_op._attrs["op"].startswith("permute"): continue - input_accessors = cur_op._attrs.get("input_accessors", None) - if ( - input_accessors is not None - and hasattr(input_accessors[0], "strided_dim") - and input_accessors[0].strided_dim is not None - ): + if _reshaped_or_strided_input_or_output_accessor(cur_op): continue curr_op_output = cur_op._attrs["outputs"][0] dst_ops = curr_op_output._attrs["dst_ops"] @@ -89,6 +111,8 @@ def eliminate_permutations( for next_op in dst_ops: if not next_op._attrs["op"].startswith("permute"): continue + if _reshaped_or_strided_input_or_output_accessor(next_op): + continue p1 = get_permutation(cur_op) p2 = get_permutation(next_op) if not np.all(np.array(p1)[p2] == np.arange(0, len(p1))): diff --git a/tests/unittest/compiler/test_eliminate_permutations.py b/tests/unittest/compiler/test_eliminate_permutations.py index 54ec64d6f..6c3117da2 100644 --- a/tests/unittest/compiler/test_eliminate_permutations.py +++ b/tests/unittest/compiler/test_eliminate_permutations.py @@ -28,7 +28,6 @@ ) -@unittest.skip("Skip until we fix the accuracy issue") class EliminatePermutationTestCase(unittest.TestCase): def test_eliminate_permutation(self): dtype = "float" @@ -190,6 +189,126 @@ def test_eliminate_permutation_all_permutations(self): self.assertEqual(len(result_graph), 3) self.assertTrue(graph_has_op(result_graph, "permute")) + def test_do_not_eliminate_permutation_of_strided_input(self): + dtype = "float" + shape = [3, 2, 4] + new_shape = [3, 2 * 2] + target = detect_target() + + x = Tensor(shape, name="x", dtype=dtype, is_input=True) + s1 = ops.dynamic_slice()( + x, start_indices=[0, 0, 2], end_indices=[2147483647, 2147483647, 4] + ) + p1 = ops.permute()(s1, dims=[0, 2, 1]) + p2 = ops.permute()(p1, dims=[0, 2, 1]) + z = ops.reshape()(p2, new_shape) + z._attrs["is_output"] = True + z._attrs["name"] = "z" + + with compile_model( + z, target, "./tmp", "test_do_not_eliminate_permutation_of_strided_input" + ) as module: + # Verify the generated graph. + sorted_graph = module.debug_sorted_graph + self.assertEqual(len(sorted_graph), 4) + self.assertTrue(graph_has_op(sorted_graph, "permute021")) + + x_pt = get_random_torch_tensor(shape, dtype) + z_pt = get_torch_empty_tensor(new_shape, dtype) + + module.run_with_tensors({"x": x_pt}, {"z": z_pt}) + + self.assertTrue( + torch.equal( + torch.reshape(torch.split(x_pt, 2, dim=2)[1], new_shape), z_pt + ) + ) + + def test_do_not_eliminate_permutation_of_strided_input2(self): + dtype = "float" + shape = [3, 4, 2] + new_shape = [3, 2 * 2] + target = detect_target() + + x = Tensor(shape, name="x", dtype=dtype, is_input=True) + p1 = ops.permute()(x, dims=[0, 2, 1]) + s1 = ops.dynamic_slice()( + p1, start_indices=[0, 0, 2], end_indices=[2147483647, 2147483647, 4] + ) + p2 = ops.permute()(s1, dims=[0, 2, 1]) + z = ops.reshape()(p2, new_shape) + z._attrs["is_output"] = True + z._attrs["name"] = "z" + + with compile_model( + z, target, "./tmp", "test_do_not_eliminate_permutation_of_strided_input2" + ) as module: + # Verify the generated graph. + sorted_graph = module.debug_sorted_graph + self.assertEqual(len(sorted_graph), 4) + self.assertTrue(graph_has_op(sorted_graph, "permute021")) + + x_pt = get_random_torch_tensor(shape, dtype) + z_pt = get_torch_empty_tensor(new_shape, dtype) + + module.run_with_tensors({"x": x_pt}, {"z": z_pt}) + + self.assertTrue( + torch.equal( + torch.reshape( + torch.permute( + torch.split(torch.permute(x_pt, (0, 2, 1)), 2, dim=2)[1], + (0, 2, 1), + ), + new_shape, + ), + z_pt, + ) + ) + + def test_do_not_eliminate_permutation_of_reshaped_input(self): + dtype = "float" + shape = [3, 2, 4] + new_shape = [3, 2, 4] + target = detect_target() + + x = Tensor(shape, name="x", dtype=dtype, is_input=True) + p1 = ops.permute()(x, dims=[0, 2, 1]) + r1 = ops.reshape()(p1, new_shape) + p2 = ops.permute()(r1, dims=[0, 2, 1]) + z = ops.dynamic_slice()( + p2, start_indices=[0, 0, 1], end_indices=[2147483647, 2147483647, 2] + ) + z._attrs["is_output"] = True + z._attrs["name"] = "z" + + with compile_model( + z, target, "./tmp", "test_do_not_eliminate_permutation_of_reshaped_input" + ) as module: + # Verify the generated graph. + sorted_graph = module.debug_sorted_graph + self.assertEqual(len(sorted_graph), 4) + self.assertTrue(graph_has_op(sorted_graph, "permute021")) + + x_pt = get_random_torch_tensor(shape, dtype) + z_pt = get_torch_empty_tensor([3, 4, 1], dtype) + + module.run_with_tensors({"x": x_pt}, {"z": z_pt}) + + self.assertTrue( + torch.equal( + torch.split( + torch.permute( + torch.reshape(torch.permute(x_pt, (0, 2, 1)), new_shape), + (0, 2, 1), + ), + 1, + dim=2, + )[1], + z_pt, + ) + ) + if __name__ == "__main__": unittest.main()