diff --git a/python/aitemplate/compiler/transform/transform_merge_view_ops.py b/python/aitemplate/compiler/transform/transform_merge_view_ops.py index f2ae55dd6..488d7658e 100644 --- a/python/aitemplate/compiler/transform/transform_merge_view_ops.py +++ b/python/aitemplate/compiler/transform/transform_merge_view_ops.py @@ -30,7 +30,7 @@ def _is_inout(t: Tensor): return t._attrs["is_input"] or t._attrs["is_output"] -def _merge_view_ops_for(sorted_graph: List[Tensor], tensor: Tensor) -> List[Tensor]: +def _merge_view_ops_for(graph: List[Tensor], tensor: Tensor) -> List[Tensor]: """ `tensor` should have exactly 1 src op, and that op must be a view op. We will look for view ops in the dst ops and merge them with the src view op @@ -58,11 +58,12 @@ def _merge_view_ops_for(sorted_graph: List[Tensor], tensor: Tensor) -> List[Tens new_out_tensor._attrs["is_output"] = True new_out_tensor._attrs["name"] = out_tensor._attrs["name"] transform_utils.replace_tensor(out_tensor, new_out_tensor) - sorted_graph.append(new_out_tensor) - sorted_graph.remove(out_tensor) + graph.append(new_out_tensor) + graph.remove(out_tensor) removed_ops.add(op) - dst_ops -= removed_ops - return sorted_graph + for op in removed_ops: + transform_utils.remove_view_op_from_sorted_graph(op) + return graph def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]: @@ -84,6 +85,7 @@ def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tens continue dst_ops = tensor._attrs["dst_ops"] if any(op._attrs["op"] in _VIEW_OPS for op in dst_ops): + # NOTE: _merge_view_ops_for does *not* return a sorted graph sorted_graph = _merge_view_ops_for(sorted_graph, tensor) changed = True break @@ -91,9 +93,5 @@ def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tens break if changed: - # Prune tensors that may have become unused after view op merging - for t in sorted_graph: - if len(t._attrs["dst_ops"]) == 0 and not t._attrs["is_output"]: - transform_utils.remove_tensor_from_sorted_graph(t) return transform_utils.sanitize_sorted_graph(toposort(sorted_graph)) return sorted_graph