diff --git a/fx2ait/fx2ait/ait_splitter.py b/fx2ait/fx2ait/ait_splitter.py index 5eb6a04c6..16244b777 100644 --- a/fx2ait/fx2ait/ait_splitter.py +++ b/fx2ait/fx2ait/ait_splitter.py @@ -115,12 +115,14 @@ def __init__( min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, allow_int_inputs=False, debug_operator_range=None, + max_acc_splits=-1, ): super().__init__() self.min_acc_module_size = min_acc_module_size self.exclude_support_node_name: set = set() self.allow_int_inputs: bool = allow_int_inputs self.debug_operator_range = debug_operator_range + self.max_acc_splits = max_acc_splits class SelectedOperatorSupport(ops.OperatorSupportBase): diff --git a/fx2ait/fx2ait/lower/lower.py b/fx2ait/fx2ait/lower/lower.py index 2a7e4ab7b..42e734af4 100644 --- a/fx2ait/fx2ait/lower/lower.py +++ b/fx2ait/fx2ait/lower/lower.py @@ -100,6 +100,7 @@ def default_split_function( settings = AITSplitterSettings( min_acc_module_size=lower_settings.min_acc_module_size, allow_int_inputs=lower_settings.allow_int_inputs, + max_acc_splits=lower_settings.max_acc_splits, ) splitter = AITSplitter(model, inputs, settings=settings) splitter.node_support_preview() diff --git a/fx2ait/fx2ait/lower/lower_settings.py b/fx2ait/fx2ait/lower/lower_settings.py index a63d19b7b..1c1ffbb21 100644 --- a/fx2ait/fx2ait/lower/lower_settings.py +++ b/fx2ait/fx2ait/lower/lower_settings.py @@ -61,6 +61,9 @@ class LowerSettings: max_batch_size: int = 2048 min_acc_module_size: int = 10 + # Maximum number of splits for lowered module + # (eg. if lowered module is split into _run_on_gpu_0(unlowered submodule) and _run_on_acc_1(lowered submodule) it has 2 splits) + max_acc_splits: int = -1 workdir: str = "" name: str = "" dll_name: str = "ait_engine.so" diff --git a/fx2ait/fx2ait/test/test_ait_splitter.py b/fx2ait/fx2ait/test/test_ait_splitter.py index bbb7a0ed3..7acfd09b5 100644 --- a/fx2ait/fx2ait/test/test_ait_splitter.py +++ b/fx2ait/fx2ait/test/test_ait_splitter.py @@ -250,3 +250,58 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: dict(split_results_relu_allowed.split_module.named_children()).keys(), {"_run_on_acc_0"}, ) + + def test_fail_if_exceed_max_acc_split_limit(self): + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(b) + d = torch.cos(c) + e = torch.sigmoid(d) + f = torch.tanh(e) + return f + + # Support all ops + _support_dict = { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.tanh": None, + } + custom_op_support = op_support.OperatorSupport(_support_dict) + + # With no ops excluded, the entire module should be lowered + # into one acc graph + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + settings = AITSplitterSettings(min_acc_module_size=0, max_acc_splits=1) + splitter = AITSplitter( + mod, + (torch.randn(2, 3),), + custom_op_support, + settings, + ) + + res_all_nodes_supported = splitter.generate_split_results() + split_named_mods = dict(res_all_nodes_supported.split_module.named_children()) + self.assertEqual(len(split_named_mods), 1) + self.assertIn("_run_on_acc_0", split_named_mods) + + # Add "relu" to exclude_support_node_name + # The graph should be split into 3 parts now(_run_on_acc_0, _run_on_gpu_1, _run_on_acc_2) + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + for node in mod.graph.nodes: + if node.target == acc_ops.relu: + settings.exclude_support_node_name.add(node.name) + splitter = AITSplitter( + mod, + (torch.randn(2, 3),), + custom_op_support, + settings, + ) + # Split should fail now + with self.assertRaisesRegex( + ValueError, + "Cannot fulfill max_acc_splits limit. This may cause split fragmentation and result in performance issues.", + ): + splitter.generate_split_results()