-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the shark backend for torch.compile API. (#1596)
- Loading branch information
Prashant Kumar
authored
Jun 26, 2023
1 parent
eaa49cc
commit 27a0873
Showing
4 changed files
with
154 additions
and
174 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import functools | ||
from typing import List, Optional | ||
import torch | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
from torch._functorch.compile_utils import strip_overloads | ||
from shark.shark_inference import SharkInference | ||
from torch._decomp import get_decompositions | ||
from torch.func import functionalize | ||
import io | ||
import torch_mlir | ||
|
||
|
||
# TODO: Control decompositions. | ||
def default_decompositions(): | ||
return get_decompositions( | ||
[ | ||
torch.ops.aten.embedding_dense_backward, | ||
torch.ops.aten.native_layer_norm_backward, | ||
torch.ops.aten.slice_backward, | ||
torch.ops.aten.select_backward, | ||
torch.ops.aten.norm.ScalarOpt_dim, | ||
torch.ops.aten.native_group_norm, | ||
torch.ops.aten.upsample_bilinear2d.vec, | ||
torch.ops.aten.split.Tensor, | ||
torch.ops.aten.split_with_sizes, | ||
torch.ops.aten.native_layer_norm, | ||
torch.ops.aten.masked_fill.Tensor, | ||
torch.ops.aten.masked_fill.Scalar, | ||
] | ||
) | ||
|
||
|
||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: | ||
removed_indexes = [] | ||
for node in fx_g.graph.nodes: | ||
if node.op == "output": | ||
assert ( | ||
len(node.args) == 1 | ||
), "Output node must have a single argument" | ||
node_arg = node.args[0] | ||
if isinstance(node_arg, (list, tuple)): | ||
node_arg = list(node_arg) | ||
node_args_len = len(node_arg) | ||
for i in range(node_args_len): | ||
curr_index = node_args_len - (i + 1) | ||
if node_arg[curr_index] is None: | ||
removed_indexes.append(curr_index) | ||
node_arg.pop(curr_index) | ||
node.args = (tuple(node_arg),) | ||
break | ||
|
||
if len(removed_indexes) > 0: | ||
fx_g.graph.lint() | ||
fx_g.graph.eliminate_dead_code() | ||
fx_g.recompile() | ||
removed_indexes.sort() | ||
return removed_indexes | ||
|
||
|
||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool: | ||
for node in fx_g.graph.nodes: | ||
if node.op == "output": | ||
assert ( | ||
len(node.args) == 1 | ||
), "Output node must have a single argument" | ||
node_arg = node.args[0] | ||
if isinstance(node_arg, tuple): | ||
return len(node_arg) == 0 | ||
return False | ||
|
||
|
||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: | ||
""" | ||
Replace tuple with tuple element in functions that return one-element tuples. | ||
Returns true if an unwrapping took place, and false otherwise. | ||
""" | ||
unwrapped_tuple = False | ||
for node in fx_g.graph.nodes: | ||
if node.op == "output": | ||
assert ( | ||
len(node.args) == 1 | ||
), "Output node must have a single argument" | ||
node_arg = node.args[0] | ||
if isinstance(node_arg, tuple): | ||
if len(node_arg) == 1: | ||
node.args = (node_arg[0],) | ||
unwrapped_tuple = True | ||
break | ||
|
||
if unwrapped_tuple: | ||
fx_g.graph.lint() | ||
fx_g.recompile() | ||
return unwrapped_tuple | ||
|
||
|
||
class SharkBackend: | ||
def __init__( | ||
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict | ||
): | ||
self.fx_g = fx_g | ||
self.inputs = inputs | ||
self.shark_module = None | ||
self.device: str = options.get("device", "cpu") | ||
self.was_unwrapped: bool = False | ||
self.none_indices: list = [] | ||
self._modify_fx_g() | ||
self.compile() | ||
|
||
def _modify_fx_g(self): | ||
self.none_indices = _remove_nones(self.fx_g) | ||
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g) | ||
|
||
def compile(self): | ||
gm = make_fx( | ||
functionalize(self.fx_g), | ||
decomposition_table=default_decompositions(), | ||
)(*self.inputs) | ||
gm.graph.set_codegen(torch.fx.graph.CodeGen()) | ||
gm.recompile() | ||
strip_overloads(gm) | ||
ts_g = torch.jit.script(gm) | ||
mlir_module = torch_mlir.compile( | ||
ts_g, self.inputs, output_type="linalg-on-tensors" | ||
) | ||
bytecode_stream = io.BytesIO() | ||
mlir_module.operation.write_bytecode(bytecode_stream) | ||
bytecode = bytecode_stream.getvalue() | ||
from shark.shark_inference import SharkInference | ||
|
||
shark_module = SharkInference( | ||
mlir_module=bytecode, | ||
device=self.device, | ||
mlir_dialect="tm_tensor", | ||
) | ||
shark_module.compile(extra_args=[]) | ||
self.shark_module = shark_module | ||
|
||
def __call__(self, *inputs): | ||
np_inputs = [x.detach().cpu().numpy() for x in inputs] | ||
np_outs = self.shark_module("forward", np_inputs) | ||
if self.was_unwrapped: | ||
np_outs = [ | ||
np_outs, | ||
] | ||
|
||
if not isinstance(np_outs, list): | ||
res = torch.from_numpy(np_outs) | ||
return res | ||
|
||
result = [torch.from_numpy(x) for x in np_outs] | ||
for r_in in self.none_indices: | ||
result.insert(r_in, None) | ||
result = tuple(result) | ||
return result |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.