Skip to content

Commit

Permalink
[SharkInference-SharkRuntime] Adds capability to mmap vmfbs
Browse files Browse the repository at this point in the history
-- This commit is based on [VmModule.mmap() API](iree-org/iree#14124).
-- It thereby adds capability to mmap vmfbs in SHARK.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma committed Jun 15, 2023
1 parent 38570a9 commit 189afec
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 6 deletions.
97 changes: 91 additions & 6 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import os
import re
import tempfile


# Get the iree-compile arguments given device.
Expand Down Expand Up @@ -316,26 +317,102 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
return ModuleCompiled, config


def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
hal_module = ireert.create_hal_module(instance, haldevice)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)

haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
context = ireert.VmContext(instance, modules=[hal_module, mmaped_vmfb])
else:
tmpf = tempfile.NamedTemporaryFile(delete=False)
tmpf_name = tmpf.name
tmpf.write(flatbuffer_blob_or_path)
tmpf.flush()
mmaped_vmfb = ireert.VmModule.mmap(instance, tmpf_name)
context = ireert.VmContext(instance, modules=[hal_module, mmaped_vmfb])
tmpf.close()
os.unlink(tmpf.name)
return mmaped_vmfb, config, context, haldevice


def get_iree_compiled_module(
module,
device: str,
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
mmap: bool = True,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
print(
"Device from get_iree_compiled_module = ",
device,
" device_idx = ",
device_idx,
)
if mmap:
return load_vmfb_using_mmap(flatbuffer_blob, device, device_idx)
# TODO: Returning of `None` is quite an ugly/hacky-way - will make it better in subsequent iteration.
return (
*get_iree_module(flatbuffer_blob, device, device_idx=device_idx),
None,
None,
)


def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
def load_flatbuffer(
flatbuffer_path: str,
device: str,
device_idx: int = None,
mmap: bool = True,
):
print(
"Device from load_flatbuffer = ", device, " device_idx = ", device_idx
)
if mmap:
return load_vmfb_using_mmap(flatbuffer_path, device, device_idx)
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()

return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
# TODO: Returning of `None` is quite an ugly/hacky-way - will make it better in subsequent iteration.
return (
*get_iree_module(flatbuffer_blob, device, device_idx=device_idx),
None,
None,
)


def export_iree_module_to_vmfb(
Expand Down Expand Up @@ -384,10 +461,18 @@ def get_results(
config,
frontend="torch",
send_to_host=True,
mmap=True,
context=None,
haldevice=None,
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm[function_name](*device_inputs)
if mmap:
f = compiled_vm.lookup_function(function_name)
finv = ireert.FunctionInvoker(context, haldevice, f, tracer=None)
result = finv(*input)
else:
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
Expand Down
9 changes: 9 additions & 0 deletions shark/shark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class SharkInference:
Refer to {https://mlir.llvm.org/docs/Dialects/}
is_benchmark: bool
Whether this SharkInference module should be benchmark-enabled.
mmap: bool
Whether to load/run vmfb using mmap. It's `True` by default.
Methods
-------
Expand All @@ -70,6 +72,7 @@ def __init__(
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
mmap: bool = True,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
Expand All @@ -88,6 +91,7 @@ def __init__(
)

self.shark_runner = None
self.mmap = mmap

def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
Expand Down Expand Up @@ -122,6 +126,7 @@ def compile(self, extra_args=[]):
self.mlir_dialect,
extra_args=extra_args,
device_idx=self.device_idx,
mmap=self.mmap,
)

if self.dispatch_benchmarks is not None:
Expand Down Expand Up @@ -200,13 +205,17 @@ def load_module(self, path, extra_args=[]):
device=self.device,
compile_vmfb=False,
extra_args=extra_args,
mmap=self.mmap,
)
(
self.shark_runner.iree_compilation_module,
self.shark_runner.iree_config,
self.shark_runner.context,
self.shark_runner.haldevice,
) = load_flatbuffer(
path,
self.device,
self.device_idx,
mmap=self.mmap,
)
return
12 changes: 12 additions & 0 deletions shark/shark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class SharkRunner:
mlir_dialect: str
The dialect in which the given mlir_module is in.
Refer to {https://mlir.llvm.org/docs/Dialects/}
mmap: bool
Whether to load/run vmfb using mmap. It's `True` by default.
When `True` - `iree_compilation_module` would contain the mmap'd vmfb
- `context` and `haldevice` would be not `None`.
Methods
-------
Expand All @@ -72,6 +76,7 @@ def __init__(
extra_args: list = [],
compile_vmfb: bool = True,
device_idx: int = None,
mmap: bool = True,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
Expand All @@ -83,17 +88,21 @@ def __init__(
print(device_driver_info(self.device))
sys.exit(1)

self.mmap = mmap
if compile_vmfb == True:
# Compile the module to get the .vmfb.
(
self.iree_compilation_module,
self.iree_config,
self.context,
self.haldevice,
) = get_iree_compiled_module(
self.mlir_module,
self.device,
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
mmap=mmap,
)

def run(self, function_name, inputs: tuple, send_to_host=False):
Expand All @@ -104,6 +113,9 @@ def run(self, function_name, inputs: tuple, send_to_host=False):
self.iree_config,
self.mlir_dialect,
send_to_host,
mmap=self.mmap,
context=self.context,
haldevice=self.haldevice,
)

# Get all function names defined within the compiled module.
Expand Down

0 comments on commit 189afec

Please sign in to comment.