diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index bb45f92fce..00c3faa2a7 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -379,6 +379,13 @@ def is_valid_file(arg): help="Specify target triple for vulkan", ) +p.add_argument( + "--iree_metal_target_platfrom", + type=str, + default="", + help="Specify target triple for metal", +) + p.add_argument( "--vulkan_debug_utils", default=False, diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index aca9169a9c..2ca9b9aef4 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -18,6 +18,7 @@ set_iree_vulkan_runtime_flags, get_vulkan_target_triple, ) +from shark.iree_utils.metal_utils import get_metal_target_triple from shark.iree_utils.gpu_utils import get_cuda_sm_cc from apps.stable_diffusion.src.utils.stable_args import args from apps.stable_diffusion.src.utils.resources import opt_flags @@ -274,6 +275,15 @@ def set_init_device_flags(): ) elif "cuda" in args.device: args.device = "cuda" + elif "metal" in args.device: + device_name, args.device = map_device_to_name_path(args.device) + if not args.iree_metal_target_platfrom: + triple = get_metal_target_triple(device_name) + if triple is not None: + args.iree_metal_target_platfrom = triple + print( + f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}." + ) elif "cpu" in args.device: args.device = "cpu" @@ -431,6 +441,8 @@ def get_devices_by_name(driver_name): available_devices = [] vulkan_devices = get_devices_by_name("vulkan") available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) cuda_devices = get_devices_by_name("cuda") available_devices.extend(cuda_devices) available_devices.append("device => cpu") diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index a42e74b311..7d2086a398 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -34,6 +34,7 @@ # set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. init_iree_vulkan_target_triple = args.iree_vulkan_target_triple +init_iree_metal_target_platfrom = args.iree_metal_target_platfrom init_use_tuned = args.use_tuned init_import_mlir = args.import_mlir @@ -137,6 +138,7 @@ def txt2img_inf( args.width = width args.device = device.split("=>", 1)[1].strip() args.iree_vulkan_target_triple = init_iree_vulkan_target_triple + args.iree_metal_target_platfrom = init_iree_metal_target_platfrom args.use_tuned = init_use_tuned args.import_mlir = init_import_mlir args.img_path = None diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index e090a097be..2729af8088 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Nod Team. All rights reserved. +# Copyright 2023 The Nod Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ def get_supported_device_list(): "cpu-sync": "local-sync", "cuda": "cuda", "vulkan": "vulkan", - "metal": "vulkan", + "metal": "metal", "rocm": "rocm", "intel-gpu": "level_zero", } @@ -84,7 +84,7 @@ def iree_target_map(device): "cpu-sync": "llvm-cpu", "cuda": "cuda", "vulkan": "vulkan", - "metal": "vulkan", + "metal": "metal", "rocm": "rocm", "intel-gpu": "opencl-spirv", } diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 391129961f..8ddb530b38 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Nod Team. All rights reserved. +# Copyright 2023 The Nod Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,12 +43,18 @@ def get_iree_device_args(device, extra_args=[]): from shark.iree_utils.gpu_utils import get_iree_gpu_args return get_iree_gpu_args() - if device_uri[0] in ["metal", "vulkan"]: + if device_uri[0] == "vulkan": from shark.iree_utils.vulkan_utils import get_iree_vulkan_args return get_iree_vulkan_args( device_num=device_num, extra_args=extra_args ) + if device_uri[0] == "metal": + from shark.iree_utils.metal_utils import get_iree_metal_args + + return get_iree_metal_args( + device_num=device_num, extra_args=extra_args + ) if device_uri[0] == "rocm": from shark.iree_utils.gpu_utils import get_iree_rocm_args diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py new file mode 100644 index 0000000000..7f65488f7b --- /dev/null +++ b/shark/iree_utils/metal_utils.py @@ -0,0 +1,121 @@ +# Copyright 2023 The Nod Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# All the iree_vulkan related functionalities go here. + +from shark.iree_utils._common import run_cmd +import iree.runtime as ireert +from sys import platform +from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag + + +def get_metal_device_name(device_num=0): + iree_device_dump = run_cmd("iree-run-module --dump_devices") + iree_device_dump = iree_device_dump[0].split("\n\n") + metal_device_list = [ + s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s + ] + if len(metal_device_list) == 0: + raise ValueError("No device name found in device dump!") + if len(metal_device_list) > 1: + print("Following devices found:") + for i, dname in enumerate(metal_device_list): + print(f"{i}. {dname}") + print(f"Choosing device: {metal_device_list[device_num]}") + return metal_device_list[device_num] + + +def get_os_name(): + if platform.startswith("linux"): + return "linux" + elif platform == "darwin": + return "macos" + elif platform == "win32": + return "windows" + else: + print("Cannot detect OS type, defaulting to linux.") + return "linux" + + +def get_metal_target_triple(device_name): + """This method provides a target triple str for specified vulkan device. + + Args: + device_name (str): name of the hardware device to be used with vulkan + + Returns: + str or None: target triple or None if no match found for given name + """ + # Apple Targets + if all(x in device_name for x in ("Apple", "M1")): + triple = "m1-moltenvk-macos" + elif all(x in device_name for x in ("Apple", "M2")): + triple = "m1-moltenvk-macos" + + else: + triple = None + return triple + + +def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): + for flag in extra_args: + if "-iree-metal-target-platform=" in flag: + print(f"Using target triple {flag.split('=')[1]}") + return None + + if device_name == "" or device_name == [] or device_name is None: + metal_device = get_metal_device_name(device_num=device_num) + else: + metal_device = device_name + triple = get_metal_target_triple(metal_device) + if triple is not None: + print( + f"Found metal device {metal_device}. Using metal target triple {triple}" + ) + return f"-iree-metal-target-platfrom={triple}" + print( + """Optimized kernel for your target device is not added yet. + Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] + or pull up an issue.""" + ) + print(f"Target : {metal_device}") + return None + + +def get_iree_metal_args(device_num=0, extra_args=[]): + # res_metal_flag = ["--iree-flow-demote-i64-to-i32"] + + res_metal_flag = [] + metal_triple_flag = None + for arg in extra_args: + if "-iree-metal-target-platform=" in arg: + print(f"Using target triple {arg} from command line args") + meatal_triple_flag = arg + break + + if metal_triple_flag is None: + meatal_triple_flag = get_metal_triple_flag( + device_num=device_num, extra_args=extra_args + ) + + if meatal_triple_flag is not None: + vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag) + res_metal_flag.append(vulkan_target_env) + return res_metal_flag + + +def set_iree_metal_runtime_flags(flags): + for flag in flags: + ireert.flags.parse_flags(flag) + return