From 6983d9ee23e8e8b6c735056a990695c1f13ab7ac Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Tue, 7 May 2024 22:38:42 -0700 Subject: [PATCH] Range operator lowering minimizer Summary: During our perf tuning, we are seeing mtml_instagram_model lag behind AIT around 23% on mergenet (525662456/1670) How do we compare the perf between AIT and AOTI? Refering the idea from Oleg, I tried to implement a minimizer leveraging acc tracer + fx splitter. It is good these two backend still can use the same acc tracer so we can make sure the front end graphs are exactly the same. Then we can leverage the splitter to define a range of operator to lower. In this way, we can directly compare the perf of a small subgraph between these two backends. How to use it? (1) attach --debug_operator_range="0, 100" randomly and you will see graph nodes printed out in log P1233104423 (2) attach the precise range you want. For ex, --debug_operator_range="1613,1635" and you will see logs like this ``` INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_122 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_124 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_125 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_126 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_127 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_128 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_129 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_130 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_131 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_132 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_133 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_134 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_135 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_136 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_137 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_138 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_139 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_140 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_141 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_142 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_143 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_144 to selected nodes INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=, n.name=linear_145 to selected nodes ``` Reviewed By: hl475 Differential Revision: D57030900 fbshipit-source-id: b55ac5b908d9752ee7448da51f32130accd19839 --- fx2ait/fx2ait/ait_splitter.py | 71 ++++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 9 deletions(-) diff --git a/fx2ait/fx2ait/ait_splitter.py b/fx2ait/fx2ait/ait_splitter.py index 17add596f..9679b75fc 100644 --- a/fx2ait/fx2ait/ait_splitter.py +++ b/fx2ait/fx2ait/ait_splitter.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, Iterable, Mapping, Sequence +import logging +from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Set import torch import torch.fx.passes.operator_support as ops @@ -23,8 +24,9 @@ from fx2ait.converters.converter_registry import AIT_CONVERTERS from fx2ait.fx2ait import AITInterpreter from torch.fx.passes.operator_support import create_op_support, OperatorSupportBase -from torch.fx.passes.tools_common import get_acc_ops_name +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, get_acc_ops_name +logger: logging.Logger = logging.getLogger(__name__) _VIEW_OPS = frozenset( ( @@ -109,12 +111,53 @@ def create_ait_operator_support( class AITSplitterSettings(splitter_base._SplitterSettingBase): # TODO: Fix this once pytorch nightly is updated def __init__( - self, min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, allow_int_inputs=False + self, + min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, + allow_int_inputs=False, + debug_operator_range=None, ): super().__init__() self.min_acc_module_size = min_acc_module_size self.exclude_support_node_name: set = set() self.allow_int_inputs: bool = allow_int_inputs + self.debug_operator_range = debug_operator_range + + +class SelectedOperatorSupport(ops.OperatorSupportBase): + def __init__(self, selected_nodes: Set[torch.fx.Node]) -> None: + self.selected_nodes = selected_nodes + + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return node in self.selected_nodes + + +def _range_operator_support( + module: torch.fx.GraphModule, start: int, end: int +) -> ops.OperatorSupportBase: + if start > end: + raise ValueError(f"Start {start} is greater or equal to end {end}") + if end >= len(module.graph.nodes): + raise ValueError( + f"End {end} is greater than number of nodes in the graph {len(module.graph.nodes)}" + ) + logger.info("===Enter into debug mode: set range operator for lowering===") + for i, n in enumerate(module.graph.nodes): + logger.info(f"Index:{i}, n.op={n.op}, n.target={n.target}, n.name={n.name}") + + selected_nodes: Set[torch.fx.Node] = set() + for i, n in enumerate(module.graph.nodes): + if i >= start and i <= end: + if n.op in CALLABLE_NODE_OPS: + selected_nodes.add(n) + logger.info( + f"Add n.op={n.op}, n.target={n.target}, n.name={n.name} to selected nodes" + ) + + if len(selected_nodes) == 0: + raise ValueError("No nodes are selected") + return SelectedOperatorSupport(selected_nodes) class AITSplitter(splitter_base._SplitterBase): @@ -122,16 +165,26 @@ def __init__( self, module: torch.fx.GraphModule, sample_input: Sequence[Any], - operator_support: ops.OperatorSupportBase = None, - settings: AITSplitterSettings = None, + operator_support: Optional[ops.OperatorSupportBase] = None, + settings: Optional[AITSplitterSettings] = None, ): if not settings: settings = AITSplitterSettings() if not operator_support: - operator_support = create_ait_operator_support( - op_lowering_disallow_list=settings.exclude_support_node_name, - allow_int_inputs=settings.allow_int_inputs, - ) + if settings.debug_operator_range: + min_range, max_range = tuple( + int(x) for x in settings.debug_operator_range.split(",") + ) + operator_support = _range_operator_support( + module=module, + start=min_range, + end=max_range, + ) + else: + operator_support = create_ait_operator_support( + op_lowering_disallow_list=settings.exclude_support_node_name, + allow_int_inputs=settings.allow_int_inputs, + ) else: operator_support = ops.chain( operator_support,