Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VmModule.mmap() to Python API. #14124

Merged
merged 2 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions runtime/bindings/python/tests/vm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import numpy as np
import tempfile
import unittest

import iree.compiler
Expand Down Expand Up @@ -106,6 +107,27 @@ def test_add_scalar_new_abi(self):
logging.info("result: %s", result)
self.assertEqual(result, 11)

def test_mmap(self):
binary = iree.compiler.compile_str(
"""
func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
%0 = arith.addi %arg0, %arg1 : i32
return %0 : i32
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS)
with tempfile.NamedTemporaryFile() as tf:
tf.write(binary)
tf.flush()
m = iree.runtime.VmModule.mmap(self.instance, tf.name)
Comment on lines +119 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need some tweaking for Windows

Indeed. https://github.com/openxla/iree/actions/runs/5275962562/jobs/9542060597#step:9:2321

ERROR: test_mmap (__main__.VmTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\a\iree\iree\runtime\bindings\python\tests\vm_test.py", line 122, in test_mmap
    m = iree.runtime.VmModule.mmap(self.instance, tf.name)
PermissionError: [Errno 13] Permission denied: 'C:\\Users\\RUNNER~1\\AppData\\Local\\Temp\\tmpr4u5ldjo'

Maybe related to #13148 (comment)

See https://stackoverflow.com/a/23212515 - NamedTemporaryFile creates and opens the file, and the file cannot be opened again... on Windows (it can be opened again on Unix).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a fix for that in SHARK to deal with this very issue - perhaps that can be used.

But the other issue in Windows which I bumped into is mentioned here - it's API related instead of the use case.

CC: @powderluv

context = iree.runtime.VmContext(self.instance,
modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
logging.info("result: %s", result)
self.assertEqual(result, 11)

def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
m = create_simple_dynamic_abs_module(self.instance)
context = iree.runtime.VmContext(self.instance,
Expand Down
18 changes: 18 additions & 0 deletions runtime/bindings/python/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ VmModule VmModule::ResolveModuleDependency(VmInstance* instance,
return py_module;
}

VmModule VmModule::MMap(VmInstance* instance, std::string filepath) {
IREE_TRACE_SCOPE_NAMED("VmModule::MMap");
auto mmap_module = py::module::import("mmap");
auto open_func = py::module::import("io").attr("open");
auto file_obj = open_func(filepath, "r+b");
auto flags = py::cast<int64_t>(mmap_module.attr("MAP_SHARED"));
// MAP_POPULATE isn't available on all versions/platforms.
if (py::hasattr(mmap_module, "MAP_POPULATE")) {
flags |= py::cast<int64_t>(mmap_module.attr("MAP_POPULATE"));
}
auto prot = py::cast<int64_t>(mmap_module.attr("PROT_READ"));
auto mapped_file =
mmap_module.attr("mmap")(file_obj.attr("fileno")(), 0, flags, prot);
mapped_file.attr("madvise")(mmap_module.attr("MADV_RANDOM"));
return FromFlatbufferBlob(instance, mapped_file);
}

VmModule VmModule::FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object) {
IREE_TRACE_SCOPE_NAMED("VmModule::FromFlatbufferBlob");
Expand Down Expand Up @@ -661,6 +678,7 @@ void SetupVmBindings(pybind11::module m) {
.def_static("resolve_module_dependency",
&VmModule::ResolveModuleDependency)
.def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
.def_static("mmap", &VmModule::MMap)
.def_property_readonly("name", &VmModule::name)
.def_property_readonly("version",
[](VmModule& self) {
Expand Down
1 change: 1 addition & 0 deletions runtime/bindings/python/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
const std::string& name,
uint32_t minimum_version);

static VmModule MMap(VmInstance* instance, std::string filepath);
static VmModule FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object);

Expand Down