Skip to content

Commit

Permalink
Move remove_id_ops into remove_no_ops (facebookincubator#873)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#873

### This diff
Relocates `remove_id_ops` pass in `optimize_graph.py` call to `remove_no_ops.py`.

Reviewed By: muchulee8

Differential Revision: D47937117

fbshipit-source-id: cd7e5d793d34b1acf4a8a632e3a34031d136756d
  • Loading branch information
ColinPeppler authored and facebook-github-bot committed Aug 7, 2023
1 parent e3a89be commit aa730f3
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 44 deletions.
1 change: 0 additions & 1 deletion python/aitemplate/compiler/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from aitemplate.compiler.transform.optimize_graph import optimize_graph
from aitemplate.compiler.transform.profile import profile
from aitemplate.compiler.transform.refine_graph import refine_graph
from aitemplate.compiler.transform.remove_id_ops import remove_id_ops
from aitemplate.compiler.transform.remove_no_ops import remove_no_ops
from aitemplate.compiler.transform.remove_unused_ops import remove_unused_ops
from aitemplate.compiler.transform.split_large_concat_ops import split_large_concat_ops
Expand Down
2 changes: 0 additions & 2 deletions python/aitemplate/compiler/transform/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from aitemplate.compiler.transform.remove_elementwise_no_ops import (
remove_elementwise_no_ops,
)
from aitemplate.compiler.transform.remove_id_ops import remove_id_ops
from aitemplate.compiler.transform.split_large_concat_ops import split_large_concat_ops
from aitemplate.compiler.transform.split_large_slice_scatter_ops import (
split_large_slice_scatter_ops,
Expand Down Expand Up @@ -95,7 +94,6 @@ def optimize_graph(
"""

funcs = [
remove_id_ops,
remove_elementwise_no_ops,
dedup_make_jagged_ops,
fuse_permute_bmm_and_gemm,
Expand Down
41 changes: 0 additions & 41 deletions python/aitemplate/compiler/transform/remove_id_ops.py

This file was deleted.

23 changes: 23 additions & 0 deletions python/aitemplate/compiler/transform/remove_no_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@
from aitemplate.utils.shape_utils import is_singleton_dimension


def _remove_id_ops(sorted_graph: List[Tensor]) -> List[Tensor]:
"""Remove identity ops."""
ops = graph_utils.get_sorted_ops(sorted_graph)
for op in ops:
if op._attrs["op"] != "identity":
continue

inputs = op._attrs["inputs"]
assert len(inputs) == 1, "identity must only have 1 input"

outputs = op._attrs["outputs"]
identity_output = outputs[0]
assert len(inputs) == 1, "identity must only have 1 output"

# skip a very special case where id takes an input and produces an output
if identity_output._attrs["is_output"] and inputs[0]._attrs["is_input"]:
continue

transform_utils.remove_single_tensor_op_from_sorted_graph(op)
return transform_utils.sanitize_sorted_graph(sorted_graph)


def _remove_no_op_concats(sorted_graph: List[Tensor]) -> List[Tensor]:
"""
Remove no-op concats from the graph. A no-op concat is where the output
Expand Down Expand Up @@ -328,6 +350,7 @@ def remove_no_ops(sorted_graph: List[Tensor]) -> List[Tensor]:
Graph after remove no-ops
"""
passes = [
_remove_id_ops,
_remove_no_op_concats,
_remove_no_op_dynamic_slices,
_remove_no_op_splits,
Expand Down

0 comments on commit aa730f3

Please sign in to comment.