Skip to content

Commit

Permalink
vulkan device id fix (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhaneeshB committed Dec 9, 2023
1 parent 7159698 commit bf70e80
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,10 @@ def get_model_path(self, suffix="mlir"):
self.vulkan_target_triple.split("-")[:-1]
)
differentiator = target_triple
else:
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
tt = get_vulkan_triple_flag(device_num=self.device_id)
differentiator = "_" + "_".join(tt.split("=")[1].split('-')[:-1])

elif "rocm" == self.device:
from shark.iree_utils.gpu_utils import get_rocm_device_arch
Expand Down Expand Up @@ -2355,6 +2359,10 @@ def avg_and_stdev(data):
break
id += 1

if "://" in device :
from shark.iree_utils.compile_utils import clean_device_info
_, device_id = clean_device_info(args.device)

assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
Expand Down

0 comments on commit bf70e80

Please sign in to comment.