Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and re-enable eliminate_permutations graph transformation #824

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/aitemplate/compiler/transform/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from aitemplate.compiler.transform.transform_odd_alignment import (
transform_odd_alignment,
)
from aitemplate.compiler.transform.transform_permutations import eliminate_permutations
from aitemplate.compiler.transform.transform_permute_to_reshape import (
transform_permute_to_reshape,
)
Expand Down Expand Up @@ -125,8 +126,7 @@ def optimize_graph(
split_large_split_ops,
transform_permute_to_reshape,
transform_memory_ops,
# FIXME: temporarily disable this due to some accuracy issue
# eliminate_permutations,
eliminate_permutations,
]

if not optimize:
Expand Down
36 changes: 30 additions & 6 deletions python/aitemplate/compiler/transform/transform_permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.tensor_accessor import TensorAccessor
from aitemplate.compiler.transform import transform_utils


Expand Down Expand Up @@ -60,6 +61,32 @@ def remove_second_permutation_from_graph(
transform_utils.remove_tensor_from_sorted_graph(output_tensor)


def _reshaped_or_strided_input_or_output_accessor(op: Operator) -> bool:
def _reshaped_or_strided_tensor_accessor(accessor: TensorAccessor) -> bool:
if (
accessor.actual_shapes is not None
and accessor.actual_shapes != accessor.original_shapes
):
return True

# Is it a strided accessor
if hasattr(accessor, "stride_dim") and accessor.stride_dim is not None:
return True

return False

input_accessors = op._attrs.get("input_accessors", None)
output_accessors = op._attrs.get("output_accessors", None)

return (
(input_accessors is not None)
and _reshaped_or_strided_tensor_accessor(input_accessors[0])
) or (
(output_accessors is not None)
and _reshaped_or_strided_tensor_accessor(output_accessors[0])
)


def eliminate_permutations(
sorted_graph: List[Tensor], workdir: str = None
) -> List[Tensor]:
Expand All @@ -73,12 +100,7 @@ def eliminate_permutations(
continue
if not cur_op._attrs["op"].startswith("permute"):
continue
input_accessors = cur_op._attrs.get("input_accessors", None)
if (
input_accessors is not None
and hasattr(input_accessors[0], "strided_dim")
and input_accessors[0].strided_dim is not None
):
if _reshaped_or_strided_input_or_output_accessor(cur_op):
continue
curr_op_output = cur_op._attrs["outputs"][0]
dst_ops = curr_op_output._attrs["dst_ops"]
Expand All @@ -89,6 +111,8 @@ def eliminate_permutations(
for next_op in dst_ops:
if not next_op._attrs["op"].startswith("permute"):
continue
if _reshaped_or_strided_input_or_output_accessor(next_op):
continue
p1 = get_permutation(cur_op)
p2 = get_permutation(next_op)
if not np.all(np.array(p1)[p2] == np.arange(0, len(p1))):
Expand Down
121 changes: 120 additions & 1 deletion tests/unittest/compiler/test_eliminate_permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)


@unittest.skip("Skip until we fix the accuracy issue")
class EliminatePermutationTestCase(unittest.TestCase):
def test_eliminate_permutation(self):
dtype = "float"
Expand Down Expand Up @@ -190,6 +189,126 @@ def test_eliminate_permutation_all_permutations(self):
self.assertEqual(len(result_graph), 3)
self.assertTrue(graph_has_op(result_graph, "permute"))

def test_do_not_eliminate_permutation_of_strided_input(self):
dtype = "float"
shape = [3, 2, 4]
new_shape = [3, 2 * 2]
target = detect_target()

x = Tensor(shape, name="x", dtype=dtype, is_input=True)
s1 = ops.dynamic_slice()(
x, start_indices=[0, 0, 2], end_indices=[2147483647, 2147483647, 4]
)
p1 = ops.permute()(s1, dims=[0, 2, 1])
p2 = ops.permute()(p1, dims=[0, 2, 1])
z = ops.reshape()(p2, new_shape)
z._attrs["is_output"] = True
z._attrs["name"] = "z"

with compile_model(
z, target, "./tmp", "test_do_not_eliminate_permutation_of_strided_input"
) as module:
# Verify the generated graph.
sorted_graph = module.debug_sorted_graph
self.assertEqual(len(sorted_graph), 4)
self.assertTrue(graph_has_op(sorted_graph, "permute021"))

x_pt = get_random_torch_tensor(shape, dtype)
z_pt = get_torch_empty_tensor(new_shape, dtype)

module.run_with_tensors({"x": x_pt}, {"z": z_pt})

self.assertTrue(
torch.equal(
torch.reshape(torch.split(x_pt, 2, dim=2)[1], new_shape), z_pt
)
)

def test_do_not_eliminate_permutation_of_strided_input2(self):
dtype = "float"
shape = [3, 4, 2]
new_shape = [3, 2 * 2]
target = detect_target()

x = Tensor(shape, name="x", dtype=dtype, is_input=True)
p1 = ops.permute()(x, dims=[0, 2, 1])
s1 = ops.dynamic_slice()(
p1, start_indices=[0, 0, 2], end_indices=[2147483647, 2147483647, 4]
)
p2 = ops.permute()(s1, dims=[0, 2, 1])
z = ops.reshape()(p2, new_shape)
z._attrs["is_output"] = True
z._attrs["name"] = "z"

with compile_model(
z, target, "./tmp", "test_do_not_eliminate_permutation_of_strided_input2"
) as module:
# Verify the generated graph.
sorted_graph = module.debug_sorted_graph
self.assertEqual(len(sorted_graph), 4)
self.assertTrue(graph_has_op(sorted_graph, "permute021"))

x_pt = get_random_torch_tensor(shape, dtype)
z_pt = get_torch_empty_tensor(new_shape, dtype)

module.run_with_tensors({"x": x_pt}, {"z": z_pt})

self.assertTrue(
torch.equal(
torch.reshape(
torch.permute(
torch.split(torch.permute(x_pt, (0, 2, 1)), 2, dim=2)[1],
(0, 2, 1),
),
new_shape,
),
z_pt,
)
)

def test_do_not_eliminate_permutation_of_reshaped_input(self):
dtype = "float"
shape = [3, 2, 4]
new_shape = [3, 2, 4]
target = detect_target()

x = Tensor(shape, name="x", dtype=dtype, is_input=True)
p1 = ops.permute()(x, dims=[0, 2, 1])
r1 = ops.reshape()(p1, new_shape)
p2 = ops.permute()(r1, dims=[0, 2, 1])
z = ops.dynamic_slice()(
p2, start_indices=[0, 0, 1], end_indices=[2147483647, 2147483647, 2]
)
z._attrs["is_output"] = True
z._attrs["name"] = "z"

with compile_model(
z, target, "./tmp", "test_do_not_eliminate_permutation_of_reshaped_input"
) as module:
# Verify the generated graph.
sorted_graph = module.debug_sorted_graph
self.assertEqual(len(sorted_graph), 4)
self.assertTrue(graph_has_op(sorted_graph, "permute021"))

x_pt = get_random_torch_tensor(shape, dtype)
z_pt = get_torch_empty_tensor([3, 4, 1], dtype)

module.run_with_tensors({"x": x_pt}, {"z": z_pt})

self.assertTrue(
torch.equal(
torch.split(
torch.permute(
torch.reshape(torch.permute(x_pt, (0, 2, 1)), new_shape),
(0, 2, 1),
),
1,
dim=2,
)[1],
z_pt,
)
)


if __name__ == "__main__":
unittest.main()