Skip to content

Commit

Permalink
fixing for metal platform
Browse files Browse the repository at this point in the history
  • Loading branch information
Ranvirsv committed Jun 26, 2023
1 parent 009ff49 commit f52feb2
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,7 @@ def get_metal_target_triple(device_name):
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"

else:
triple = None
return triple
return "macos"


def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
Expand All @@ -81,7 +73,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
f"Found metal device {metal_device}. Using metal target platform {triple}"
)
return f"-iree-metal-target-platform={triple}"
print(
Expand All @@ -105,12 +97,10 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
break

if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)

if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
vulkan_target_env = get_vulkan_target_env_flag("-iree-vulkan-target-triple=m1-moltenvk-macos")
res_metal_flag.append(vulkan_target_env)
return res_metal_flag

Expand Down

0 comments on commit f52feb2

Please sign in to comment.