Skip to content

Commit

Permalink
Add VmModule.mmap() to Python API. (iree-org#14124)
Browse files Browse the repository at this point in the history
We really need page aligned flatbuffer blobs vs normal malloc alignment.
The best way to be loading a file is via mmap, so just make that
available as an API.

This could be done in Python by the caller but is error-prone. The
public API will make this more robust.

Provides the mechanism to fix iree-org#13887
  • Loading branch information
Stella Laurenzo authored and nhasabni committed Aug 24, 2023
1 parent faee657 commit d287cd3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
22 changes: 22 additions & 0 deletions runtime/bindings/python/tests/vm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import numpy as np
import tempfile
import unittest

import iree.compiler
Expand Down Expand Up @@ -106,6 +107,27 @@ def test_add_scalar_new_abi(self):
logging.info("result: %s", result)
self.assertEqual(result, 11)

def test_mmap(self):
binary = iree.compiler.compile_str(
"""
func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
%0 = arith.addi %arg0, %arg1 : i32
return %0 : i32
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS)
with tempfile.NamedTemporaryFile() as tf:
tf.write(binary)
tf.flush()
m = iree.runtime.VmModule.mmap(self.instance, tf.name)
context = iree.runtime.VmContext(self.instance,
modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
logging.info("result: %s", result)
self.assertEqual(result, 11)

def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
m = create_simple_dynamic_abs_module(self.instance)
context = iree.runtime.VmContext(self.instance,
Expand Down
18 changes: 18 additions & 0 deletions runtime/bindings/python/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ VmModule VmModule::ResolveModuleDependency(VmInstance* instance,
return py_module;
}

VmModule VmModule::MMap(VmInstance* instance, std::string filepath) {
IREE_TRACE_SCOPE_NAMED("VmModule::MMap");
auto mmap_module = py::module::import("mmap");
auto open_func = py::module::import("io").attr("open");
auto file_obj = open_func(filepath, "r+b");
auto flags = py::cast<int64_t>(mmap_module.attr("MAP_SHARED"));
// MAP_POPULATE isn't available on all versions/platforms.
if (py::hasattr(mmap_module, "MAP_POPULATE")) {
flags |= py::cast<int64_t>(mmap_module.attr("MAP_POPULATE"));
}
auto prot = py::cast<int64_t>(mmap_module.attr("PROT_READ"));
auto mapped_file =
mmap_module.attr("mmap")(file_obj.attr("fileno")(), 0, flags, prot);
mapped_file.attr("madvise")(mmap_module.attr("MADV_RANDOM"));
return FromFlatbufferBlob(instance, mapped_file);
}

VmModule VmModule::FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object) {
IREE_TRACE_SCOPE_NAMED("VmModule::FromFlatbufferBlob");
Expand Down Expand Up @@ -661,6 +678,7 @@ void SetupVmBindings(pybind11::module m) {
.def_static("resolve_module_dependency",
&VmModule::ResolveModuleDependency)
.def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
.def_static("mmap", &VmModule::MMap)
.def_property_readonly("name", &VmModule::name)
.def_property_readonly("version",
[](VmModule& self) {
Expand Down
1 change: 1 addition & 0 deletions runtime/bindings/python/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
const std::string& name,
uint32_t minimum_version);

static VmModule MMap(VmInstance* instance, std::string filepath);
static VmModule FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object);

Expand Down

0 comments on commit d287cd3

Please sign in to comment.