Skip to content

Commit

Permalink
Add the shark compile downstream due to pytorch/pytorch#104185 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
Prashant Kumar committed Jul 1, 2023
1 parent 6d286c0 commit 1c32915
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 58 deletions.
28 changes: 28 additions & 0 deletions shark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import importlib
import logging

from torch._dynamo import register_backend

log = logging.getLogger(__name__)


@register_backend
def shark(model, inputs, *, options):
try:
from shark.dynamo_backend.utils import SharkBackend
except ImportError:
log.exception(
"Unable to import SHARK - High Performance Machine Learning Distribution"
"Please install the right version of SHARK that matches the PyTorch version being used. "
"Refer to https://github.com/nod-ai/SHARK/ for details."
)
raise
return SharkBackend(model, inputs, options)


def has_shark():
try:
importlib.import_module("shark")
return True
except ImportError:
return False
71 changes: 13 additions & 58 deletions shark/examples/shark_dynamo/basic_examples.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,25 @@
import torch
import torch_mlir
import torch._dynamo as torchdynamo
from shark.sharkdynamo.utils import make_shark_compiler
import shark


import warnings, logging
def foo(x, a):
if x.shape[0] > 3:
return x + a
else:
return x + 3

warnings.simplefilter("ignore")
torchdynamo.config.log_level = logging.ERROR

shark_options = {"device": "cpu"}
compiled = torch.compile(foo, backend="shark", options=shark_options)

torchdynamo.reset()
input = torch.ones(4)

x = compiled(input, input)

@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(t):
return 2 * t


example_input = torch.rand((2, 3))
x = foo(example_input)
print(x)

input = torch.ones(3)

torchdynamo.reset()


@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
)
def foo(a, b):
x = a / (a + 1)
if b.sum() < 0:
b = b * -1
return x * b


print(foo(torch.rand((2, 3)), -torch.rand((2, 3))))


torchdynamo.reset()

x = compiled(input, input)

@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def foo(a):
for i in range(10):
a += 1.0
return a


print(foo(torch.rand((1, 2))))

torchdynamo.reset()


@torchdynamo.optimize(
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
)
def test_unsupported_types(t, y):
return t, 2 * y


str_input = "hello"
tensor_input = torch.randn(2)
print(test_unsupported_types(str_input, tensor_input))
print(x)

0 comments on commit 1c32915

Please sign in to comment.