diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index 341b8739c8..fa22fb1a89 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -35,6 +35,8 @@ jobs: include: - os: ubuntu-latest suite: lint + - os: MacStudio + suite: metal exclude: - os: ubuntu-latest suite: vulkan @@ -46,6 +48,8 @@ jobs: suite: cuda - os: MacStudio suite: cpu + - os: MacStudio + suite: vulkan - os: icelake suite: vulkan - os: icelake @@ -125,15 +129,14 @@ jobs: # python build_tools/stable_diffusion_testing.py --device=cuda - name: Validate Vulkan Models (MacOS) - if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio' + if: matrix.suite == 'metal' && matrix.os == 'MacStudio' run: | cd $GITHUB_WORKSPACE PYTHON=python${{ matrix.python-version }} ./setup_venv.sh source shark.venv/bin/activate - export DYLD_LIBRARY_PATH=/usr/local/lib/ echo $PATH pip list | grep -E "torch|iree" - pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan + pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal - name: Validate Vulkan Models (a100) if: matrix.suite == 'vulkan' && matrix.os == 'a100' diff --git a/.gitignore b/.gitignore index efc8970565..395a677ba6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ __pycache__/ *.py[cod] *$py.class +*.mlir +*.vmfb # C extensions *.so diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index e566d44a39..4745504ae5 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -28,10 +28,10 @@ def __init__( max_num_tokens=512, device="cuda", precision="fp32", - first_vicuna_mlir_path=Path("first_vicuna.mlir"), - second_vicuna_mlir_path=Path("second_vicuna.mlir"), - first_vicuna_vmfb_path=Path("first_vicuna.vmfb"), - second_vicuna_vmfb_path=Path("second_vicuna.vmfb"), + first_vicuna_mlir_path=None, + second_vicuna_mlir_path=None, + first_vicuna_vmfb_path=None, + second_vicuna_vmfb_path=None, load_mlir_from_shark_tank=True, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) @@ -42,9 +42,27 @@ def __init__( self.second_vicuna_vmfb_path = second_vicuna_vmfb_path self.first_vicuna_mlir_path = first_vicuna_mlir_path self.second_vicuna_mlir_path = second_vicuna_mlir_path + self.load_mlir_from_shark_tank = load_mlir_from_shark_tank + if self.first_vicuna_mlir_path == None: + self.first_vicuna_mlir_path = self.get_model_path() + if self.second_vicuna_mlir_path == None: + self.second_vicuna_mlir_path = self.get_model_path("second") + if self.first_vicuna_vmfb_path == None: + self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb") + if self.second_vicuna_vmfb_path == None: + self.second_vicuna_vmfb_path = self.get_model_path( + "second", "vmfb" + ) self.tokenizer = self.get_tokenizer() self.shark_model = self.compile() - self.load_mlir_from_shark_tank = load_mlir_from_shark_tank + + def get_model_path(self, model_number="first", suffix="mlir"): + safe_device = "_".join(self.device.split("-")) + if suffix == "mlir": + return Path(f"{model_number}_vicuna_{self.precision}.{suffix}") + return Path( + f"{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}" + ) def get_tokenizer(self): tokenizer = AutoTokenizer.from_pretrained( diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index 00c3faa2a7..0a905535d3 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -380,7 +380,7 @@ def is_valid_file(arg): ) p.add_argument( - "--iree_metal_target_platfrom", + "--iree_metal_target_platform", type=str, default="", help="Specify target triple for metal", diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 2ca9b9aef4..6d11f96d08 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -277,12 +277,12 @@ def set_init_device_flags(): args.device = "cuda" elif "metal" in args.device: device_name, args.device = map_device_to_name_path(args.device) - if not args.iree_metal_target_platfrom: + if not args.iree_metal_target_platform: triple = get_metal_target_triple(device_name) if triple is not None: - args.iree_metal_target_platfrom = triple + args.iree_metal_target_platform = triple print( - f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}." + f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}." ) elif "cpu" in args.device: args.device = "cpu" @@ -445,7 +445,10 @@ def get_devices_by_name(driver_name): available_devices.extend(metal_devices) cuda_devices = get_devices_by_name("cuda") available_devices.extend(cuda_devices) - available_devices.append("device => cpu") + cpu_device = get_devices_by_name("cpu-sync") + available_devices.extend(cpu_device) + cpu_device = get_devices_by_name("cpu-task") + available_devices.extend(cpu_device) return available_devices diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 3f2edf68ee..a0cbd59a62 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -1,6 +1,11 @@ from multiprocessing import Process, freeze_support import os import sys + +if sys.platform == "darwin": + # import before IREE to avoid torch-MLIR library issues + import torch_mlir + import shutil import PIL, transformers # ensures inclusion in pysintaller exe generation from apps.stable_diffusion.src import args, clear_all diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 6735d5b4ff..0e5cf4092d 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -41,17 +41,21 @@ def chat(curr_system_message, history, model, device, precision): curr_system_message = start_message_vicuna if vicuna_model == 0: - first_vic_vmfb_path = Path("first_vicuna.vmfb") - second_vic_vmfb_path = Path("second_vicuna.vmfb") if "cuda" in device: device = "cuda" + elif "sync" in device: + device = "cpu-sync" + elif "task" in device: + device = "cpu-task" + elif "vulkan" in device: + device = "vulkan" + else: + print("unrecognized device") vicuna_model = Vicuna( "vicuna", hf_model_path=model, device=device, precision=precision, - first_vicuna_vmfb_path=first_vic_vmfb_path, - second_vicuna_vmfb_path=second_vic_vmfb_path, ) messages = curr_system_message + "".join( [ @@ -120,9 +124,7 @@ def chat(curr_system_message, history, model, device, precision): "TheBloke/vicuna-7B-1.1-HF", ], ) - supported_devices = [ - device for device in available_devices if "cuda" in device - ] + supported_devices = available_devices enabled = len(supported_devices) > 0 device = gr.Dropdown( label="Device", diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 7d2086a398..44e41f1d4c 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -34,7 +34,7 @@ # set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_iree_metal_target_platfrom = args.iree_metal_target_platfrom +init_iree_metal_target_platform = args.iree_metal_target_platform init_use_tuned = args.use_tuned init_import_mlir = args.import_mlir @@ -138,7 +138,7 @@ def txt2img_inf( args.width = width args.device = device.split("=>", 1)[1].strip() args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.iree_metal_target_platfrom = init_iree_metal_target_platfrom + args.iree_metal_target_platform = init_iree_metal_target_platform args.use_tuned = init_use_tuned args.import_mlir = init_import_mlir args.img_path = None diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index 2729af8088..8c79243129 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -101,11 +101,13 @@ def check_device_drivers(device): subprocess.check_output("nvidia-smi") except Exception: return True - elif device in ["metal", "vulkan"]: + elif device in ["vulkan"]: try: subprocess.check_output("vulkaninfo") except Exception: return True + elif device == "metal": + return False elif device in ["intel-gpu"]: try: subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"]) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 8ddb530b38..78ee1ca6d5 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -19,6 +19,7 @@ import numpy as np import os import re +import tempfile # Get the iree-compile arguments given device. @@ -181,8 +182,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks): vmfb_file.close() config = get_iree_runtime_config(device) - vm_module = ireert.VmModule.from_flatbuffer( - config.vm_instance, flatbuffer_blob + vm_module = ireert.VmModule.from_buffer( + config.vm_instance, + flatbuffer_blob, + warn_if_copy=False, ) benchmark_cl = build_benchmark_args_non_tensor_input( @@ -313,8 +316,8 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None): config = ireert.Config(device=haldevice) else: config = get_iree_runtime_config(device) - vm_module = ireert.VmModule.from_flatbuffer( - config.vm_instance, flatbuffer_blob + vm_module = ireert.VmModule.from_buffer( + config.vm_instance, flatbuffer_blob, warn_if_copy=False ) ctx = ireert.SystemContext(config=config) ctx.add_vm_module(vm_module) @@ -322,6 +325,55 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None): return ModuleCompiled, config +def load_vmfb_using_mmap( + flatbuffer_blob_or_path, device: str, device_idx: int = None +): + instance = ireert.VmInstance() + device = iree_device_map(device) + haldriver = ireert.get_driver(device) + haldevice = haldriver.create_device_by_uri( + device, + allocators=[], + ) + # First get configs. + if device_idx is not None: + device = iree_device_map(device) + print("registering device id: ", device_idx) + haldriver = ireert.get_driver(device) + + haldevice = haldriver.create_device( + haldriver.query_available_devices()[device_idx]["device_id"], + allocators=shark_args.device_allocator, + ) + config = ireert.Config(device=haldevice) + else: + config = get_iree_runtime_config(device) + # Now load vmfb. + # Two scenarios we have here :- + # 1. We either have the vmfb already saved and therefore pass the path of it. + # (This would arise if we're invoking `load_module` from a SharkInference obj) + # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. + # (This would arise if we're invoking `compile` from a SharkInference obj) + temp_file_to_unlink = None + if ( + isinstance(flatbuffer_blob_or_path, str) + and ".vmfb" in flatbuffer_blob_or_path + ): + vmfb_file_path = flatbuffer_blob_or_path + mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path) + ctx = ireert.SystemContext(config=config) + ctx.add_vm_module(mmaped_vmfb) + mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) + else: + with tempfile.NamedTemporaryFile(delete=False) as tf: + tf.write(flatbuffer_blob_or_path) + tf.flush() + vmfb_file_path = tf.name + temp_file_to_unlink = vmfb_file_path + mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path) + return mmaped_vmfb, config, temp_file_to_unlink + + def get_iree_compiled_module( module, device: str, @@ -329,19 +381,58 @@ def get_iree_compiled_module( model_config_path: str = None, extra_args: list = [], device_idx: int = None, + mmap: bool = False, ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( module, device, frontend, model_config_path, extra_args ) - return get_iree_module(flatbuffer_blob, device, device_idx=device_idx) - + temp_file_to_unlink = None + # TODO: Currently mmap=True control flow path has been switched off for mmap. + # Got to find a cleaner way to unlink/delete the temporary file since + # we're setting delete=False when creating NamedTemporaryFile. That's why + # I'm getting hold of the name of the temporary file in `temp_file_to_unlink`. + if mmap: + print(f"Will load the compiled module as a mmapped temporary file") + vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( + flatbuffer_blob, device, device_idx + ) + else: + vmfb, config = get_iree_module( + flatbuffer_blob, device, device_idx=device_idx + ) + ret_params = { + "vmfb": vmfb, + "config": config, + "temp_file_to_unlink": temp_file_to_unlink, + } + return ret_params -def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None): - with open(os.path.join(flatbuffer_path), "rb") as f: - flatbuffer_blob = f.read() - return get_iree_module(flatbuffer_blob, device, device_idx=device_idx) +def load_flatbuffer( + flatbuffer_path: str, + device: str, + device_idx: int = None, + mmap: bool = False, +): + temp_file_to_unlink = None + if mmap: + print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file") + vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( + flatbuffer_path, device, device_idx + ) + else: + with open(os.path.join(flatbuffer_path), "rb") as f: + flatbuffer_blob = f.read() + vmfb, config = get_iree_module( + flatbuffer_blob, device, device_idx=device_idx + ) + ret_params = { + "vmfb": vmfb, + "config": config, + "temp_file_to_unlink": temp_file_to_unlink, + } + return ret_params def export_iree_module_to_vmfb( diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py index 7f65488f7b..ef6cdfcc6e 100644 --- a/shark/iree_utils/metal_utils.py +++ b/shark/iree_utils/metal_utils.py @@ -83,7 +83,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): print( f"Found metal device {metal_device}. Using metal target triple {triple}" ) - return f"-iree-metal-target-platfrom={triple}" + return f"-iree-metal-target-platform={triple}" print( """Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] @@ -101,16 +101,16 @@ def get_iree_metal_args(device_num=0, extra_args=[]): for arg in extra_args: if "-iree-metal-target-platform=" in arg: print(f"Using target triple {arg} from command line args") - meatal_triple_flag = arg + metal_triple_flag = arg break if metal_triple_flag is None: - meatal_triple_flag = get_metal_triple_flag( + metal_triple_flag = get_metal_triple_flag( device_num=device_num, extra_args=extra_args ) - if meatal_triple_flag is not None: - vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag) + if metal_triple_flag is not None: + vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag) res_metal_flag.append(vulkan_target_env) return res_metal_flag diff --git a/shark/shark_eager/shark_eager.py b/shark/shark_eager/shark_eager.py new file mode 100644 index 0000000000..807c6c014e --- /dev/null +++ b/shark/shark_eager/shark_eager.py @@ -0,0 +1,206 @@ +from typing import Any, Dict, List, Tuple +from collections import defaultdict +from shark.shark_importer import import_with_fx +import torchvision.models as models +import copy +import io +import numpy as np +import sys +import torch +import torch.fx +from torch.fx.node import Node +from typing import Dict +import torch_mlir + + +def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"): + mlir_module = torch_mlir.compile( + fx_g, inputs, output_type="linalg-on-tensors" + ) + bytecode_stream = io.BytesIO() + mlir_module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + from shark.shark_inference import SharkInference + + shark_module = SharkInference( + mlir_module=bytecode, + device=device, + mlir_dialect="tm_tensor", + ) + shark_module.compile(extra_args=[]) + return shark_module + + +def _make_single_op_gm(node, captured_val, compiled_graph): + """Make a GraphModule that just executes the given node.""" + g = torch.fx.Graph() + env = {} + inputs = [] + for arg in node.args: + if arg and hasattr(arg, "name"): + env[arg.name] = g.placeholder(arg.name) + if isinstance(captured_val[arg.name], (list, tuple)): + for val in captured_val[arg.name]: + inputs.append(val) + else: + inputs.append(captured_val[arg.name]) + + call = g.node_copy(node, lambda n: env[n.name]) + g.output(call) + g.lint() + single_node = torch.fx.GraphModule(torch.nn.Module(), g) + compiled_module = shark_backend(single_node, inputs) + compiled_graph[node.name] = { + "module": compiled_module, + "inputs": [i for i in env], + "result": None, + } + return + + +def compiled_graph(gm: torch.fx.GraphModule, attr_info): + compiled_graph = {} + g = gm.graph + for node in g.nodes: + if node.op == "call_function": + if not ( + node.target in [torch.ops.aten.empty] + or node.name.startswith("getitem") + ): + _make_single_op_gm(node, attr_info, compiled_graph) + + # Currently torch.aten.empty has an compilation issue, so running natively. + elif node.target in [torch.ops.aten.empty]: + compiled_graph[node.name] = { + "target": node.target, + "args": node.args, + "kwargs": node.kwargs, + "result": None, + } + # Get item is a simple case takes a tuple and return the tensor at a particular index. + elif node.name.startswith("getitem"): + compiled_graph[node.name] = { + "input": node.args[0].name, + "pos": node.args[1], + "result": None, + } + + return compiled_graph + + +class ShapeProp: + """ + Shape propagation. This class takes a `GraphModule`. + Then, its `propagate` method executes the `GraphModule` + node-by-node with the given arguments. As each operation + executes, the ShapeProp class stores away the shape and + element type for the output values of each operation on + the `shape` and `dtype` attributes of the operation's + `Node`. + """ + + def __init__(self, mod): + self.mod = mod + self.graph = mod.graph + self.modules = dict(self.mod.named_modules()) + + def propagate(self, *args): + args_iter = iter(args) + env: Dict[str, Node] = {} + + def load_arg(a): + return torch.fx.graph.map_arg(a, lambda n: env[n.name]) + + def fetch_attr(target: str): + target_atoms = target.split(".") + attr_itr = self.mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + for node in self.graph.nodes: + if node.op == "placeholder": + result = next(args_iter) + elif node.op == "get_attr": + result = fetch_attr(node.target) + elif node.op == "call_function": + result = node.target( + *load_arg(node.args), **load_arg(node.kwargs) + ) + elif node.op == "call_method": + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == "call_module": + result = self.modules[node.target]( + *load_arg(node.args), **load_arg(node.kwargs) + ) + + # This is the only code specific to shape propagation. + # you can delete this `if` branch and this becomes + # a generic GraphModule interpreter. + if isinstance(result, torch.Tensor): + node.shape = result.shape + node.dtype = result.dtype + + env[node.name] = result + + return env + + # return load_arg(self.graph.result) + + +resnet18 = models.resnet18(pretrained=True) +resnet18.train(False) +input = (torch.randn(1, 3, 224, 224),) + +print(resnet18(input[0])) + +fx_graph = import_with_fx(resnet18, input, mlir_type="fx") + +shape_prop = ShapeProp(fx_graph) + +x = shape_prop.propagate(input[0]) + +shark_graph = compiled_graph(fx_graph, x) + + +for key in shark_graph: + if key.startswith("getitem"): + input_val = shark_graph[key]["input"] + pos = shark_graph[key]["pos"] + if input_val not in shark_graph: + shark_graph[key]["result"] = x[input_val][pos].detach() + else: + shark_graph[key]["result"] = shark_graph[input_val]["result"][ + pos + ].detach() + elif key.startswith("empty"): + operator = shark_graph[key]["target"] + args = shark_graph[key]["args"] + kwargs = shark_graph[key]["kwargs"] + shark_graph[key]["result"] = operator(*args, **kwargs).detach() + else: + input_val = shark_graph[key]["inputs"] + input_tensors = [] + for input in input_val: + if input not in shark_graph: + input_tensors.append(x[input].detach()) + else: + input_tensors.append(shark_graph[input]["result"]) + + val = shark_graph[key]["module"]("forward", input_tensors) + if isinstance(val, (tuple, list)): + list_val = [] + for v in val: + list_val.append(torch.from_numpy(v)) + shark_graph[key]["result"] = list_val + else: + shark_graph[key]["result"] = torch.from_numpy(val) + + +print(shark_graph) diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 64480d02ce..e12f7c0922 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -555,6 +555,9 @@ def strip_overloads(gm): add_upcast(fx_g) fx_g.recompile() + if mlir_type == "fx": + return fx_g + if training: change_fx_graph_return_to_tuple(fx_g) inputs = flatten_training_input(inputs) diff --git a/shark/shark_inference.py b/shark/shark_inference.py index b4b5541149..671f5be9de 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -48,6 +48,8 @@ class SharkInference: Refer to {https://mlir.llvm.org/docs/Dialects/} is_benchmark: bool Whether this SharkInference module should be benchmark-enabled. + mmap: bool + Whether to load/run vmfb using mmap. It's `True` by default. Methods ------- @@ -70,6 +72,7 @@ def __init__( dispatch_benchmark: str = None, dispatch_benchmark_dir: str = "temp_dispatch_benchmarks", device_idx: int = None, + mmap: bool = True, ): self.mlir_module = mlir_module self.device = shark_args.device if device == "none" else device @@ -88,6 +91,7 @@ def __init__( ) self.shark_runner = None + self.mmap = mmap def compile(self, extra_args=[]): if self.dispatch_benchmarks is not None: @@ -201,12 +205,14 @@ def load_module(self, path, extra_args=[]): compile_vmfb=False, extra_args=extra_args, ) - ( - self.shark_runner.iree_compilation_module, - self.shark_runner.iree_config, - ) = load_flatbuffer( + params = load_flatbuffer( path, self.device, self.device_idx, + mmap=self.mmap, ) + self.shark_runner.iree_compilation_module = params["vmfb"] + self.shark_runner.iree_config = params["config"] + self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"] + del params return diff --git a/shark/shark_runner.py b/shark/shark_runner.py index 8cf1c84854..2552dd6a89 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -85,16 +85,17 @@ def __init__( if compile_vmfb == True: # Compile the module to get the .vmfb. - ( - self.iree_compilation_module, - self.iree_config, - ) = get_iree_compiled_module( + params = get_iree_compiled_module( self.mlir_module, self.device, self.mlir_dialect, extra_args=self.extra_args, device_idx=self.device_idx, ) + self.iree_compilation_module = params["vmfb"] + self.iree_config = params["config"] + self.temp_file_to_unlink = params["temp_file_to_unlink"] + del params def run(self, function_name, inputs: tuple, send_to_host=False): return get_results(