Skip to content

Commit

Permalink
[SHARK] Add a compile API to use for quick testing of inference (#1606)
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Jun 28, 2023
1 parent 6274a81 commit d496053
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 16 deletions.
28 changes: 12 additions & 16 deletions shark/examples/shark_inference/mega_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import torch_mlir
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.utils import (
compile_through_fx,
args,
)
from shark.shark_compile import shark_compile_through_fx
from MEGABYTE_pytorch import MEGABYTE

import os
Expand Down Expand Up @@ -37,23 +34,22 @@ def forward(self, input):


megaModel = MegaModel()
input = [torch.randint(0, 16000, (1, 1024, 4))]
inputs = [torch.randint(0, 16000, (1, 1024, 4))]

# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
# 1. aten.alias
shark_module, _ = compile_through_fx(
megaModel,
inputs=input,
shark_module, _ = shark_compile_through_fx(
model=megaModel,
inputs=inputs,
extended_model_name="mega_shark",
debug=False,
generate_vmfb=True,
is_f16=False,
f16_input_mask=None,
save_dir=os.getcwd(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
base_model_id=None,
model_name="mega_shark",
precision=None,
return_mlir=True,
device="cuda",
mlir_dialect="tm_tensor",
)
# logits = model(x)

Expand All @@ -63,10 +59,10 @@ def print_output_info(output, msg):
print("\n\t", output.shape)


ans = shark_module("forward", input)
ans = shark_module("forward", inputs)
print_output_info(torch.from_numpy(ans), "SHARK's output")

ans = megaModel.forward(*input)
ans = megaModel.forward(*inputs)
print_output_info(ans, "ORIGINAL Model's output")

# and sample from the logits accordingly
Expand Down
99 changes: 99 additions & 0 deletions shark/shark_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx


def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
shark_module = None
if os.path.isfile(vmfb_path):
shark_module = SharkInference(
None,
device=device,
mlir_dialect=mlir_dialect,
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module


def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
if os.path.isfile(vmfb_path):
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
print(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module


def shark_compile_through_fx(
model,
inputs,
extended_model_name,
is_f16=False,
f16_input_mask=None,
save_dir=tempfile.gettempdir(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
device=None,
mlir_dialect="tm_tensor",
):
if generate_or_load_vmfb:
shark_module = load_vmfb(
extended_model_name=extended_model_name,
device=device,
mlir_dialect=mlir_dialect,
extra_args=extra_args,
)
if shark_module:
return (
shark_module,
None,
)

from shark.parser import shark_args

if "cuda" in device:
shark_args.enable_tf32 = True

(
mlir_module,
_,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)

shark_module = SharkInference(
mlir_module,
device=device,
mlir_dialect=mlir_dialect,
)
return (
compile_module(
shark_module,
extended_model_name,
generate_vmfb=generate_or_load_vmfb,
extra_args=extra_args,
),
mlir_module,
)

0 comments on commit d496053

Please sign in to comment.