From 2cecba26af875a26ff1573e4f6f6c4bc23400ccd Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Fri, 14 Jul 2023 11:47:47 -0700 Subject: [PATCH] Refactor SimpleDisjointSet in fuse_elementwise to use sets (#830) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/830 The `SimpleDisjointSet` used in the `fuse_elementwise` transformation pass is using lists to collect groups of the disjoint nodes. In certain cases, this can lead to a heavy duplication in the node groups resulting in the `fuse_elementwise` pass taking much longer time than it should. This diff refactors the `SimpleDisjointSet` to use sets instead of lists to alleviate the node duplication and the resulting long pass execution. Reviewed By: sgrigory Differential Revision: D47466832 fbshipit-source-id: 13044f18f72763a4de02f769782300b690d7faf3 --- .../aitemplate/compiler/transform/fuse_ops.py | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/python/aitemplate/compiler/transform/fuse_ops.py b/python/aitemplate/compiler/transform/fuse_ops.py index 40af68f35..60a496553 100644 --- a/python/aitemplate/compiler/transform/fuse_ops.py +++ b/python/aitemplate/compiler/transform/fuse_ops.py @@ -18,7 +18,7 @@ import collections import logging from dataclasses import dataclass -from typing import Any, Dict, List, Set +from typing import Any, Dict, List, Optional, Set from aitemplate.compiler.base import Operator, Tensor from aitemplate.compiler.ops.common import elementwise, fused_elementwise @@ -38,39 +38,40 @@ class SimpleDisjointSet: def __init__(self): - self.node_to_list_mapping: Dict[Any, List[Any]] = {} + self.node_to_set_mapping: Dict[Any, Set[Any]] = {} - def add(self, node: Any, dependent_nodes: Set[Any]) -> None: - if node in self.node_to_list_mapping: + def add(self, node: Any, dependent_nodes: Optional[Set[Any]]) -> None: + if node in self.node_to_set_mapping: return if dependent_nodes is None or len(dependent_nodes) == 0: - self.node_to_list_mapping[node] = [node] + self.node_to_set_mapping[node] = {node} return - current_list = [ - node # node should also be considered to decide if a new_list can be added. - ] - for dependent in list(dependent_nodes): - if dependent is None or dependent not in self.node_to_list_mapping: + current_set = { + node # node should also be considered to decide if a new_set can be added. + } + for dependent in dependent_nodes: + if dependent is None or dependent not in self.node_to_set_mapping: continue - new_list = self.node_to_list_mapping.get(dependent) + new_set = self.node_to_set_mapping.get(dependent) - if _detect_cycle(current_list + new_list): + if _detect_cycle(current_set | new_set): continue - current_list.extend(new_list) - for new_node in new_list: - self.node_to_list_mapping[new_node] = current_list - self.node_to_list_mapping[node] = current_list - def get_node_groups(self) -> List[List[Any]]: + current_set.update(new_set) + for new_node in new_set: + self.node_to_set_mapping[new_node] = current_set + self.node_to_set_mapping[node] = current_set + + def get_node_groups(self) -> List[Set[Any]]: node_groups = [] visited = set() - for groups in self.node_to_list_mapping.values(): - addr = id(groups) + for group in self.node_to_set_mapping.values(): + addr = id(group) if addr not in visited: visited.add(addr) - node_groups.append(groups) + node_groups.append(group) return node_groups @@ -146,7 +147,7 @@ class FusedElementwiseInfo: external_outputs: Set[Tensor] -def _partition_subgraphs(ops: List[Operator]) -> Dict[str, Set[Operator]]: +def _partition_subgraphs(ops: Set[Operator]) -> Dict[str, Set[Operator]]: """ Given ops of candidate graph of fused_elementwise op graph and partition into subgraph based on output shape, returns dict of @@ -283,7 +284,7 @@ def _create_fuse_ops(info_list: List[FusedElementwiseInfo]) -> None: ) -def _detect_cycle(group: List[Operator]) -> bool: +def _detect_cycle(group: Set[Operator]) -> bool: """ Given a group of ops, to detect if they would form cycles, i.e. --> group_ops @@ -294,7 +295,7 @@ def _detect_cycle(group: List[Operator]) -> bool: """ parents = [o for op1 in group for i in op1._attrs["inputs"] for o in i.src_ops()] for op1 in group: - for op2 in set(parents) - set(group): + for op2 in set(parents) - group: if transform_utils.is_ancestor(op1, op2): return True return False @@ -322,7 +323,7 @@ def fuse_elementwise(sorted_graph: List[Tensor], workdir: str = None) -> List[Te # Partition subgraph based on output shape. output_op_map = _partition_subgraphs(ops) # Collect information to create fuse ops. - info_list = _collect_info(output_op_map, set(ops), sorted_graph) + info_list = _collect_info(output_op_map, ops, sorted_graph) # Create fuse ops. _create_fuse_ops(info_list)