diff --git a/shark/examples/shark_inference/mega_test.py b/shark/examples/shark_inference/mega_test.py index efc5e70b79..a4e6f6b406 100644 --- a/shark/examples/shark_inference/mega_test.py +++ b/shark/examples/shark_inference/mega_test.py @@ -1,10 +1,7 @@ import torch import torch_mlir from shark.shark_inference import SharkInference -from apps.stable_diffusion.src.utils import ( - compile_through_fx, - args, -) +from shark.shark_compile import shark_compile_through_fx from MEGABYTE_pytorch import MEGABYTE import os @@ -37,23 +34,22 @@ def forward(self, input): megaModel = MegaModel() -input = [torch.randint(0, 16000, (1, 1024, 4))] +inputs = [torch.randint(0, 16000, (1, 1024, 4))] # CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :- # 1. aten.alias -shark_module, _ = compile_through_fx( - megaModel, - inputs=input, +shark_module, _ = shark_compile_through_fx( + model=megaModel, + inputs=inputs, extended_model_name="mega_shark", - debug=False, - generate_vmfb=True, + is_f16=False, + f16_input_mask=None, save_dir=os.getcwd(), + debug=False, + generate_or_load_vmfb=True, extra_args=[], - base_model_id=None, - model_name="mega_shark", - precision=None, - return_mlir=True, device="cuda", + mlir_dialect="tm_tensor", ) # logits = model(x) @@ -63,10 +59,10 @@ def print_output_info(output, msg): print("\n\t", output.shape) -ans = shark_module("forward", input) +ans = shark_module("forward", inputs) print_output_info(torch.from_numpy(ans), "SHARK's output") -ans = megaModel.forward(*input) +ans = megaModel.forward(*inputs) print_output_info(ans, "ORIGINAL Model's output") # and sample from the logits accordingly diff --git a/shark/shark_compile.py b/shark/shark_compile.py new file mode 100644 index 0000000000..79431155f5 --- /dev/null +++ b/shark/shark_compile.py @@ -0,0 +1,99 @@ +import os +import tempfile +from shark.shark_inference import SharkInference +from shark.shark_importer import import_with_fx + + +def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]): + vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") + shark_module = None + if os.path.isfile(vmfb_path): + shark_module = SharkInference( + None, + device=device, + mlir_dialect=mlir_dialect, + ) + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=extra_args) + return shark_module + + +def compile_module( + shark_module, extended_model_name, generate_vmfb, extra_args=[] +): + if generate_vmfb: + vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") + if os.path.isfile(vmfb_path): + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=extra_args) + else: + print( + "No vmfb found. Compiling and saving to {}".format(vmfb_path) + ) + path = shark_module.save_module( + os.getcwd(), extended_model_name, extra_args + ) + shark_module.load_module(path, extra_args=extra_args) + else: + shark_module.compile(extra_args) + return shark_module + + +def shark_compile_through_fx( + model, + inputs, + extended_model_name, + is_f16=False, + f16_input_mask=None, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=None, + mlir_dialect="tm_tensor", +): + if generate_or_load_vmfb: + shark_module = load_vmfb( + extended_model_name=extended_model_name, + device=device, + mlir_dialect=mlir_dialect, + extra_args=extra_args, + ) + if shark_module: + return ( + shark_module, + None, + ) + + from shark.parser import shark_args + + if "cuda" in device: + shark_args.enable_tf32 = True + + ( + mlir_module, + _, + ) = import_with_fx( + model=model, + inputs=inputs, + is_f16=is_f16, + f16_input_mask=f16_input_mask, + debug=debug, + model_name=extended_model_name, + save_dir=save_dir, + ) + + shark_module = SharkInference( + mlir_module, + device=device, + mlir_dialect=mlir_dialect, + ) + return ( + compile_module( + shark_module, + extended_model_name, + generate_vmfb=generate_or_load_vmfb, + extra_args=extra_args, + ), + mlir_module, + )