diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 8ddb530b38..78ee1ca6d5 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -19,6 +19,7 @@ import numpy as np import os import re +import tempfile # Get the iree-compile arguments given device. @@ -181,8 +182,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks): vmfb_file.close() config = get_iree_runtime_config(device) - vm_module = ireert.VmModule.from_flatbuffer( - config.vm_instance, flatbuffer_blob + vm_module = ireert.VmModule.from_buffer( + config.vm_instance, + flatbuffer_blob, + warn_if_copy=False, ) benchmark_cl = build_benchmark_args_non_tensor_input( @@ -313,8 +316,8 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None): config = ireert.Config(device=haldevice) else: config = get_iree_runtime_config(device) - vm_module = ireert.VmModule.from_flatbuffer( - config.vm_instance, flatbuffer_blob + vm_module = ireert.VmModule.from_buffer( + config.vm_instance, flatbuffer_blob, warn_if_copy=False ) ctx = ireert.SystemContext(config=config) ctx.add_vm_module(vm_module) @@ -322,6 +325,55 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None): return ModuleCompiled, config +def load_vmfb_using_mmap( + flatbuffer_blob_or_path, device: str, device_idx: int = None +): + instance = ireert.VmInstance() + device = iree_device_map(device) + haldriver = ireert.get_driver(device) + haldevice = haldriver.create_device_by_uri( + device, + allocators=[], + ) + # First get configs. + if device_idx is not None: + device = iree_device_map(device) + print("registering device id: ", device_idx) + haldriver = ireert.get_driver(device) + + haldevice = haldriver.create_device( + haldriver.query_available_devices()[device_idx]["device_id"], + allocators=shark_args.device_allocator, + ) + config = ireert.Config(device=haldevice) + else: + config = get_iree_runtime_config(device) + # Now load vmfb. + # Two scenarios we have here :- + # 1. We either have the vmfb already saved and therefore pass the path of it. + # (This would arise if we're invoking `load_module` from a SharkInference obj) + # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. + # (This would arise if we're invoking `compile` from a SharkInference obj) + temp_file_to_unlink = None + if ( + isinstance(flatbuffer_blob_or_path, str) + and ".vmfb" in flatbuffer_blob_or_path + ): + vmfb_file_path = flatbuffer_blob_or_path + mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path) + ctx = ireert.SystemContext(config=config) + ctx.add_vm_module(mmaped_vmfb) + mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) + else: + with tempfile.NamedTemporaryFile(delete=False) as tf: + tf.write(flatbuffer_blob_or_path) + tf.flush() + vmfb_file_path = tf.name + temp_file_to_unlink = vmfb_file_path + mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path) + return mmaped_vmfb, config, temp_file_to_unlink + + def get_iree_compiled_module( module, device: str, @@ -329,19 +381,58 @@ def get_iree_compiled_module( model_config_path: str = None, extra_args: list = [], device_idx: int = None, + mmap: bool = False, ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( module, device, frontend, model_config_path, extra_args ) - return get_iree_module(flatbuffer_blob, device, device_idx=device_idx) - + temp_file_to_unlink = None + # TODO: Currently mmap=True control flow path has been switched off for mmap. + # Got to find a cleaner way to unlink/delete the temporary file since + # we're setting delete=False when creating NamedTemporaryFile. That's why + # I'm getting hold of the name of the temporary file in `temp_file_to_unlink`. + if mmap: + print(f"Will load the compiled module as a mmapped temporary file") + vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( + flatbuffer_blob, device, device_idx + ) + else: + vmfb, config = get_iree_module( + flatbuffer_blob, device, device_idx=device_idx + ) + ret_params = { + "vmfb": vmfb, + "config": config, + "temp_file_to_unlink": temp_file_to_unlink, + } + return ret_params -def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None): - with open(os.path.join(flatbuffer_path), "rb") as f: - flatbuffer_blob = f.read() - return get_iree_module(flatbuffer_blob, device, device_idx=device_idx) +def load_flatbuffer( + flatbuffer_path: str, + device: str, + device_idx: int = None, + mmap: bool = False, +): + temp_file_to_unlink = None + if mmap: + print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file") + vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( + flatbuffer_path, device, device_idx + ) + else: + with open(os.path.join(flatbuffer_path), "rb") as f: + flatbuffer_blob = f.read() + vmfb, config = get_iree_module( + flatbuffer_blob, device, device_idx=device_idx + ) + ret_params = { + "vmfb": vmfb, + "config": config, + "temp_file_to_unlink": temp_file_to_unlink, + } + return ret_params def export_iree_module_to_vmfb( diff --git a/shark/shark_inference.py b/shark/shark_inference.py index b4b5541149..671f5be9de 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -48,6 +48,8 @@ class SharkInference: Refer to {https://mlir.llvm.org/docs/Dialects/} is_benchmark: bool Whether this SharkInference module should be benchmark-enabled. + mmap: bool + Whether to load/run vmfb using mmap. It's `True` by default. Methods ------- @@ -70,6 +72,7 @@ def __init__( dispatch_benchmark: str = None, dispatch_benchmark_dir: str = "temp_dispatch_benchmarks", device_idx: int = None, + mmap: bool = True, ): self.mlir_module = mlir_module self.device = shark_args.device if device == "none" else device @@ -88,6 +91,7 @@ def __init__( ) self.shark_runner = None + self.mmap = mmap def compile(self, extra_args=[]): if self.dispatch_benchmarks is not None: @@ -201,12 +205,14 @@ def load_module(self, path, extra_args=[]): compile_vmfb=False, extra_args=extra_args, ) - ( - self.shark_runner.iree_compilation_module, - self.shark_runner.iree_config, - ) = load_flatbuffer( + params = load_flatbuffer( path, self.device, self.device_idx, + mmap=self.mmap, ) + self.shark_runner.iree_compilation_module = params["vmfb"] + self.shark_runner.iree_config = params["config"] + self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"] + del params return diff --git a/shark/shark_runner.py b/shark/shark_runner.py index 8cf1c84854..2552dd6a89 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -85,16 +85,17 @@ def __init__( if compile_vmfb == True: # Compile the module to get the .vmfb. - ( - self.iree_compilation_module, - self.iree_config, - ) = get_iree_compiled_module( + params = get_iree_compiled_module( self.mlir_module, self.device, self.mlir_dialect, extra_args=self.extra_args, device_idx=self.device_idx, ) + self.iree_compilation_module = params["vmfb"] + self.iree_config = params["config"] + self.temp_file_to_unlink = params["temp_file_to_unlink"] + del params def run(self, function_name, inputs: tuple, send_to_host=False): return get_results(