diff --git a/python/aitemplate/compiler/transform/__init__.py b/python/aitemplate/compiler/transform/__init__.py index 3ff2d800a..c195d4087 100644 --- a/python/aitemplate/compiler/transform/__init__.py +++ b/python/aitemplate/compiler/transform/__init__.py @@ -37,7 +37,6 @@ from aitemplate.compiler.transform.optimize_graph import optimize_graph from aitemplate.compiler.transform.profile import profile from aitemplate.compiler.transform.refine_graph import refine_graph -from aitemplate.compiler.transform.remove_id_ops import remove_id_ops from aitemplate.compiler.transform.remove_no_ops import remove_no_ops from aitemplate.compiler.transform.remove_unused_ops import remove_unused_ops from aitemplate.compiler.transform.split_large_concat_ops import split_large_concat_ops diff --git a/python/aitemplate/compiler/transform/optimize_graph.py b/python/aitemplate/compiler/transform/optimize_graph.py index 70d51b34d..edf0eede5 100644 --- a/python/aitemplate/compiler/transform/optimize_graph.py +++ b/python/aitemplate/compiler/transform/optimize_graph.py @@ -41,7 +41,6 @@ from aitemplate.compiler.transform.remove_elementwise_no_ops import ( remove_elementwise_no_ops, ) -from aitemplate.compiler.transform.remove_id_ops import remove_id_ops from aitemplate.compiler.transform.split_large_concat_ops import split_large_concat_ops from aitemplate.compiler.transform.split_large_slice_scatter_ops import ( split_large_slice_scatter_ops, @@ -95,7 +94,6 @@ def optimize_graph( """ funcs = [ - remove_id_ops, remove_elementwise_no_ops, dedup_make_jagged_ops, fuse_permute_bmm_and_gemm, diff --git a/python/aitemplate/compiler/transform/remove_id_ops.py b/python/aitemplate/compiler/transform/remove_id_ops.py deleted file mode 100644 index 6b9057e67..000000000 --- a/python/aitemplate/compiler/transform/remove_id_ops.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -Remove id ops from a sorted_graph. -""" -from typing import List - -from aitemplate.compiler.base import Tensor -from aitemplate.compiler.transform import transform_utils - - -def remove_id_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]: - """Remove id ops from the input sorted_graph.""" - for tensor in sorted_graph: - src_ops = tensor._attrs["src_ops"] - if len(src_ops) != 1: - continue - src_op = list(src_ops)[0] - if src_op._attrs["op"] != "identity": - continue - id_op = src_op - input_tensor = id_op._attrs["inputs"][0] - # skip a very special case where id takes an input and produces an output - if tensor._attrs["is_output"] and input_tensor._attrs["is_input"]: - continue - transform_utils.remove_single_tensor_op_from_sorted_graph(id_op) - - sorted_graph = transform_utils.sanitize_sorted_graph(sorted_graph) - return transform_utils.sanitize_sorted_graph(sorted_graph) diff --git a/python/aitemplate/compiler/transform/remove_no_ops.py b/python/aitemplate/compiler/transform/remove_no_ops.py index fd5d325dd..4fe11e7bb 100644 --- a/python/aitemplate/compiler/transform/remove_no_ops.py +++ b/python/aitemplate/compiler/transform/remove_no_ops.py @@ -40,6 +40,28 @@ from aitemplate.utils.shape_utils import is_singleton_dimension +def _remove_id_ops(sorted_graph: List[Tensor]) -> List[Tensor]: + """Remove identity ops.""" + ops = graph_utils.get_sorted_ops(sorted_graph) + for op in ops: + if op._attrs["op"] != "identity": + continue + + inputs = op._attrs["inputs"] + assert len(inputs) == 1, "identity must only have 1 input" + + outputs = op._attrs["outputs"] + identity_output = outputs[0] + assert len(inputs) == 1, "identity must only have 1 output" + + # skip a very special case where id takes an input and produces an output + if identity_output._attrs["is_output"] and inputs[0]._attrs["is_input"]: + continue + + transform_utils.remove_single_tensor_op_from_sorted_graph(op) + return transform_utils.sanitize_sorted_graph(sorted_graph) + + def _remove_no_op_concats(sorted_graph: List[Tensor]) -> List[Tensor]: """ Remove no-op concats from the graph. A no-op concat is where the output @@ -328,6 +350,7 @@ def remove_no_ops(sorted_graph: List[Tensor]) -> List[Tensor]: Graph after remove no-ops """ passes = [ + _remove_id_ops, _remove_no_op_concats, _remove_no_op_dynamic_slices, _remove_no_op_splits,