diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index 00c3faa2a7..0a905535d3 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -380,7 +380,7 @@ def is_valid_file(arg): ) p.add_argument( - "--iree_metal_target_platfrom", + "--iree_metal_target_platform", type=str, default="", help="Specify target triple for metal", diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 002a2b86a4..6d11f96d08 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -277,12 +277,12 @@ def set_init_device_flags(): args.device = "cuda" elif "metal" in args.device: device_name, args.device = map_device_to_name_path(args.device) - if not args.iree_metal_target_platfrom: + if not args.iree_metal_target_platform: triple = get_metal_target_triple(device_name) if triple is not None: - args.iree_metal_target_platfrom = triple + args.iree_metal_target_platform = triple print( - f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}." + f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}." ) elif "cpu" in args.device: args.device = "cpu" diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 7d2086a398..44e41f1d4c 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -34,7 +34,7 @@ # set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_iree_metal_target_platfrom = args.iree_metal_target_platfrom +init_iree_metal_target_platform = args.iree_metal_target_platform init_use_tuned = args.use_tuned init_import_mlir = args.import_mlir @@ -138,7 +138,7 @@ def txt2img_inf( args.width = width args.device = device.split("=>", 1)[1].strip() args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.iree_metal_target_platfrom = init_iree_metal_target_platfrom + args.iree_metal_target_platform = init_iree_metal_target_platform args.use_tuned = init_use_tuned args.import_mlir = init_import_mlir args.img_path = None diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index 2729af8088..8c79243129 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -101,11 +101,13 @@ def check_device_drivers(device): subprocess.check_output("nvidia-smi") except Exception: return True - elif device in ["metal", "vulkan"]: + elif device in ["vulkan"]: try: subprocess.check_output("vulkaninfo") except Exception: return True + elif device == "metal": + return False elif device in ["intel-gpu"]: try: subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"]) diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py index 7f65488f7b..ef6cdfcc6e 100644 --- a/shark/iree_utils/metal_utils.py +++ b/shark/iree_utils/metal_utils.py @@ -83,7 +83,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): print( f"Found metal device {metal_device}. Using metal target triple {triple}" ) - return f"-iree-metal-target-platfrom={triple}" + return f"-iree-metal-target-platform={triple}" print( """Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] @@ -101,16 +101,16 @@ def get_iree_metal_args(device_num=0, extra_args=[]): for arg in extra_args: if "-iree-metal-target-platform=" in arg: print(f"Using target triple {arg} from command line args") - meatal_triple_flag = arg + metal_triple_flag = arg break if metal_triple_flag is None: - meatal_triple_flag = get_metal_triple_flag( + metal_triple_flag = get_metal_triple_flag( device_num=device_num, extra_args=extra_args ) - if meatal_triple_flag is not None: - vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag) + if metal_triple_flag is not None: + vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag) res_metal_flag.append(vulkan_target_env) return res_metal_flag