diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 391129961f..ee1d4a1fc1 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -19,6 +19,7 @@ import numpy as np import os import re +import tempfile # Get the iree-compile arguments given device. @@ -316,6 +317,54 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None): return ModuleCompiled, config +def load_vmfb_using_mmap( + flatbuffer_blob_or_path, device: str, device_idx: int = None +): + instance = ireert.VmInstance() + device = iree_device_map(device) + haldriver = ireert.get_driver(device) + haldevice = haldriver.create_device_by_uri( + device, + allocators=[], + ) + hal_module = ireert.create_hal_module(instance, haldevice) + # First get configs. + if device_idx is not None: + device = iree_device_map(device) + print("registering device id: ", device_idx) + haldriver = ireert.get_driver(device) + + haldevice = haldriver.create_device( + haldriver.query_available_devices()[device_idx]["device_id"], + allocators=shark_args.device_allocator, + ) + config = ireert.Config(device=haldevice) + else: + config = get_iree_runtime_config(device) + # Now load vmfb. + # Two scenarios we have here :- + # 1. We either have the vmfb already saved and therefore pass the path of it. + # (This would arise if we're invoking `load_module` from a SharkInference obj) + # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. + # (This would arise if we're invoking `compile` from a SharkInference obj) + if ( + isinstance(flatbuffer_blob_or_path, str) + and ".vmfb" in flatbuffer_blob_or_path + ): + mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path) + context = ireert.VmContext(instance, modules=[hal_module, mmaped_vmfb]) + else: + tmpf = tempfile.NamedTemporaryFile(delete=False) + tmpf_name = tmpf.name + tmpf.write(flatbuffer_blob_or_path) + tmpf.flush() + mmaped_vmfb = ireert.VmModule.mmap(instance, tmpf_name) + context = ireert.VmContext(instance, modules=[hal_module, mmaped_vmfb]) + tmpf.close() + os.unlink(tmpf.name) + return mmaped_vmfb, config, context, haldevice + + def get_iree_compiled_module( module, device: str, @@ -323,19 +372,47 @@ def get_iree_compiled_module( model_config_path: str = None, extra_args: list = [], device_idx: int = None, + mmap: bool = True, ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( module, device, frontend, model_config_path, extra_args ) - return get_iree_module(flatbuffer_blob, device, device_idx=device_idx) + print( + "Device from get_iree_compiled_module = ", + device, + " device_idx = ", + device_idx, + ) + if mmap: + return load_vmfb_using_mmap(flatbuffer_blob, device, device_idx) + # TODO: Returning of `None` is quite an ugly/hacky-way - will make it better in subsequent iteration. + return ( + *get_iree_module(flatbuffer_blob, device, device_idx=device_idx), + None, + None, + ) -def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None): +def load_flatbuffer( + flatbuffer_path: str, + device: str, + device_idx: int = None, + mmap: bool = True, +): + print( + "Device from load_flatbuffer = ", device, " device_idx = ", device_idx + ) + if mmap: + return load_vmfb_using_mmap(flatbuffer_path, device, device_idx) with open(os.path.join(flatbuffer_path), "rb") as f: flatbuffer_blob = f.read() - - return get_iree_module(flatbuffer_blob, device, device_idx=device_idx) + # TODO: Returning of `None` is quite an ugly/hacky-way - will make it better in subsequent iteration. + return ( + *get_iree_module(flatbuffer_blob, device, device_idx=device_idx), + None, + None, + ) def export_iree_module_to_vmfb( @@ -384,10 +461,18 @@ def get_results( config, frontend="torch", send_to_host=True, + mmap=True, + context=None, + haldevice=None, ): """Runs a .vmfb file given inputs and config and returns output.""" - device_inputs = [ireert.asdevicearray(config.device, a) for a in input] - result = compiled_vm[function_name](*device_inputs) + if mmap: + f = compiled_vm.lookup_function(function_name) + finv = ireert.FunctionInvoker(context, haldevice, f, tracer=None) + result = finv(*input) + else: + device_inputs = [ireert.asdevicearray(config.device, a) for a in input] + result = compiled_vm[function_name](*device_inputs) result_tensors = [] if isinstance(result, tuple): if send_to_host: diff --git a/shark/shark_inference.py b/shark/shark_inference.py index b4b5541149..8946e389d6 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -48,6 +48,8 @@ class SharkInference: Refer to {https://mlir.llvm.org/docs/Dialects/} is_benchmark: bool Whether this SharkInference module should be benchmark-enabled. + mmap: bool + Whether to load/run vmfb using mmap. It's `True` by default. Methods ------- @@ -70,6 +72,7 @@ def __init__( dispatch_benchmark: str = None, dispatch_benchmark_dir: str = "temp_dispatch_benchmarks", device_idx: int = None, + mmap: bool = True, ): self.mlir_module = mlir_module self.device = shark_args.device if device == "none" else device @@ -88,6 +91,7 @@ def __init__( ) self.shark_runner = None + self.mmap = mmap def compile(self, extra_args=[]): if self.dispatch_benchmarks is not None: @@ -122,6 +126,7 @@ def compile(self, extra_args=[]): self.mlir_dialect, extra_args=extra_args, device_idx=self.device_idx, + mmap=self.mmap, ) if self.dispatch_benchmarks is not None: @@ -200,13 +205,17 @@ def load_module(self, path, extra_args=[]): device=self.device, compile_vmfb=False, extra_args=extra_args, + mmap=self.mmap, ) ( self.shark_runner.iree_compilation_module, self.shark_runner.iree_config, + self.shark_runner.context, + self.shark_runner.haldevice, ) = load_flatbuffer( path, self.device, self.device_idx, + mmap=self.mmap, ) return diff --git a/shark/shark_runner.py b/shark/shark_runner.py index 8cf1c84854..4cd23f9464 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -52,6 +52,10 @@ class SharkRunner: mlir_dialect: str The dialect in which the given mlir_module is in. Refer to {https://mlir.llvm.org/docs/Dialects/} + mmap: bool + Whether to load/run vmfb using mmap. It's `True` by default. + When `True` - `iree_compilation_module` would contain the mmap'd vmfb + - `context` and `haldevice` would be not `None`. Methods ------- @@ -72,6 +76,7 @@ def __init__( extra_args: list = [], compile_vmfb: bool = True, device_idx: int = None, + mmap: bool = True, ): self.mlir_module = mlir_module self.device = shark_args.device if device == "none" else device @@ -83,17 +88,21 @@ def __init__( print(device_driver_info(self.device)) sys.exit(1) + self.mmap = mmap if compile_vmfb == True: # Compile the module to get the .vmfb. ( self.iree_compilation_module, self.iree_config, + self.context, + self.haldevice, ) = get_iree_compiled_module( self.mlir_module, self.device, self.mlir_dialect, extra_args=self.extra_args, device_idx=self.device_idx, + mmap=mmap, ) def run(self, function_name, inputs: tuple, send_to_host=False): @@ -104,6 +113,9 @@ def run(self, function_name, inputs: tuple, send_to_host=False): self.iree_config, self.mlir_dialect, send_to_host, + mmap=self.mmap, + context=self.context, + haldevice=self.haldevice, ) # Get all function names defined within the compiled module.