Skip to content

Commit

Permalink
Range operator lowering minimizer
Browse files Browse the repository at this point in the history
Summary:
During our perf tuning, we are seeing mtml_instagram_model lag behind AIT around 23% on mergenet (525662456/1670)
How do we compare the perf between AIT and AOTI? Refering the idea from Oleg, I tried to implement a minimizer leveraging acc tracer + fx splitter. It is good these two backend still can use the same acc tracer so we can make sure the front end graphs are exactly the same. Then we can leverage the splitter to define a range of operator to lower. In this way, we can directly compare the perf of a small subgraph between these two backends.

How to use it?
(1) attach --debug_operator_range="0, 100" randomly and you will see graph nodes printed out in log P1233104423

(2) attach the precise range you want. For ex, --debug_operator_range="1613,1635" and you will see logs like this
```
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_122 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_124 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_125 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_126 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_127 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_128 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_129 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_130 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_131 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_132 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_133 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_134 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_135 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_136 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_137 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_138 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_139 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_140 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_141 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_142 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_143 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_144 to selected nodes
INFO:fx2ait.ait_splitter:Add n.op=call_function, n.target=<function linear at 0x7f01014a7b50>, n.name=linear_145 to selected nodes
```

Reviewed By: hl475

Differential Revision: D57030900

fbshipit-source-id: b55ac5b908d9752ee7448da51f32130accd19839
  • Loading branch information
frank-wei authored and facebook-github-bot committed May 8, 2024
1 parent 0973303 commit 6983d9e
Showing 1 changed file with 62 additions and 9 deletions.
71 changes: 62 additions & 9 deletions fx2ait/fx2ait/ait_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, Iterable, Mapping, Sequence
import logging
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Set

import torch
import torch.fx.passes.operator_support as ops
Expand All @@ -23,8 +24,9 @@
from fx2ait.converters.converter_registry import AIT_CONVERTERS
from fx2ait.fx2ait import AITInterpreter
from torch.fx.passes.operator_support import create_op_support, OperatorSupportBase
from torch.fx.passes.tools_common import get_acc_ops_name
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, get_acc_ops_name

logger: logging.Logger = logging.getLogger(__name__)

_VIEW_OPS = frozenset(
(
Expand Down Expand Up @@ -109,29 +111,80 @@ def create_ait_operator_support(
class AITSplitterSettings(splitter_base._SplitterSettingBase):
# TODO: Fix this once pytorch nightly is updated
def __init__(
self, min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, allow_int_inputs=False
self,
min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
allow_int_inputs=False,
debug_operator_range=None,
):
super().__init__()
self.min_acc_module_size = min_acc_module_size
self.exclude_support_node_name: set = set()
self.allow_int_inputs: bool = allow_int_inputs
self.debug_operator_range = debug_operator_range


class SelectedOperatorSupport(ops.OperatorSupportBase):
def __init__(self, selected_nodes: Set[torch.fx.Node]) -> None:
self.selected_nodes = selected_nodes

def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
return node in self.selected_nodes


def _range_operator_support(
module: torch.fx.GraphModule, start: int, end: int
) -> ops.OperatorSupportBase:
if start > end:
raise ValueError(f"Start {start} is greater or equal to end {end}")
if end >= len(module.graph.nodes):
raise ValueError(
f"End {end} is greater than number of nodes in the graph {len(module.graph.nodes)}"
)
logger.info("===Enter into debug mode: set range operator for lowering===")
for i, n in enumerate(module.graph.nodes):
logger.info(f"Index:{i}, n.op={n.op}, n.target={n.target}, n.name={n.name}")

selected_nodes: Set[torch.fx.Node] = set()
for i, n in enumerate(module.graph.nodes):
if i >= start and i <= end:
if n.op in CALLABLE_NODE_OPS:
selected_nodes.add(n)
logger.info(
f"Add n.op={n.op}, n.target={n.target}, n.name={n.name} to selected nodes"
)

if len(selected_nodes) == 0:
raise ValueError("No nodes are selected")
return SelectedOperatorSupport(selected_nodes)


class AITSplitter(splitter_base._SplitterBase):
def __init__(
self,
module: torch.fx.GraphModule,
sample_input: Sequence[Any],
operator_support: ops.OperatorSupportBase = None,
settings: AITSplitterSettings = None,
operator_support: Optional[ops.OperatorSupportBase] = None,
settings: Optional[AITSplitterSettings] = None,
):
if not settings:
settings = AITSplitterSettings()
if not operator_support:
operator_support = create_ait_operator_support(
op_lowering_disallow_list=settings.exclude_support_node_name,
allow_int_inputs=settings.allow_int_inputs,
)
if settings.debug_operator_range:
min_range, max_range = tuple(
int(x) for x in settings.debug_operator_range.split(",")
)
operator_support = _range_operator_support(
module=module,
start=min_range,
end=max_range,
)
else:
operator_support = create_ait_operator_support(
op_lowering_disallow_list=settings.exclude_support_node_name,
allow_int_inputs=settings.allow_int_inputs,
)
else:
operator_support = ops.chain(
operator_support,
Expand Down

0 comments on commit 6983d9e

Please sign in to comment.