Skip to content

Commit

Permalink
Add shark_eager mode.
Browse files Browse the repository at this point in the history
-- Eager mode with step by step op compilation and execution.
  • Loading branch information
Prashant Kumar committed Jun 22, 2023
1 parent cdd505e commit a811b86
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 0 deletions.
206 changes: 206 additions & 0 deletions shark/shark_eager/shark_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx
import torchvision.models as models
import copy
import io
import numpy as np
import sys
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
import torch_mlir


def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
mlir_module = torch_mlir.compile(
fx_g, inputs, output_type="linalg-on-tensors"
)
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference

shark_module = SharkInference(
mlir_module=bytecode,
device=device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
return shark_module


def _make_single_op_gm(node, captured_val, compiled_graph):
"""Make a GraphModule that just executes the given node."""
g = torch.fx.Graph()
env = {}
inputs = []
for arg in node.args:
if arg and hasattr(arg, "name"):
env[arg.name] = g.placeholder(arg.name)
if isinstance(captured_val[arg.name], (list, tuple)):
for val in captured_val[arg.name]:
inputs.append(val)
else:
inputs.append(captured_val[arg.name])

call = g.node_copy(node, lambda n: env[n.name])
g.output(call)
g.lint()
single_node = torch.fx.GraphModule(torch.nn.Module(), g)
compiled_module = shark_backend(single_node, inputs)
compiled_graph[node.name] = {
"module": compiled_module,
"inputs": [i for i in env],
"result": None,
}
return


def compiled_graph(gm: torch.fx.GraphModule, attr_info):
compiled_graph = {}
g = gm.graph
for node in g.nodes:
if node.op == "call_function":
if not (
node.target in [torch.ops.aten.empty]
or node.name.startswith("getitem")
):
_make_single_op_gm(node, attr_info, compiled_graph)

# Currently torch.aten.empty has an compilation issue, so running natively.
elif node.target in [torch.ops.aten.empty]:
compiled_graph[node.name] = {
"target": node.target,
"args": node.args,
"kwargs": node.kwargs,
"result": None,
}
# Get item is a simple case takes a tuple and return the tensor at a particular index.
elif node.name.startswith("getitem"):
compiled_graph[node.name] = {
"input": node.args[0].name,
"pos": node.args[1],
"result": None,
}

return compiled_graph


class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""

def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())

def propagate(self, *args):
args_iter = iter(args)
env: Dict[str, Node] = {}

def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])

def fetch_attr(target: str):
target_atoms = target.split(".")
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr

for node in self.graph.nodes:
if node.op == "placeholder":
result = next(args_iter)
elif node.op == "get_attr":
result = fetch_attr(node.target)
elif node.op == "call_function":
result = node.target(
*load_arg(node.args), **load_arg(node.kwargs)
)
elif node.op == "call_method":
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == "call_module":
result = self.modules[node.target](
*load_arg(node.args), **load_arg(node.kwargs)
)

# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype

env[node.name] = result

return env

# return load_arg(self.graph.result)


resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
input = (torch.randn(1, 3, 224, 224),)

print(resnet18(input[0]))

fx_graph = import_with_fx(resnet18, input, mlir_type="fx")

shape_prop = ShapeProp(fx_graph)

x = shape_prop.propagate(input[0])

shark_graph = compiled_graph(fx_graph, x)


for key in shark_graph:
if key.startswith("getitem"):
input_val = shark_graph[key]["input"]
pos = shark_graph[key]["pos"]
if input_val not in shark_graph:
shark_graph[key]["result"] = x[input_val][pos].detach()
else:
shark_graph[key]["result"] = shark_graph[input_val]["result"][
pos
].detach()
elif key.startswith("empty"):
operator = shark_graph[key]["target"]
args = shark_graph[key]["args"]
kwargs = shark_graph[key]["kwargs"]
shark_graph[key]["result"] = operator(*args, **kwargs).detach()
else:
input_val = shark_graph[key]["inputs"]
input_tensors = []
for input in input_val:
if input not in shark_graph:
input_tensors.append(x[input].detach())
else:
input_tensors.append(shark_graph[input]["result"])

val = shark_graph[key]["module"]("forward", input_tensors)
if isinstance(val, (tuple, list)):
list_val = []
for v in val:
list_val.append(torch.from_numpy(v))
shark_graph[key]["result"] = list_val
else:
shark_graph[key]["result"] = torch.from_numpy(val)


print(shark_graph)
3 changes: 3 additions & 0 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ def strip_overloads(gm):
add_upcast(fx_g)
fx_g.recompile()

if mlir_type == "fx":
return fx_g

if training:
change_fx_graph_return_to_tuple(fx_g)
inputs = flatten_training_input(inputs)
Expand Down

0 comments on commit a811b86

Please sign in to comment.