From 64b1573627472c50d3dca659a5c52e83b0211733 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 22 Jun 2023 14:03:19 -0700 Subject: [PATCH] better implementation of pruning --- .../transform/transform_merge_view_ops.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/aitemplate/compiler/transform/transform_merge_view_ops.py b/python/aitemplate/compiler/transform/transform_merge_view_ops.py index f2ae55dd6..8e0b988be 100644 --- a/python/aitemplate/compiler/transform/transform_merge_view_ops.py +++ b/python/aitemplate/compiler/transform/transform_merge_view_ops.py @@ -30,7 +30,7 @@ def _is_inout(t: Tensor): return t._attrs["is_input"] or t._attrs["is_output"] -def _merge_view_ops_for(sorted_graph: List[Tensor], tensor: Tensor) -> List[Tensor]: +def _merge_view_ops_for(graph: List[Tensor], tensor: Tensor) -> List[Tensor]: """ `tensor` should have exactly 1 src op, and that op must be a view op. We will look for view ops in the dst ops and merge them with the src view op @@ -58,11 +58,11 @@ def _merge_view_ops_for(sorted_graph: List[Tensor], tensor: Tensor) -> List[Tens new_out_tensor._attrs["is_output"] = True new_out_tensor._attrs["name"] = out_tensor._attrs["name"] transform_utils.replace_tensor(out_tensor, new_out_tensor) - sorted_graph.append(new_out_tensor) - sorted_graph.remove(out_tensor) + graph.append(new_out_tensor) + graph.remove(out_tensor) removed_ops.add(op) dst_ops -= removed_ops - return sorted_graph + return graph def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]: @@ -84,6 +84,7 @@ def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tens continue dst_ops = tensor._attrs["dst_ops"] if any(op._attrs["op"] in _VIEW_OPS for op in dst_ops): + # NOTE: _merge_view_ops_for does *not* return a sorted graph sorted_graph = _merge_view_ops_for(sorted_graph, tensor) changed = True break @@ -92,8 +93,6 @@ def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tens if changed: # Prune tensors that may have become unused after view op merging - for t in sorted_graph: - if len(t._attrs["dst_ops"]) == 0 and not t._attrs["is_output"]: - transform_utils.remove_tensor_from_sorted_graph(t) - return transform_utils.sanitize_sorted_graph(toposort(sorted_graph)) + sorted_graph = toposort([t for t in sorted_graph if t._attrs["is_output"]]) + return transform_utils.sanitize_sorted_graph(sorted_graph) return sorted_graph