Skip to content

Commit

Permalink
Metal typo fix (#1572)
Browse files Browse the repository at this point in the history
* fixing typos for metal changes

* black formating
  • Loading branch information
Ranvirsv authored Jun 22, 2023
1 parent a202bb4 commit 18c8e9e
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def is_valid_file(arg):
)

p.add_argument(
"--iree_metal_target_platfrom",
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal",
Expand Down
6 changes: 3 additions & 3 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,12 @@ def set_init_device_flags():
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platfrom:
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platfrom = triple
args.iree_metal_target_platform = triple
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}."
f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
Expand Down
4 changes: 2 additions & 2 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platfrom = args.iree_metal_target_platfrom
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir

Expand Down Expand Up @@ -138,7 +138,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platfrom = init_iree_metal_target_platfrom
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
Expand Down
4 changes: 3 additions & 1 deletion shark/iree_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def check_device_drivers(device):
subprocess.check_output("nvidia-smi")
except Exception:
return True
elif device in ["metal", "vulkan"]:
elif device in ["vulkan"]:
try:
subprocess.check_output("vulkaninfo")
except Exception:
return True
elif device == "metal":
return False
elif device in ["intel-gpu"]:
try:
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
Expand Down
10 changes: 5 additions & 5 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-platfrom={triple}"
return f"-iree-metal-target-platform={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
Expand All @@ -101,16 +101,16 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
meatal_triple_flag = arg
metal_triple_flag = arg
break

if metal_triple_flag is None:
meatal_triple_flag = get_metal_triple_flag(
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)

if meatal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag

Expand Down

0 comments on commit 18c8e9e

Please sign in to comment.