Skip to content

Commit

Permalink
Metal testing (#1595)
Browse files Browse the repository at this point in the history
* Fixing metal_platform and device selection

* fixing for metal platform

* fixed for black lint formating
  • Loading branch information
Ranvirsv committed Jul 8, 2023
1 parent 788d469 commit 9fcae4f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 18 deletions.
2 changes: 1 addition & 1 deletion apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def set_init_device_flags():
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platform = triple
args.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{args.iree_metal_target_platform}."
Expand Down
4 changes: 1 addition & 3 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args

return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

Expand Down
20 changes: 6 additions & 14 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,7 @@ def get_metal_target_triple(device_name):
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"

else:
triple = None
return triple
return "macos"


def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
Expand All @@ -81,7 +73,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
f"Found metal device {metal_device}. Using metal target platform {triple}"
)
return f"-iree-metal-target-platform={triple}"
print(
Expand All @@ -105,12 +97,12 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
break

if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)

if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
vulkan_target_env = get_vulkan_target_env_flag(
"-iree-vulkan-target-triple=m1-moltenvk-macos"
)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag

Expand Down

0 comments on commit 9fcae4f

Please sign in to comment.