diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py index dbc400edf4..7f65488f7b 100644 --- a/shark/iree_utils/metal_utils.py +++ b/shark/iree_utils/metal_utils.py @@ -14,7 +14,6 @@ # All the iree_vulkan related functionalities go here. -from os import linesep from shark.iree_utils._common import run_cmd import iree.runtime as ireert from sys import platform @@ -22,17 +21,19 @@ def get_metal_device_name(device_num=0): - vulkaninfo_dump, _ = run_cmd("vulkaninfo") - vulkaninfo_dump = vulkaninfo_dump.split(linesep) - vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s] - if len(vulkaninfo_list) == 0: - raise ValueError("No device name found in VulkanInfo!") - if len(vulkaninfo_list) > 1: + iree_device_dump = run_cmd("iree-run-module --dump_devices") + iree_device_dump = iree_device_dump[0].split("\n\n") + metal_device_list = [ + s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s + ] + if len(metal_device_list) == 0: + raise ValueError("No device name found in device dump!") + if len(metal_device_list) > 1: print("Following devices found:") - for i, dname in enumerate(vulkaninfo_list): + for i, dname in enumerate(metal_device_list): print(f"{i}. {dname}") - print(f"Choosing device: {vulkaninfo_list[device_num]}") - return vulkaninfo_list[device_num] + print(f"Choosing device: {metal_device_list[device_num]}") + return metal_device_list[device_num] def get_os_name(): @@ -82,7 +83,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): print( f"Found metal device {metal_device}. Using metal target triple {triple}" ) - return f"-iree-metal-target-triple={triple}" + return f"-iree-metal-target-platfrom={triple}" print( """Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]