From 9cd6da4593ad3325dd22153a25ca451504498aa1 Mon Sep 17 00:00:00 2001 From: PhaneeshB Date: Fri, 8 Dec 2023 06:23:30 +0530 Subject: [PATCH] vulkan device id fix --- apps/language_models/scripts/vicuna.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 544c3fdea0..548dcb8344 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1616,6 +1616,10 @@ def get_model_path(self, suffix="mlir"): self.vulkan_target_triple.split("-")[:-1] ) differentiator = target_triple + else: + from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag + tt = get_vulkan_triple_flag(device_num=self.device_id) + differentiator = "_" + "_".join(tt.split("=")[1].split('-')[:-1]) elif "rocm" == self.device: from shark.iree_utils.gpu_utils import get_rocm_device_arch @@ -2355,6 +2359,10 @@ def avg_and_stdev(data): break id += 1 + if "://" in device : + from shark.iree_utils.compile_utils import clean_device_info + _, device_id = clean_device_info(args.device) + assert ( device_id ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"