diff --git a/python/aitemplate/compiler/transform/transform_merge_view_ops.py b/python/aitemplate/compiler/transform/transform_merge_view_ops.py index 122fe7813..f2ae55dd6 100644 --- a/python/aitemplate/compiler/transform/transform_merge_view_ops.py +++ b/python/aitemplate/compiler/transform/transform_merge_view_ops.py @@ -91,5 +91,9 @@ def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tens break if changed: + # Prune tensors that may have become unused after view op merging + for t in sorted_graph: + if len(t._attrs["dst_ops"]) == 0 and not t._attrs["is_output"]: + transform_utils.remove_tensor_from_sorted_graph(t) return transform_utils.sanitize_sorted_graph(toposort(sorted_graph)) return sorted_graph