Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored May 23, 2024
1 parent 17b98e7 commit 04b5897
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from cpuinfo import get_cpu_info

# TODO: migrate these utils to studio

from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)

def get_available_devices():
def get_devices_by_name(driver_name):
Expand Down Expand Up @@ -44,6 +48,8 @@ def get_devices_by_name(driver_name):
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list

set_iree_runtime_flags()

available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
Expand Down Expand Up @@ -72,6 +78,53 @@ def get_devices_by_name(driver_name):
available_devices.extend(cpu_device)
return available_devices

def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()

# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple

triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"


def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)

def parse_device(device_str):
from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map
Expand Down

0 comments on commit 04b5897

Please sign in to comment.