Skip to content

Commit

Permalink
Fix and re-enable eliminate_permutations graph transformation (#824)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #824

Fix eliminate_permutations graph transformation in AITemplate to retain successive cancelling permutations on strided input Tensor.

Re-enabled the transformation as well.

Differential Revision: D47371853

fbshipit-source-id: a2eb55549ca510482c57f2fa03b81d88642aff04
  • Loading branch information
amitaga authored and facebook-github-bot committed Jul 13, 2023
1 parent df4381b commit ddc78c4
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 9 deletions.
4 changes: 2 additions & 2 deletions python/aitemplate/compiler/transform/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from aitemplate.compiler.transform.transform_odd_alignment import (
transform_odd_alignment,
)
from aitemplate.compiler.transform.transform_permutations import eliminate_permutations
from aitemplate.compiler.transform.transform_permute_to_reshape import (
transform_permute_to_reshape,
)
Expand Down Expand Up @@ -125,8 +126,7 @@ def optimize_graph(
split_large_split_ops,
transform_permute_to_reshape,
transform_memory_ops,
# FIXME: temporarily disable this due to some accuracy issue
# eliminate_permutations,
eliminate_permutations,
]

if not optimize:
Expand Down
36 changes: 30 additions & 6 deletions python/aitemplate/compiler/transform/transform_permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.compiler.transform import transform_utils


Expand Down Expand Up @@ -60,6 +61,32 @@ def remove_second_permutation_from_graph(
transform_utils.remove_tensor_from_sorted_graph(output_tensor)


def _reshaped_or_strided_input_or_output_accessor(op: Operator) -> bool:
def _reshaped_or_strided_tensor_accessor(accessor: TensorAccessor) -> bool:
if (
accessor.actual_shapes is not None
and accessor.actual_shapes != accessor.original_shapes
):
return True

# Is it a strided accessor
if hasattr(accessor, "stride_dim") and accessor.stride_dim is not None:
return True

return False

input_accessors = op._attrs.get("input_accessors", None)
output_accessors = op._attrs.get("output_accessors", None)

return (
(input_accessors is not None)
and _reshaped_or_strided_tensor_accessor(input_accessors[0])
) or (
(output_accessors is not None)
and _reshaped_or_strided_tensor_accessor(output_accessors[0])
)


def eliminate_permutations(
sorted_graph: List[Tensor], workdir: str = None
) -> List[Tensor]:
Expand All @@ -73,12 +100,7 @@ def eliminate_permutations(
continue
if not cur_op._attrs["op"].startswith("permute"):
continue
input_accessors = cur_op._attrs.get("input_accessors", None)
if (
input_accessors is not None
and hasattr(input_accessors[0], "strided_dim")
and input_accessors[0].strided_dim is not None
):
if _reshaped_or_strided_input_or_output_accessor(cur_op):
continue
curr_op_output = cur_op._attrs["outputs"][0]
dst_ops = curr_op_output._attrs["dst_ops"]
Expand All @@ -89,6 +111,8 @@ def eliminate_permutations(
for next_op in dst_ops:
if not next_op._attrs["op"].startswith("permute"):
continue
if _reshaped_or_strided_input_or_output_accessor(next_op):
continue
p1 = get_permutation(cur_op)
p2 = get_permutation(next_op)
if not np.all(np.array(p1)[p2] == np.arange(0, len(p1))):
Expand Down
121 changes: 120 additions & 1 deletion tests/unittest/compiler/test_eliminate_permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)


@unittest.skip("Skip until we fix the accuracy issue")
class EliminatePermutationTestCase(unittest.TestCase):
def test_eliminate_permutation(self):
dtype = "float"
Expand Down Expand Up @@ -190,6 +189,126 @@ def test_eliminate_permutation_all_permutations(self):
self.assertEqual(len(result_graph), 3)
self.assertTrue(graph_has_op(result_graph, "permute"))

def test_do_not_eliminate_permutation_of_strided_input(self):
dtype = "float"
shape = [3, 2, 4]
new_shape = [3, 2 * 2]
target = detect_target()

x = Tensor(shape, name="x", dtype=dtype, is_input=True)
s1 = ops.dynamic_slice()(
x, start_indices=[0, 0, 2], end_indices=[2147483647, 2147483647, 4]
)
p1 = ops.permute()(s1, dims=[0, 2, 1])
p2 = ops.permute()(p1, dims=[0, 2, 1])
z = ops.reshape()(p2, new_shape)
z._attrs["is_output"] = True
z._attrs["name"] = "z"

with compile_model(
z, target, "./tmp", "test_do_not_eliminate_permutation_of_strided_input"
) as module:
# Verify the generated graph.
sorted_graph = module.debug_sorted_graph
self.assertEqual(len(sorted_graph), 4)
self.assertTrue(graph_has_op(sorted_graph, "permute021"))

x_pt = get_random_torch_tensor(shape, dtype)
z_pt = get_torch_empty_tensor(new_shape, dtype)

module.run_with_tensors({"x": x_pt}, {"z": z_pt})

self.assertTrue(
torch.equal(
torch.reshape(torch.split(x_pt, 2, dim=2)[1], new_shape), z_pt
)
)

def test_do_not_eliminate_permutation_of_strided_input2(self):
dtype = "float"
shape = [3, 4, 2]
new_shape = [3, 2 * 2]
target = detect_target()

x = Tensor(shape, name="x", dtype=dtype, is_input=True)
p1 = ops.permute()(x, dims=[0, 2, 1])
s1 = ops.dynamic_slice()(
p1, start_indices=[0, 0, 2], end_indices=[2147483647, 2147483647, 4]
)
p2 = ops.permute()(s1, dims=[0, 2, 1])
z = ops.reshape()(p2, new_shape)
z._attrs["is_output"] = True
z._attrs["name"] = "z"

with compile_model(
z, target, "./tmp", "test_do_not_eliminate_permutation_of_strided_input2"
) as module:
# Verify the generated graph.
sorted_graph = module.debug_sorted_graph
self.assertEqual(len(sorted_graph), 4)
self.assertTrue(graph_has_op(sorted_graph, "permute021"))

x_pt = get_random_torch_tensor(shape, dtype)
z_pt = get_torch_empty_tensor(new_shape, dtype)

module.run_with_tensors({"x": x_pt}, {"z": z_pt})

self.assertTrue(
torch.equal(
torch.reshape(
torch.permute(
torch.split(torch.permute(x_pt, (0, 2, 1)), 2, dim=2)[1],
(0, 2, 1),
),
new_shape,
),
z_pt,
)
)

def test_do_not_eliminate_permutation_of_reshaped_input(self):
dtype = "float"
shape = [3, 2, 4]
new_shape = [3, 2, 4]
target = detect_target()

x = Tensor(shape, name="x", dtype=dtype, is_input=True)
p1 = ops.permute()(x, dims=[0, 2, 1])
r1 = ops.reshape()(p1, new_shape)
p2 = ops.permute()(r1, dims=[0, 2, 1])
z = ops.dynamic_slice()(
p2, start_indices=[0, 0, 1], end_indices=[2147483647, 2147483647, 2]
)
z._attrs["is_output"] = True
z._attrs["name"] = "z"

with compile_model(
z, target, "./tmp", "test_do_not_eliminate_permutation_of_reshaped_input"
) as module:
# Verify the generated graph.
sorted_graph = module.debug_sorted_graph
self.assertEqual(len(sorted_graph), 4)
self.assertTrue(graph_has_op(sorted_graph, "permute021"))

x_pt = get_random_torch_tensor(shape, dtype)
z_pt = get_torch_empty_tensor([3, 4, 1], dtype)

module.run_with_tensors({"x": x_pt}, {"z": z_pt})

self.assertTrue(
torch.equal(
torch.split(
torch.permute(
torch.reshape(torch.permute(x_pt, (0, 2, 1)), new_shape),
(0, 2, 1),
),
1,
dim=2,
)[1],
z_pt,
)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit ddc78c4

Please sign in to comment.