diff --git a/fx2ait/fx2ait/ait_splitter.py b/fx2ait/fx2ait/ait_splitter.py index 17add596f..9679b75fc 100644 --- a/fx2ait/fx2ait/ait_splitter.py +++ b/fx2ait/fx2ait/ait_splitter.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, Iterable, Mapping, Sequence +import logging +from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Set import torch import torch.fx.passes.operator_support as ops @@ -23,8 +24,9 @@ from fx2ait.converters.converter_registry import AIT_CONVERTERS from fx2ait.fx2ait import AITInterpreter from torch.fx.passes.operator_support import create_op_support, OperatorSupportBase -from torch.fx.passes.tools_common import get_acc_ops_name +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, get_acc_ops_name +logger: logging.Logger = logging.getLogger(__name__) _VIEW_OPS = frozenset( ( @@ -109,12 +111,53 @@ def create_ait_operator_support( class AITSplitterSettings(splitter_base._SplitterSettingBase): # TODO: Fix this once pytorch nightly is updated def __init__( - self, min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, allow_int_inputs=False + self, + min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, + allow_int_inputs=False, + debug_operator_range=None, ): super().__init__() self.min_acc_module_size = min_acc_module_size self.exclude_support_node_name: set = set() self.allow_int_inputs: bool = allow_int_inputs + self.debug_operator_range = debug_operator_range + + +class SelectedOperatorSupport(ops.OperatorSupportBase): + def __init__(self, selected_nodes: Set[torch.fx.Node]) -> None: + self.selected_nodes = selected_nodes + + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return node in self.selected_nodes + + +def _range_operator_support( + module: torch.fx.GraphModule, start: int, end: int +) -> ops.OperatorSupportBase: + if start > end: + raise ValueError(f"Start {start} is greater or equal to end {end}") + if end >= len(module.graph.nodes): + raise ValueError( + f"End {end} is greater than number of nodes in the graph {len(module.graph.nodes)}" + ) + logger.info("===Enter into debug mode: set range operator for lowering===") + for i, n in enumerate(module.graph.nodes): + logger.info(f"Index:{i}, n.op={n.op}, n.target={n.target}, n.name={n.name}") + + selected_nodes: Set[torch.fx.Node] = set() + for i, n in enumerate(module.graph.nodes): + if i >= start and i <= end: + if n.op in CALLABLE_NODE_OPS: + selected_nodes.add(n) + logger.info( + f"Add n.op={n.op}, n.target={n.target}, n.name={n.name} to selected nodes" + ) + + if len(selected_nodes) == 0: + raise ValueError("No nodes are selected") + return SelectedOperatorSupport(selected_nodes) class AITSplitter(splitter_base._SplitterBase): @@ -122,16 +165,26 @@ def __init__( self, module: torch.fx.GraphModule, sample_input: Sequence[Any], - operator_support: ops.OperatorSupportBase = None, - settings: AITSplitterSettings = None, + operator_support: Optional[ops.OperatorSupportBase] = None, + settings: Optional[AITSplitterSettings] = None, ): if not settings: settings = AITSplitterSettings() if not operator_support: - operator_support = create_ait_operator_support( - op_lowering_disallow_list=settings.exclude_support_node_name, - allow_int_inputs=settings.allow_int_inputs, - ) + if settings.debug_operator_range: + min_range, max_range = tuple( + int(x) for x in settings.debug_operator_range.split(",") + ) + operator_support = _range_operator_support( + module=module, + start=min_range, + end=max_range, + ) + else: + operator_support = create_ait_operator_support( + op_lowering_disallow_list=settings.exclude_support_node_name, + allow_int_inputs=settings.allow_int_inputs, + ) else: operator_support = ops.chain( operator_support,