diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 7241bf3691..f213320045 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -12,7 +12,11 @@ from cpuinfo import get_cpu_info # TODO: migrate these utils to studio - +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, + get_iree_vulkan_runtime_flags, +) def get_available_devices(): def get_devices_by_name(driver_name): @@ -44,6 +48,8 @@ def get_devices_by_name(driver_name): device_list.append(f"{device_name} => {driver_name}://{i}") return device_list + set_iree_runtime_flags() + available_devices = [] from shark.iree_utils.vulkan_utils import ( get_all_vulkan_devices, @@ -72,6 +78,53 @@ def get_devices_by_name(driver_name): available_devices.extend(cpu_device) return available_devices +def set_init_device_flags(): + if "vulkan" in cmd_opts.device: + # set runtime flags for vulkan. + set_iree_runtime_flags() + + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_vulkan_target_triple: + triple = get_vulkan_target_triple(device_name) + if triple is not None: + cmd_opts.iree_vulkan_target_triple = triple + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_vulkan_target_triple}." + ) + elif "cuda" in cmd_opts.device: + cmd_opts.device = "cuda" + elif "metal" in cmd_opts.device: + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_metal_target_platform: + from shark.iree_utils.metal_utils import get_metal_target_triple + + triple = get_metal_target_triple(device_name) + if triple is not None: + cmd_opts.iree_metal_target_platform = triple.split("-")[-1] + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_metal_target_platform}." + ) + elif "cpu" in cmd_opts.device: + cmd_opts.device = "cpu" + + +def set_iree_runtime_flags(): + # TODO: This function should be device-agnostic and piped properly + # to general runtime driver init. + vulkan_runtime_flags = get_iree_vulkan_runtime_flags() + if cmd_opts.enable_rgp: + vulkan_runtime_flags += [ + f"--enable_rgp=true", + f"--vulkan_debug_utils=true", + ] + if cmd_opts.device_allocator_heap_key: + vulkan_runtime_flags += [ + f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}", + ] + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) def parse_device(device_str): from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map