diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index bb45f92fce..00c3faa2a7 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -379,6 +379,13 @@ def is_valid_file(arg): help="Specify target triple for vulkan", ) +p.add_argument( + "--iree_metal_target_platfrom", + type=str, + default="", + help="Specify target triple for metal", +) + p.add_argument( "--vulkan_debug_utils", default=False, diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 85d943d93c..3d08b530a6 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -18,6 +18,7 @@ set_iree_vulkan_runtime_flags, get_vulkan_target_triple, ) +from shark.iree_utils.metal_utils import get_metal_target_triple from shark.iree_utils.gpu_utils import get_cuda_sm_cc from apps.stable_diffusion.src.utils.stable_args import args from apps.stable_diffusion.src.utils.resources import opt_flags @@ -274,6 +275,16 @@ def set_init_device_flags(): ) elif "cuda" in args.device: args.device = "cuda" + elif "metal" in args.device: + print("\n\n yaha \n\n") + device_name, args.device = map_device_to_name_path(args.device) + if not args.iree_metal_target_platfrom: + triple = get_metal_target_triple(device_name) + if triple is not None: + args.iree_metal_target_platfrom = triple + print( + f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}." + ) elif "cpu" in args.device: args.device = "cpu" @@ -426,6 +437,8 @@ def get_devices_by_name(driver_name): available_devices = [] vulkan_devices = get_devices_by_name("vulkan") available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) cuda_devices = get_devices_by_name("cuda") available_devices.extend(cuda_devices) available_devices.append("device => cpu") diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index a42e74b311..7d2086a398 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -34,6 +34,7 @@ # set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. init_iree_vulkan_target_triple = args.iree_vulkan_target_triple +init_iree_metal_target_platfrom = args.iree_metal_target_platfrom init_use_tuned = args.use_tuned init_import_mlir = args.import_mlir @@ -137,6 +138,7 @@ def txt2img_inf( args.width = width args.device = device.split("=>", 1)[1].strip() args.iree_vulkan_target_triple = init_iree_vulkan_target_triple + args.iree_metal_target_platfrom = init_iree_metal_target_platfrom args.use_tuned = init_use_tuned args.import_mlir = init_import_mlir args.img_path = None