Skip to content

Commit

Permalink
Refactor SimpleDisjointSet in fuse_elementwise to use sets (#830)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #830

The `SimpleDisjointSet` used in the `fuse_elementwise` transformation pass is using lists to collect groups of the disjoint nodes. In certain cases, this can lead to a heavy duplication in the node groups resulting in the `fuse_elementwise` pass taking much longer time than it should.

This diff refactors the `SimpleDisjointSet` to use sets instead of lists to alleviate the node duplication and the resulting long pass execution.

Reviewed By: sgrigory

Differential Revision: D47466832

fbshipit-source-id: 13044f18f72763a4de02f769782300b690d7faf3
  • Loading branch information
aakhundov authored and facebook-github-bot committed Jul 14, 2023
1 parent 17e20b4 commit 2cecba2
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions python/aitemplate/compiler/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import collections
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Optional, Set

from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.ops.common import elementwise, fused_elementwise
Expand All @@ -38,39 +38,40 @@

class SimpleDisjointSet:
def __init__(self):
self.node_to_list_mapping: Dict[Any, List[Any]] = {}
self.node_to_set_mapping: Dict[Any, Set[Any]] = {}

def add(self, node: Any, dependent_nodes: Set[Any]) -> None:
if node in self.node_to_list_mapping:
def add(self, node: Any, dependent_nodes: Optional[Set[Any]]) -> None:
if node in self.node_to_set_mapping:
return

if dependent_nodes is None or len(dependent_nodes) == 0:
self.node_to_list_mapping[node] = [node]
self.node_to_set_mapping[node] = {node}
return

current_list = [
node # node should also be considered to decide if a new_list can be added.
]
for dependent in list(dependent_nodes):
if dependent is None or dependent not in self.node_to_list_mapping:
current_set = {
node # node should also be considered to decide if a new_set can be added.
}
for dependent in dependent_nodes:
if dependent is None or dependent not in self.node_to_set_mapping:
continue
new_list = self.node_to_list_mapping.get(dependent)
new_set = self.node_to_set_mapping.get(dependent)

if _detect_cycle(current_list + new_list):
if _detect_cycle(current_set | new_set):
continue
current_list.extend(new_list)
for new_node in new_list:
self.node_to_list_mapping[new_node] = current_list
self.node_to_list_mapping[node] = current_list

def get_node_groups(self) -> List[List[Any]]:
current_set.update(new_set)
for new_node in new_set:
self.node_to_set_mapping[new_node] = current_set
self.node_to_set_mapping[node] = current_set

def get_node_groups(self) -> List[Set[Any]]:
node_groups = []
visited = set()
for groups in self.node_to_list_mapping.values():
addr = id(groups)
for group in self.node_to_set_mapping.values():
addr = id(group)
if addr not in visited:
visited.add(addr)
node_groups.append(groups)
node_groups.append(group)
return node_groups


Expand Down Expand Up @@ -146,7 +147,7 @@ class FusedElementwiseInfo:
external_outputs: Set[Tensor]


def _partition_subgraphs(ops: List[Operator]) -> Dict[str, Set[Operator]]:
def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, Set[Operator]]:
"""
Given ops of candidate graph of fused_elementwise op graph and partition
into subgraph based on output shape, returns dict of
Expand Down Expand Up @@ -283,7 +284,7 @@ def _create_fuse_ops(info_list: List[FusedElementwiseInfo]) -> None:
)


def _detect_cycle(group: List[Operator]) -> bool:
def _detect_cycle(group: Set[Operator]) -> bool:
"""
Given a group of ops, to detect if they would form cycles, i.e.
--> group_ops
Expand All @@ -294,7 +295,7 @@ def _detect_cycle(group: List[Operator]) -> bool:
"""
parents = [o for op1 in group for i in op1._attrs["inputs"] for o in i.src_ops()]
for op1 in group:
for op2 in set(parents) - set(group):
for op2 in set(parents) - group:
if transform_utils.is_ancestor(op1, op2):
return True
return False
Expand Down Expand Up @@ -322,7 +323,7 @@ def fuse_elementwise(sorted_graph: List[Tensor], workdir: str = None) -> List[Te
# Partition subgraph based on output shape.
output_op_map = _partition_subgraphs(ops)
# Collect information to create fuse ops.
info_list = _collect_info(output_op_map, set(ops), sorted_graph)
info_list = _collect_info(output_op_map, ops, sorted_graph)
# Create fuse ops.
_create_fuse_ops(info_list)

Expand Down

0 comments on commit 2cecba2

Please sign in to comment.