From 959be51b0d97bb8a6be42b4266ab40a47c77cd33 Mon Sep 17 00:00:00 2001 From: ranivrsv Date: Tue, 20 Jun 2023 10:41:24 -0700 Subject: [PATCH 1/9] Adding metal_utils for iree_utils --- shark/iree_utils/_common.py | 4 +- shark/iree_utils/compile_utils.py | 8 +- shark/iree_utils/metal_utils.py | 176 ++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 3 deletions(-) create mode 100644 shark/iree_utils/metal_utils.py diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index e090a097be..97c4a3f633 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -66,7 +66,7 @@ def get_supported_device_list(): "cpu-sync": "local-sync", "cuda": "cuda", "vulkan": "vulkan", - "metal": "vulkan", + "metal": "metal", "rocm": "rocm", "intel-gpu": "level_zero", } @@ -84,7 +84,7 @@ def iree_target_map(device): "cpu-sync": "llvm-cpu", "cuda": "cuda", "vulkan": "vulkan", - "metal": "vulkan", + "metal": "metal", "rocm": "rocm", "intel-gpu": "opencl-spirv", } diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 391129961f..dd66641843 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -43,12 +43,18 @@ def get_iree_device_args(device, extra_args=[]): from shark.iree_utils.gpu_utils import get_iree_gpu_args return get_iree_gpu_args() - if device_uri[0] in ["metal", "vulkan"]: + if device_uri[0] == "vulkan": from shark.iree_utils.vulkan_utils import get_iree_vulkan_args return get_iree_vulkan_args( device_num=device_num, extra_args=extra_args ) + if device_uri[0] == "metal": + from shark.iree_utils.metal_utils import get_iree_metal_args + + return get_iree_metal_args( + device_num=device_num, extra_args=extra_args + ) if device_uri[0] == "rocm": from shark.iree_utils.gpu_utils import get_iree_rocm_args diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py new file mode 100644 index 0000000000..fb2bf86023 --- /dev/null +++ b/shark/iree_utils/metal_utils.py @@ -0,0 +1,176 @@ +# Copyright 2020 The Nod Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# All the iree_vulkan related functionalities go here. + +from os import linesep +from shark.iree_utils._common import run_cmd +import iree.runtime as ireert +from sys import platform +from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag + + +def get_metal_device_name(device_num=0): + vulkaninfo_dump, _ = run_cmd("vulkaninfo") + vulkaninfo_dump = vulkaninfo_dump.split(linesep) + vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s] + if len(vulkaninfo_list) == 0: + raise ValueError("No device name found in VulkanInfo!") + if len(vulkaninfo_list) > 1: + print("Following devices found:") + for i, dname in enumerate(vulkaninfo_list): + print(f"{i}. {dname}") + print(f"Choosing device: {vulkaninfo_list[device_num]}") + return vulkaninfo_list[device_num] + + +def get_os_name(): + if platform.startswith("linux"): + return "linux" + elif platform == "darwin": + return "macos" + elif platform == "win32": + return "windows" + else: + print("Cannot detect OS type, defaulting to linux.") + return "linux" + + +def get_metal_target_triple(device_name): + """This method provides a target triple str for specified vulkan device. + + Args: + device_name (str): name of the hardware device to be used with vulkan + + Returns: + str or None: target triple or None if no match found for given name + """ + system_os = get_os_name() + # Apple Targets + if all(x in device_name for x in ("Apple", "M1")): + triple = "m1-moltenvk-macos" + elif all(x in device_name for x in ("Apple", "M2")): + triple = "m1-moltenvk-macos" + + # Nvidia Targets + elif all(x in device_name for x in ("RTX", "2080")): + triple = f"turing-rtx2080-{system_os}" + elif all(x in device_name for x in ("A100", "SXM4")): + triple = f"ampere-a100-{system_os}" + elif all(x in device_name for x in ("RTX", "3090")): + triple = f"ampere-rtx3090-{system_os}" + elif all(x in device_name for x in ("RTX", "3080")): + triple = f"ampere-rtx3080-{system_os}" + elif all(x in device_name for x in ("RTX", "3070")): + triple = f"ampere-rtx3070-{system_os}" + elif all(x in device_name for x in ("RTX", "3060")): + triple = f"ampere-rtx3060-{system_os}" + elif all(x in device_name for x in ("RTX", "3050")): + triple = f"ampere-rtx3050-{system_os}" + # We use ampere until lovelace target triples are plumbed in. + elif all(x in device_name for x in ("RTX", "4090")): + triple = f"ampere-rtx4090-{system_os}" + elif all(x in device_name for x in ("RTX", "4080")): + triple = f"ampere-rtx4080-{system_os}" + elif all(x in device_name for x in ("RTX", "4070")): + triple = f"ampere-rtx4070-{system_os}" + elif all(x in device_name for x in ("RTX", "4000")): + triple = f"turing-rtx4000-{system_os}" + elif all(x in device_name for x in ("RTX", "5000")): + triple = f"turing-rtx5000-{system_os}" + elif all(x in device_name for x in ("RTX", "6000")): + triple = f"turing-rtx6000-{system_os}" + elif all(x in device_name for x in ("RTX", "8000")): + triple = f"turing-rtx8000-{system_os}" + elif all(x in device_name for x in ("TITAN", "RTX")): + triple = f"turing-titanrtx-{system_os}" + elif all(x in device_name for x in ("GTX", "1060")): + triple = f"pascal-gtx1060-{system_os}" + elif all(x in device_name for x in ("GTX", "1070")): + triple = f"pascal-gtx1070-{system_os}" + elif all(x in device_name for x in ("GTX", "1080")): + triple = f"pascal-gtx1080-{system_os}" + + # Amd Targets + # Linux: Radeon RX 7900 XTX + # Windows: AMD Radeon RX 7900 XTX + elif all(x in device_name for x in ("RX", "7900")): + triple = f"rdna3-7900-{system_os}" + elif all(x in device_name for x in ("AMD", "PRO", "W7900")): + triple = f"rdna3-w7900-{system_os}" + elif any(x in device_name for x in ("AMD", "Radeon")): + triple = f"rdna2-unknown-{system_os}" + # Intel Targets + elif any(x in device_name for x in ("A770", "A750")): + triple = f"arc-770-{system_os}" + + # Adreno Targets + elif all(x in device_name for x in ("Adreno", "740")): + triple = f"adreno-a740-{system_os}" + + else: + triple = None + return triple + + +def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): + for flag in extra_args: + if "-iree-metal-target-triple=" in flag: + print(f"Using target triple {flag.split('=')[1]}") + return None + + if device_name == "" or device_name == [] or device_name is None: + metal_device = get_metal_device_name(device_num=device_num) + else: + metal_device = device_name + triple = get_metal_target_triple(metal_device) + if triple is not None: + print( + f"Found metal device {metal_device}. Using metal target triple {triple}" + ) + return f"-iree-metal-target-triple={triple}" + print( + """Optimized kernel for your target device is not added yet. + Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] + or pull up an issue.""" + ) + print(f"Target : {metal_device}") + return None + +def get_iree_metal_args(device_num=0, extra_args=[]): + # res_metal_flag = ["--iree-flow-demote-i64-to-i32"] + + res_metal_flag = [] + metal_triple_flag = None + for arg in extra_args: + if "-iree-metal-target-triple=" in arg: + print(f"Using target triple {arg} from command line args") + meatal_triple_flag = arg + break + + if metal_triple_flag is None: + meatal_triple_flag = get_metal_triple_flag( + device_num=device_num, extra_args=extra_args + ) + + if meatal_triple_flag is not None: + vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag) + res_metal_flag.append(vulkan_target_env) + return res_metal_flag + + +def set_iree_metal_runtime_flags(flags): + for flag in flags: + ireert.flags.parse_flags(flag) + return From 9efacaa50cee866685b20f2fc7d58dc679f98274 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 20 Jun 2023 22:34:17 +0530 Subject: [PATCH 2/9] Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559) -- It also modifies the mega_test.py script Signed-off-by: Abhishek Varma --- apps/stable_diffusion/src/utils/utils.py | 3 ++- shark/examples/shark_inference/mega_test.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 54199f76a1..85d943d93c 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -116,6 +116,7 @@ def compile_through_fx( model_name=None, precision=None, return_mlir=False, + device=None, ): if not return_mlir and model_name is not None: vmfb_path = get_vmfb_path_name(extended_model_name) @@ -157,7 +158,7 @@ def compile_through_fx( shark_module = SharkInference( mlir_module, - device=args.device, + device=args.device if device is None else device, mlir_dialect="tm_tensor", ) if generate_vmfb: diff --git a/shark/examples/shark_inference/mega_test.py b/shark/examples/shark_inference/mega_test.py index de0cc2425b..efc5e70b79 100644 --- a/shark/examples/shark_inference/mega_test.py +++ b/shark/examples/shark_inference/mega_test.py @@ -52,14 +52,22 @@ def forward(self, input): base_model_id=None, model_name="mega_shark", precision=None, - return_mlir=False, + return_mlir=True, device="cuda", ) # logits = model(x) + +def print_output_info(output, msg): + print("\n", msg) + print("\n\t", output.shape) + + ans = shark_module("forward", input) -print(type(ans)) -print("Logits : ", ans.shape) +print_output_info(torch.from_numpy(ans), "SHARK's output") + +ans = megaModel.forward(*input) +print_output_info(ans, "ORIGINAL Model's output") # and sample from the logits accordingly # or you can use the generate function From d49b15716eb4bdcafdc05ee1d951980a8387fa4b Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 20 Jun 2023 12:26:36 -0500 Subject: [PATCH 3/9] [SD] Update unet in_channels API and add PIL metadata to spec. (#1560) * Fix deprecation warning for unet config. * Include PIL metadata instead of hidden imports in SD spec. --- apps/stable_diffusion/shark_sd.spec | 2 +- apps/stable_diffusion/src/models/model_wrappers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/stable_diffusion/shark_sd.spec b/apps/stable_diffusion/shark_sd.spec index ba089f7ff0..e51ca30553 100644 --- a/apps/stable_diffusion/shark_sd.spec +++ b/apps/stable_diffusion/shark_sd.spec @@ -19,6 +19,7 @@ datas += copy_metadata('importlib_metadata') datas += copy_metadata('torch-mlir') datas += copy_metadata('omegaconf') datas += copy_metadata('safetensors') +datas += copy_metadata('Pillow') datas += collect_data_files('diffusers') datas += collect_data_files('transformers') datas += collect_data_files('pytorch_lightning') @@ -48,7 +49,6 @@ block_cipher = None hiddenimports = ['shark', 'shark.shark_inference', 'apps'] hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] -hiddenimports += [x for x in collect_submodules("PIL") if "tests" not in x] a = Analysis( ['web/index.py'], diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 752c9fe1e3..93e416d740 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -426,7 +426,7 @@ def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=sel ) if use_lora != "": update_lora_weight(self.unet, use_lora, "unet") - self.in_channels = self.unet.in_channels + self.in_channels = self.unet.config.in_channels self.train(False) if(args.attention_slicing is not None and args.attention_slicing != "none"): if(args.attention_slicing.isdigit()): From 948cef51c70e5d1c1e3a6b4abc736d9b05e1689c Mon Sep 17 00:00:00 2001 From: ranivrsv Date: Tue, 20 Jun 2023 11:51:36 -0700 Subject: [PATCH 4/9] Fixing iree-metal-target-platform --- shark/iree_utils/metal_utils.py | 63 ++------------------------------- 1 file changed, 3 insertions(+), 60 deletions(-) diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py index fb2bf86023..0bd27d71dc 100644 --- a/shark/iree_utils/metal_utils.py +++ b/shark/iree_utils/metal_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Nod Team. All rights reserved. +# Copyright 2023 The Nod Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,69 +56,12 @@ def get_metal_target_triple(device_name): Returns: str or None: target triple or None if no match found for given name """ - system_os = get_os_name() # Apple Targets if all(x in device_name for x in ("Apple", "M1")): triple = "m1-moltenvk-macos" elif all(x in device_name for x in ("Apple", "M2")): triple = "m1-moltenvk-macos" - # Nvidia Targets - elif all(x in device_name for x in ("RTX", "2080")): - triple = f"turing-rtx2080-{system_os}" - elif all(x in device_name for x in ("A100", "SXM4")): - triple = f"ampere-a100-{system_os}" - elif all(x in device_name for x in ("RTX", "3090")): - triple = f"ampere-rtx3090-{system_os}" - elif all(x in device_name for x in ("RTX", "3080")): - triple = f"ampere-rtx3080-{system_os}" - elif all(x in device_name for x in ("RTX", "3070")): - triple = f"ampere-rtx3070-{system_os}" - elif all(x in device_name for x in ("RTX", "3060")): - triple = f"ampere-rtx3060-{system_os}" - elif all(x in device_name for x in ("RTX", "3050")): - triple = f"ampere-rtx3050-{system_os}" - # We use ampere until lovelace target triples are plumbed in. - elif all(x in device_name for x in ("RTX", "4090")): - triple = f"ampere-rtx4090-{system_os}" - elif all(x in device_name for x in ("RTX", "4080")): - triple = f"ampere-rtx4080-{system_os}" - elif all(x in device_name for x in ("RTX", "4070")): - triple = f"ampere-rtx4070-{system_os}" - elif all(x in device_name for x in ("RTX", "4000")): - triple = f"turing-rtx4000-{system_os}" - elif all(x in device_name for x in ("RTX", "5000")): - triple = f"turing-rtx5000-{system_os}" - elif all(x in device_name for x in ("RTX", "6000")): - triple = f"turing-rtx6000-{system_os}" - elif all(x in device_name for x in ("RTX", "8000")): - triple = f"turing-rtx8000-{system_os}" - elif all(x in device_name for x in ("TITAN", "RTX")): - triple = f"turing-titanrtx-{system_os}" - elif all(x in device_name for x in ("GTX", "1060")): - triple = f"pascal-gtx1060-{system_os}" - elif all(x in device_name for x in ("GTX", "1070")): - triple = f"pascal-gtx1070-{system_os}" - elif all(x in device_name for x in ("GTX", "1080")): - triple = f"pascal-gtx1080-{system_os}" - - # Amd Targets - # Linux: Radeon RX 7900 XTX - # Windows: AMD Radeon RX 7900 XTX - elif all(x in device_name for x in ("RX", "7900")): - triple = f"rdna3-7900-{system_os}" - elif all(x in device_name for x in ("AMD", "PRO", "W7900")): - triple = f"rdna3-w7900-{system_os}" - elif any(x in device_name for x in ("AMD", "Radeon")): - triple = f"rdna2-unknown-{system_os}" - # Intel Targets - elif any(x in device_name for x in ("A770", "A750")): - triple = f"arc-770-{system_os}" - - # Adreno Targets - elif all(x in device_name for x in ("Adreno", "740")): - triple = f"adreno-a740-{system_os}" - else: triple = None return triple @@ -126,7 +69,7 @@ def get_metal_target_triple(device_name): def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): for flag in extra_args: - if "-iree-metal-target-triple=" in flag: + if "-iree-metal-target-platform=" in flag: print(f"Using target triple {flag.split('=')[1]}") return None @@ -154,7 +97,7 @@ def get_iree_metal_args(device_num=0, extra_args=[]): res_metal_flag = [] metal_triple_flag = None for arg in extra_args: - if "-iree-metal-target-triple=" in arg: + if "-iree-metal-target-platform=" in arg: print(f"Using target triple {arg} from command line args") meatal_triple_flag = arg break From 51f98b953f9e5d7b9f570fe70d2fc9a6e41d2569 Mon Sep 17 00:00:00 2001 From: ranivrsv Date: Tue, 20 Jun 2023 15:34:07 -0700 Subject: [PATCH 5/9] adding metal to txt2img pipeline --- apps/stable_diffusion/src/utils/stable_args.py | 7 +++++++ apps/stable_diffusion/src/utils/utils.py | 13 +++++++++++++ apps/stable_diffusion/web/ui/txt2img_ui.py | 2 ++ 3 files changed, 22 insertions(+) diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index bb45f92fce..00c3faa2a7 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -379,6 +379,13 @@ def is_valid_file(arg): help="Specify target triple for vulkan", ) +p.add_argument( + "--iree_metal_target_platfrom", + type=str, + default="", + help="Specify target triple for metal", +) + p.add_argument( "--vulkan_debug_utils", default=False, diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 85d943d93c..3d08b530a6 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -18,6 +18,7 @@ set_iree_vulkan_runtime_flags, get_vulkan_target_triple, ) +from shark.iree_utils.metal_utils import get_metal_target_triple from shark.iree_utils.gpu_utils import get_cuda_sm_cc from apps.stable_diffusion.src.utils.stable_args import args from apps.stable_diffusion.src.utils.resources import opt_flags @@ -274,6 +275,16 @@ def set_init_device_flags(): ) elif "cuda" in args.device: args.device = "cuda" + elif "metal" in args.device: + print("\n\n yaha \n\n") + device_name, args.device = map_device_to_name_path(args.device) + if not args.iree_metal_target_platfrom: + triple = get_metal_target_triple(device_name) + if triple is not None: + args.iree_metal_target_platfrom = triple + print( + f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}." + ) elif "cpu" in args.device: args.device = "cpu" @@ -426,6 +437,8 @@ def get_devices_by_name(driver_name): available_devices = [] vulkan_devices = get_devices_by_name("vulkan") available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) cuda_devices = get_devices_by_name("cuda") available_devices.extend(cuda_devices) available_devices.append("device => cpu") diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index a42e74b311..7d2086a398 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -34,6 +34,7 @@ # set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. init_iree_vulkan_target_triple = args.iree_vulkan_target_triple +init_iree_metal_target_platfrom = args.iree_metal_target_platfrom init_use_tuned = args.use_tuned init_import_mlir = args.import_mlir @@ -137,6 +138,7 @@ def txt2img_inf( args.width = width args.device = device.split("=>", 1)[1].strip() args.iree_vulkan_target_triple = init_iree_vulkan_target_triple + args.iree_metal_target_platfrom = init_iree_metal_target_platfrom args.use_tuned = init_use_tuned args.import_mlir = init_import_mlir args.img_path = None From 2cf12eaae90966bda1ee8878765e34c430afc7dd Mon Sep 17 00:00:00 2001 From: ranivrsv Date: Tue, 20 Jun 2023 15:50:37 -0700 Subject: [PATCH 6/9] Fixing Copyright date --- shark/iree_utils/_common.py | 2 +- shark/iree_utils/compile_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index 97c4a3f633..2729af8088 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Nod Team. All rights reserved. +# Copyright 2023 The Nod Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index dd66641843..9d6fc658f6 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Nod Team. All rights reserved. +# Copyright 2023 The Nod Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 89d3c17c587929abd5fdb19086c157cfd75e7ce0 Mon Sep 17 00:00:00 2001 From: ranivrsv Date: Tue, 20 Jun 2023 17:35:40 -0700 Subject: [PATCH 7/9] removing debug prints --- apps/stable_diffusion/src/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 3d08b530a6..d53a3105bf 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -276,7 +276,6 @@ def set_init_device_flags(): elif "cuda" in args.device: args.device = "cuda" elif "metal" in args.device: - print("\n\n yaha \n\n") device_name, args.device = map_device_to_name_path(args.device) if not args.iree_metal_target_platfrom: triple = get_metal_target_triple(device_name) From 6fa8404aa6534b7d9f84221e122251322e2aebbd Mon Sep 17 00:00:00 2001 From: ranivrsv Date: Tue, 20 Jun 2023 21:45:31 -0700 Subject: [PATCH 8/9] black lint formating --- shark/iree_utils/compile_utils.py | 2 +- shark/iree_utils/metal_utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 9d6fc658f6..8ddb530b38 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -51,7 +51,7 @@ def get_iree_device_args(device, extra_args=[]): ) if device_uri[0] == "metal": from shark.iree_utils.metal_utils import get_iree_metal_args - + return get_iree_metal_args( device_num=device_num, extra_args=extra_args ) diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py index 0bd27d71dc..dbc400edf4 100644 --- a/shark/iree_utils/metal_utils.py +++ b/shark/iree_utils/metal_utils.py @@ -91,6 +91,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): print(f"Target : {metal_device}") return None + def get_iree_metal_args(device_num=0, extra_args=[]): # res_metal_flag = ["--iree-flow-demote-i64-to-i32"] @@ -101,7 +102,7 @@ def get_iree_metal_args(device_num=0, extra_args=[]): print(f"Using target triple {arg} from command line args") meatal_triple_flag = arg break - + if metal_triple_flag is None: meatal_triple_flag = get_metal_triple_flag( device_num=device_num, extra_args=extra_args From 88753b975a24d6837f00031a22d77cff9b72879d Mon Sep 17 00:00:00 2001 From: ranivrsv Date: Tue, 20 Jun 2023 23:14:51 -0700 Subject: [PATCH 9/9] fixing device dump --- shark/iree_utils/metal_utils.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py index dbc400edf4..7f65488f7b 100644 --- a/shark/iree_utils/metal_utils.py +++ b/shark/iree_utils/metal_utils.py @@ -14,7 +14,6 @@ # All the iree_vulkan related functionalities go here. -from os import linesep from shark.iree_utils._common import run_cmd import iree.runtime as ireert from sys import platform @@ -22,17 +21,19 @@ def get_metal_device_name(device_num=0): - vulkaninfo_dump, _ = run_cmd("vulkaninfo") - vulkaninfo_dump = vulkaninfo_dump.split(linesep) - vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s] - if len(vulkaninfo_list) == 0: - raise ValueError("No device name found in VulkanInfo!") - if len(vulkaninfo_list) > 1: + iree_device_dump = run_cmd("iree-run-module --dump_devices") + iree_device_dump = iree_device_dump[0].split("\n\n") + metal_device_list = [ + s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s + ] + if len(metal_device_list) == 0: + raise ValueError("No device name found in device dump!") + if len(metal_device_list) > 1: print("Following devices found:") - for i, dname in enumerate(vulkaninfo_list): + for i, dname in enumerate(metal_device_list): print(f"{i}. {dname}") - print(f"Choosing device: {vulkaninfo_list[device_num]}") - return vulkaninfo_list[device_num] + print(f"Choosing device: {metal_device_list[device_num]}") + return metal_device_list[device_num] def get_os_name(): @@ -82,7 +83,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]): print( f"Found metal device {metal_device}. Using metal target triple {triple}" ) - return f"-iree-metal-target-triple={triple}" + return f"-iree-metal-target-platfrom={triple}" print( """Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]