Skip to content

Commit

Permalink
adding metal to txt2img pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Ranvirsv committed Jun 20, 2023
1 parent 948cef5 commit 51f98b9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
7 changes: 7 additions & 0 deletions apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ def is_valid_file(arg):
help="Specify target triple for vulkan",
)

p.add_argument(
"--iree_metal_target_platfrom",
type=str,
default="",
help="Specify target triple for metal",
)

p.add_argument(
"--vulkan_debug_utils",
default=False,
Expand Down
13 changes: 13 additions & 0 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
Expand Down Expand Up @@ -274,6 +275,16 @@ def set_init_device_flags():
)
elif "cuda" in args.device:
args.device = "cuda"
elif "metal" in args.device:
print("\n\n yaha \n\n")
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platfrom:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platfrom = triple
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}."
)
elif "cpu" in args.device:
args.device = "cpu"

Expand Down Expand Up @@ -426,6 +437,8 @@ def get_devices_by_name(driver_name):
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
Expand Down
2 changes: 2 additions & 0 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platfrom = args.iree_metal_target_platfrom
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir

Expand Down Expand Up @@ -137,6 +138,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platfrom = init_iree_metal_target_platfrom
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
Expand Down

0 comments on commit 51f98b9

Please sign in to comment.