From 27a08735dbb277d40280343bd9177b802dd2ee9f Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Mon, 26 Jun 2023 16:23:32 +0530 Subject: [PATCH] Add the shark backend for torch.compile API. (#1596) --- .../__init__.py | 0 shark/dynamo_backend/utils.py | 154 +++++++++++++++++ shark/sharkdynamo/README.md | 11 -- shark/sharkdynamo/utils.py | 163 ------------------ 4 files changed, 154 insertions(+), 174 deletions(-) rename shark/{sharkdynamo => dynamo_backend}/__init__.py (100%) create mode 100644 shark/dynamo_backend/utils.py delete mode 100644 shark/sharkdynamo/README.md delete mode 100644 shark/sharkdynamo/utils.py diff --git a/shark/sharkdynamo/__init__.py b/shark/dynamo_backend/__init__.py similarity index 100% rename from shark/sharkdynamo/__init__.py rename to shark/dynamo_backend/__init__.py diff --git a/shark/dynamo_backend/utils.py b/shark/dynamo_backend/utils.py new file mode 100644 index 0000000000..91876e3df5 --- /dev/null +++ b/shark/dynamo_backend/utils.py @@ -0,0 +1,154 @@ +import functools +from typing import List, Optional +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from torch._functorch.compile_utils import strip_overloads +from shark.shark_inference import SharkInference +from torch._decomp import get_decompositions +from torch.func import functionalize +import io +import torch_mlir + + +# TODO: Control decompositions. +def default_decompositions(): + return get_decompositions( + [ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes, + torch.ops.aten.native_layer_norm, + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar, + ] + ) + + +def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: + removed_indexes = [] + for node in fx_g.graph.nodes: + if node.op == "output": + assert ( + len(node.args) == 1 + ), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, (list, tuple)): + node_arg = list(node_arg) + node_args_len = len(node_arg) + for i in range(node_args_len): + curr_index = node_args_len - (i + 1) + if node_arg[curr_index] is None: + removed_indexes.append(curr_index) + node_arg.pop(curr_index) + node.args = (tuple(node_arg),) + break + + if len(removed_indexes) > 0: + fx_g.graph.lint() + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + removed_indexes.sort() + return removed_indexes + + +def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool: + for node in fx_g.graph.nodes: + if node.op == "output": + assert ( + len(node.args) == 1 + ), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, tuple): + return len(node_arg) == 0 + return False + + +def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: + """ + Replace tuple with tuple element in functions that return one-element tuples. + Returns true if an unwrapping took place, and false otherwise. + """ + unwrapped_tuple = False + for node in fx_g.graph.nodes: + if node.op == "output": + assert ( + len(node.args) == 1 + ), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, tuple): + if len(node_arg) == 1: + node.args = (node_arg[0],) + unwrapped_tuple = True + break + + if unwrapped_tuple: + fx_g.graph.lint() + fx_g.recompile() + return unwrapped_tuple + + +class SharkBackend: + def __init__( + self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict + ): + self.fx_g = fx_g + self.inputs = inputs + self.shark_module = None + self.device: str = options.get("device", "cpu") + self.was_unwrapped: bool = False + self.none_indices: list = [] + self._modify_fx_g() + self.compile() + + def _modify_fx_g(self): + self.none_indices = _remove_nones(self.fx_g) + self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g) + + def compile(self): + gm = make_fx( + functionalize(self.fx_g), + decomposition_table=default_decompositions(), + )(*self.inputs) + gm.graph.set_codegen(torch.fx.graph.CodeGen()) + gm.recompile() + strip_overloads(gm) + ts_g = torch.jit.script(gm) + mlir_module = torch_mlir.compile( + ts_g, self.inputs, output_type="linalg-on-tensors" + ) + bytecode_stream = io.BytesIO() + mlir_module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + from shark.shark_inference import SharkInference + + shark_module = SharkInference( + mlir_module=bytecode, + device=self.device, + mlir_dialect="tm_tensor", + ) + shark_module.compile(extra_args=[]) + self.shark_module = shark_module + + def __call__(self, *inputs): + np_inputs = [x.detach().cpu().numpy() for x in inputs] + np_outs = self.shark_module("forward", np_inputs) + if self.was_unwrapped: + np_outs = [ + np_outs, + ] + + if not isinstance(np_outs, list): + res = torch.from_numpy(np_outs) + return res + + result = [torch.from_numpy(x) for x in np_outs] + for r_in in self.none_indices: + result.insert(r_in, None) + result = tuple(result) + return result diff --git a/shark/sharkdynamo/README.md b/shark/sharkdynamo/README.md deleted file mode 100644 index 095cb63a96..0000000000 --- a/shark/sharkdynamo/README.md +++ /dev/null @@ -1,11 +0,0 @@ -1. Install torchdynamo - - `git clone https://github.com/pytorch/torchdynamo.git` - - `cd torchdynamo` - - `python -m pip install -r requirements.txt` - - `python setup.py develop` - -2. Install functorch - - `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"` - -3. Run examples. - - `python shark/examples/shark_dynamo/basic_examples.py` diff --git a/shark/sharkdynamo/utils.py b/shark/sharkdynamo/utils.py deleted file mode 100644 index 11815fb685..0000000000 --- a/shark/sharkdynamo/utils.py +++ /dev/null @@ -1,163 +0,0 @@ -import functools -import time -from typing import List, Optional -import torch -from torch.fx.experimental.proxy_tensor import make_fx -from torch._functorch.compile_utils import strip_overloads -from shark.shark_inference import SharkInference -from torch._decomp import get_decompositions - -import torch_mlir - - -# TODO: Control decompositions. -def default_decompositions(): - return get_decompositions( - [ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.native_group_norm, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - ] - ) - - -def timeit(*, append_time_to: Optional[List] = None): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - start_time = time.time_ns() - result = func(*args, **kwargs) - end_time = time.time_ns() - - if append_time_to is not None: - append_time_to.append(end_time - start_time) - return result - - return wrapper - - return decorator - - -def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool: - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - return len(node_arg) == 0 - return False - - -def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: - """ - Replace tuple with tuple element in functions that return one-element tuples. - Returns true if an unwrapping took place, and false otherwise. - """ - unwrapped_tuple = False - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - if len(node_arg) == 1: - node.args = (node_arg[0],) - unwrapped_tuple = True - break - - if unwrapped_tuple: - fx_g.graph.lint() - fx_g.recompile() - return unwrapped_tuple - - -def make_shark_compiler(use_tracing: bool, device: str, verbose=False): - def compiler( - fx_graph: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - ): - """Compile GraphModule using torch-mlir + SHARK.""" - if verbose: - print("Compiling graph...") - - if _returns_nothing(fx_graph): - return fx_graph - - was_unwrapped = _unwrap_single_tuple_return(fx_graph) - fx_graph = make_fx( - fx_graph, decomposition_table=default_decompositions() - )(*example_inputs) - strip_overloads(fx_graph) - - if verbose: - print("torch.fx graph:") - print(fx_graph.graph) - - ts_compiler = torch.jit.trace if use_tracing else torch.jit.script - ts_graph = ts_compiler(fx_graph, example_inputs) - - if verbose: - torch_mlir_module = torch_mlir.compile( - ts_graph, - example_inputs, - output_type=torch_mlir.OutputType.TORCH, - ) - print("\n\ntorch-mlir backend contract graph:") - print(torch_mlir_module) - - linalg_module = torch_mlir.compile( - ts_graph, - example_inputs, - output_type=torch_mlir.OutputType.LINALG_ON_TENSORS, - ) - import io - - bytecode_stream = io.BytesIO() - linalg_module.operation.write_bytecode(bytecode_stream) - mlir_module = bytecode_stream.getvalue() - - shark_module = SharkInference( - mlir_module, mlir_dialect="linalg", device=device - ) - shark_module.compile() - - def forward(*inputs): - result = shark_module("forward", inputs) - result = tuple() if result is None else result - return (result,) if was_unwrapped else result - - return forward - - return compiler - - -def check_results(compiled_results, eager_results): - for compiled_result, eager_result in zip(compiled_results, eager_results): - if not torch.allclose( - compiled_result.to("cpu"), eager_result.to("cpu"), atol=1e-5 - ): - print("Compiled result does not match eager result") - return - print("Compiled result matches eager result!") - - -def print_time_stats(times): - times_tensor = torch.tensor(times) - - def quantile_ms(q): - return torch.quantile(times_tensor.to(float), q).item() / 1e6 - - print(f"Median: {quantile_ms(0.5)} ms") - print(f"10%ile: {quantile_ms(0.1)} ms") - print(f"90%ile: {quantile_ms(0.9)} ms") - print(f"Total: {torch.sum(times_tensor) / 1e6} ms") - print()