Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass to merge consecutive view ops #768

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/aitemplate/compiler/transform/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from aitemplate.compiler.transform.split_large_split_ops import split_large_split_ops
from aitemplate.compiler.transform.transform_memory_ops import transform_memory_ops
from aitemplate.compiler.transform.transform_merge_view_ops import merge_view_ops
from aitemplate.compiler.transform.transform_odd_alignment import (
transform_odd_alignment,
)
Expand Down Expand Up @@ -105,6 +106,7 @@ def optimize_graph(
fuse_mm_reshape_permute,
# make sure we run move_view_op_before_concat before transform_memory_ops
move_view_op_before_concat,
merge_view_ops,
transform_memory_ops,
fuse_ops,
fuse_elementwise,
Expand Down
101 changes: 101 additions & 0 deletions python/aitemplate/compiler/transform/transform_merge_view_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This file implements a pass that merges consecutive view ops if possible.
"""
from typing import List, Set

from aitemplate.compiler import ops
from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.transform import transform_utils
from aitemplate.compiler.transform.toposort import toposort
from aitemplate.utils.shape_utils import convert_shape_to_IntVarTensor


_VIEW_OPS = {"reshape", "flatten", "squeeze", "unsqueeze"}


def _is_inout(t: Tensor):
return t._attrs["is_input"] or t._attrs["is_output"]


def _merge_view_ops_for(graph: List[Tensor], tensor: Tensor) -> List[Tensor]:
"""
`tensor` should have exactly 1 src op, and that op must be a view op. We
will look for view ops in the dst ops and merge them with the src view op
by creating a new reshape op.
"""
src_op = tensor._attrs["src_ops"][0]
in_tensor = src_op._attrs["inputs"][0]
dst_ops = tensor._attrs["dst_ops"]
removed_ops: Set[Operator] = set()
for op in dst_ops:
if op._attrs["op"] not in _VIEW_OPS:
continue
out_tensor = op._attrs["outputs"][0]
in_shape = in_tensor._attrs["shape"]
out_shape = out_tensor._attrs["shape"]
if out_shape == in_shape and not (
_is_inout(in_tensor) and _is_inout(out_tensor)
):
# If the shapes are identical, we can eliminate both view ops
transform_utils.replace_tensor(out_tensor, in_tensor)
else:
# Otherwise, create a new reshape op to replace the two view ops
out_shape = convert_shape_to_IntVarTensor(out_tensor)
new_out_tensor = ops.reshape()(in_tensor, out_shape)
if out_tensor._attrs["is_output"]:
new_out_tensor._attrs["is_output"] = True
new_out_tensor._attrs["name"] = out_tensor._attrs["name"]
transform_utils.replace_tensor(out_tensor, new_out_tensor)
graph.append(new_out_tensor)
graph.remove(out_tensor)
removed_ops.add(op)
for op in removed_ops:
transform_utils.remove_view_op_from_sorted_graph(op)
return graph


def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]:
"""
Merge consecutive view ops.
"""
changed = False
# Find pairs of consecutive view ops and merge them, iterating to a
# fixpoint.
# TODO: Instead of merging pairs of view ops, we should look for entire
# chains of view ops and merge them all at once.
while True:
for tensor in sorted_graph:
src_ops = tensor._attrs["src_ops"]
if len(src_ops) != 1:
continue
src_op = list(src_ops)[0]
if src_op._attrs["op"] not in _VIEW_OPS:
continue
dst_ops = tensor._attrs["dst_ops"]
if any(op._attrs["op"] in _VIEW_OPS for op in dst_ops):
# NOTE: _merge_view_ops_for does *not* return a sorted graph
sorted_graph = _merge_view_ops_for(sorted_graph, tensor)
changed = True
break
else:
break

if changed:
# Prune tensors that may have become unused after view op merging
sorted_graph = toposort([t for t in sorted_graph if t._attrs["is_output"]])
return transform_utils.sanitize_sorted_graph(toposort(sorted_graph))
return sorted_graph
26 changes: 26 additions & 0 deletions python/aitemplate/utils/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import sympy

from aitemplate.compiler.base import IntVar, IntVarTensor, Tensor


def gen_int_var(
values: List[int], name: str = None, symbolic_value: Optional[sympy.Basic] = None
Expand Down Expand Up @@ -155,6 +157,30 @@ def convert_shape_to_IntVar(shape):
return ret


def convert_shape_to_IntVarTensor(tensor: Tensor):
"""
Map IntVars in the tensor's shape to their corresponding IntVarTensors, if any.
"""
shape = tensor._attrs["shape"]
if not any(isinstance(v, IntVar) for v in shape):
return shape

intvar_to_tensor = {}
for op in tensor.src_ops():
for t in op._attrs["inputs"]:
if isinstance(t, IntVarTensor):
intvar_to_tensor[t._attrs["int_var"]] = t

ret = []
for v in shape:
# Using type() instead of isinstance() because we don't want to include IntImms
if type(v) is IntVar:
ret.append(intvar_to_tensor.get(v, v))
else:
ret.append(v)
return ret


def convert_IntVar_to_int(var) -> int:
"""
Try to convert an IntVar (or an IntVar wrapped in a IntVarTensor) to
Expand Down
Loading
Loading