Skip to content

Commit

Permalink
better implementation of pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
int3 committed Jun 22, 2023
1 parent 55539c3 commit 64b1573
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions python/aitemplate/compiler/transform/transform_merge_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _is_inout(t: Tensor):
return t._attrs["is_input"] or t._attrs["is_output"]


def _merge_view_ops_for(sorted_graph: List[Tensor], tensor: Tensor) -> List[Tensor]:
def _merge_view_ops_for(graph: List[Tensor], tensor: Tensor) -> List[Tensor]:
"""
`tensor` should have exactly 1 src op, and that op must be a view op. We
will look for view ops in the dst ops and merge them with the src view op
Expand Down Expand Up @@ -58,11 +58,11 @@ def _merge_view_ops_for(sorted_graph: List[Tensor], tensor: Tensor) -> List[Tens
new_out_tensor._attrs["is_output"] = True
new_out_tensor._attrs["name"] = out_tensor._attrs["name"]
transform_utils.replace_tensor(out_tensor, new_out_tensor)
sorted_graph.append(new_out_tensor)
sorted_graph.remove(out_tensor)
graph.append(new_out_tensor)
graph.remove(out_tensor)
removed_ops.add(op)
dst_ops -= removed_ops
return sorted_graph
return graph


def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tensor]:
Expand All @@ -84,6 +84,7 @@ def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tens
continue
dst_ops = tensor._attrs["dst_ops"]
if any(op._attrs["op"] in _VIEW_OPS for op in dst_ops):
# NOTE: _merge_view_ops_for does *not* return a sorted graph
sorted_graph = _merge_view_ops_for(sorted_graph, tensor)
changed = True
break
Expand All @@ -92,8 +93,6 @@ def merge_view_ops(sorted_graph: List[Tensor], workdir: str = None) -> List[Tens

if changed:
# Prune tensors that may have become unused after view op merging
for t in sorted_graph:
if len(t._attrs["dst_ops"]) == 0 and not t._attrs["is_output"]:
transform_utils.remove_tensor_from_sorted_graph(t)
return transform_utils.sanitize_sorted_graph(toposort(sorted_graph))
sorted_graph = toposort([t for t in sorted_graph if t._attrs["is_output"]])
return transform_utils.sanitize_sorted_graph(sorted_graph)
return sorted_graph

0 comments on commit 64b1573

Please sign in to comment.