Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SharkInference-SharkRuntime] Adds capability to mmap vmfbs #1540

Merged
merged 1 commit into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 101 additions & 10 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import os
import re
import tempfile


# Get the iree-compile arguments given device.
Expand Down Expand Up @@ -181,8 +182,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
vmfb_file.close()

config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance,
flatbuffer_blob,
warn_if_copy=False,
)

benchmark_cl = build_benchmark_args_non_tensor_input(
Expand Down Expand Up @@ -313,35 +316,123 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance, flatbuffer_blob, warn_if_copy=False
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = getattr(ctx.modules, vm_module.name)
return ModuleCompiled, config


def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)

haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink


def get_iree_compiled_module(
module,
device: str,
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)

temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
# Got to find a cleaner way to unlink/delete the temporary file since
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params

def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()

return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def load_flatbuffer(
flatbuffer_path: str,
device: str,
device_idx: int = None,
mmap: bool = False,
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params


def export_iree_module_to_vmfb(
Expand Down
14 changes: 10 additions & 4 deletions shark/shark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class SharkInference:
Refer to {https://mlir.llvm.org/docs/Dialects/}
is_benchmark: bool
Whether this SharkInference module should be benchmark-enabled.
mmap: bool
Whether to load/run vmfb using mmap. It's `True` by default.

Methods
-------
Expand All @@ -70,6 +72,7 @@ def __init__(
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
device_idx: int = None,
mmap: bool = True,
):
self.mlir_module = mlir_module
self.device = shark_args.device if device == "none" else device
Expand All @@ -88,6 +91,7 @@ def __init__(
)

self.shark_runner = None
self.mmap = mmap

def compile(self, extra_args=[]):
if self.dispatch_benchmarks is not None:
Expand Down Expand Up @@ -201,12 +205,14 @@ def load_module(self, path, extra_args=[]):
compile_vmfb=False,
extra_args=extra_args,
)
(
self.shark_runner.iree_compilation_module,
self.shark_runner.iree_config,
) = load_flatbuffer(
params = load_flatbuffer(
path,
self.device,
self.device_idx,
mmap=self.mmap,
)
self.shark_runner.iree_compilation_module = params["vmfb"]
self.shark_runner.iree_config = params["config"]
self.shark_runner.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
return
9 changes: 5 additions & 4 deletions shark/shark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,17 @@ def __init__(

if compile_vmfb == True:
# Compile the module to get the .vmfb.
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(
params = get_iree_compiled_module(
self.mlir_module,
self.device,
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params

def run(self, function_name, inputs: tuple, send_to_host=False):
return get_results(
Expand Down