diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py index bde6ae966510..7ff760028374 100644 --- a/runtime/bindings/python/tests/vm_test.py +++ b/runtime/bindings/python/tests/vm_test.py @@ -8,6 +8,7 @@ import logging import numpy as np +import tempfile import unittest import iree.compiler @@ -106,6 +107,27 @@ def test_add_scalar_new_abi(self): logging.info("result: %s", result) self.assertEqual(result, 11) + def test_mmap(self): + binary = iree.compiler.compile_str( + """ + func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 : i32 + return %0 : i32 + } + """, + target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS) + with tempfile.NamedTemporaryFile() as tf: + tf.write(binary) + tf.flush() + m = iree.runtime.VmModule.mmap(self.instance, tf.name) + context = iree.runtime.VmContext(self.instance, + modules=[self.hal_module, m]) + f = m.lookup_function("add_scalar") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + result = finv(5, 6) + logging.info("result: %s", result) + self.assertEqual(result, 11) + def test_synchronous_dynamic_shape_invoke_function_new_abi(self): m = create_simple_dynamic_abs_module(self.instance) context = iree.runtime.VmContext(self.instance, diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc index b72ecdea1989..9a3a4fdf2f7e 100644 --- a/runtime/bindings/python/vm.cc +++ b/runtime/bindings/python/vm.cc @@ -150,6 +150,23 @@ VmModule VmModule::ResolveModuleDependency(VmInstance* instance, return py_module; } +VmModule VmModule::MMap(VmInstance* instance, std::string filepath) { + IREE_TRACE_SCOPE_NAMED("VmModule::MMap"); + auto mmap_module = py::module::import("mmap"); + auto open_func = py::module::import("io").attr("open"); + auto file_obj = open_func(filepath, "r+b"); + auto flags = py::cast(mmap_module.attr("MAP_SHARED")); + // MAP_POPULATE isn't available on all versions/platforms. + if (py::hasattr(mmap_module, "MAP_POPULATE")) { + flags |= py::cast(mmap_module.attr("MAP_POPULATE")); + } + auto prot = py::cast(mmap_module.attr("PROT_READ")); + auto mapped_file = + mmap_module.attr("mmap")(file_obj.attr("fileno")(), 0, flags, prot); + mapped_file.attr("madvise")(mmap_module.attr("MADV_RANDOM")); + return FromFlatbufferBlob(instance, mapped_file); +} + VmModule VmModule::FromFlatbufferBlob(VmInstance* instance, py::object flatbuffer_blob_object) { IREE_TRACE_SCOPE_NAMED("VmModule::FromFlatbufferBlob"); @@ -661,6 +678,7 @@ void SetupVmBindings(pybind11::module m) { .def_static("resolve_module_dependency", &VmModule::ResolveModuleDependency) .def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob) + .def_static("mmap", &VmModule::MMap) .def_property_readonly("name", &VmModule::name) .def_property_readonly("version", [](VmModule& self) { diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h index 4b097e7e35e1..9bf25f462ba9 100644 --- a/runtime/bindings/python/vm.h +++ b/runtime/bindings/python/vm.h @@ -133,6 +133,7 @@ class VmModule : public ApiRefCounted { const std::string& name, uint32_t minimum_version); + static VmModule MMap(VmInstance* instance, std::string filepath); static VmModule FromFlatbufferBlob(VmInstance* instance, py::object flatbuffer_blob_object);