Skip to content
This repository has been archived by the owner on Jul 30, 2024. It is now read-only.

Commit

Permalink
Integrate IREE at iree-org/iree@c121b86
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre committed Feb 13, 2024
1 parent 30291c3 commit d1c3d91
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion iree_simple_embedding/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ set(_TRANSLATE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)

# Define arguments passed to iree-compile
set(_ARGS)
list(APPEND _ARGS "-iree-input-type=mhlo")
list(APPEND _ARGS "-iree-input-type=stablehlo")
list(APPEND _ARGS "--output-format=vm-bytecode")
list(APPEND _ARGS "-iree-hal-target-backends=vmvx")
# Uncomment the line below to use vulkan-spirv backend
Expand Down
27 changes: 13 additions & 14 deletions iree_simple_embedding/simple_embedding.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ extern const iree_const_byte_span_t load_bytecode_module_data();

iree_status_t Run() {
iree_vm_instance_t* instance = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_instance_create(iree_allocator_system(), &instance));
IREE_RETURN_IF_ERROR(iree_vm_instance_create(
IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance));
IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance));

iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
"create device");
iree_vm_module_t* hal_module = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_module_create(instance, device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
IREE_RETURN_IF_ERROR(iree_hal_module_create(
instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS,
iree_allocator_system(), &hal_module));

// Load bytecode module from the embedded data.
Expand Down Expand Up @@ -78,16 +78,16 @@ iree_status_t Run() {
iree_hal_dim_t shape[1] = {IREE_ARRAYSIZE(kFloat4)};
iree_hal_buffer_view_t* arg0_buffer_view = NULL;
iree_hal_buffer_view_t* arg1_buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
iree_hal_device_allocator(device), IREE_ARRAYSIZE(shape), shape,
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(shape), shape,
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
},
iree_make_const_byte_span(kFloat4, sizeof(kFloat4)), &arg0_buffer_view));
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
iree_hal_device_allocator(device), IREE_ARRAYSIZE(shape), shape,
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy(
device, iree_hal_device_allocator(device), IREE_ARRAYSIZE(shape), shape,
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
Expand All @@ -97,8 +97,8 @@ iree_status_t Run() {

// Setup call inputs with our buffers.
iree_vm_list_t* inputs = NULL;
IREE_RETURN_IF_ERROR(iree_vm_list_create(
/*element_type=*/NULL,
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/2, iree_allocator_system(), &inputs),
"can't allocate input vm list");

Expand All @@ -114,8 +114,8 @@ iree_status_t Run() {
// Prepare outputs list to accept the results from the invocation.
// The output vm list is allocated statically.
iree_vm_list_t* outputs = NULL;
IREE_RETURN_IF_ERROR(iree_vm_list_create(
/*element_type=*/NULL,
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(),
/*capacity=*/1, iree_allocator_system(), &outputs),
"can't allocate output vm list");

Expand All @@ -126,8 +126,7 @@ iree_status_t Run() {

// Get the result buffers from the invocation.
iree_hal_buffer_view_t* ret_buffer_view =
(iree_hal_buffer_view_t*)iree_vm_list_get_ref_deref(
outputs, 0, &iree_hal_buffer_view_descriptor);
iree_vm_list_get_buffer_view_assign(outputs, 0);
if (ret_buffer_view == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"can't find return buffer view");
Expand Down
2 changes: 1 addition & 1 deletion iree_simple_embedding/simple_embedding_test.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
{
%0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "stablehlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
2 changes: 1 addition & 1 deletion third_party/iree
Submodule iree updated from 0eeae4 to c121b8

0 comments on commit d1c3d91

Please sign in to comment.