diff --git a/python/aitemplate/compiler/transform/transform_merge_view_ops.py b/python/aitemplate/compiler/transform/transform_merge_view_ops.py index 8e0b988be..cfb426df6 100644 --- a/python/aitemplate/compiler/transform/transform_merge_view_ops.py +++ b/python/aitemplate/compiler/transform/transform_merge_view_ops.py @@ -61,7 +61,8 @@ def _merge_view_ops_for(graph: List[Tensor], tensor: Tensor) -> List[Tensor]: graph.append(new_out_tensor) graph.remove(out_tensor) removed_ops.add(op) - dst_ops -= removed_ops + for op in removed_ops: + transform_utils.remove_view_op_from_sorted_graph(op) return graph