diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 73c0417bc7..8aacf18df5 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -110,6 +110,110 @@ def chat_fn( history[-1][-1] = "Getting the model ready... Done" yield history, "" history[-1][-1] = "" + if "cuda" in device: + device = "cuda" + elif "sync" in device: + device = "cpu-sync" + elif "task" in device: + device = "cpu-task" + elif "vulkan" in device: + device_id = int(device.split("://")[1]) + device = "vulkan" + elif "rocm" in device: + device = "rocm" + else: + print("unrecognized device") + + from apps.language_models.scripts.vicuna import ShardedVicuna + from apps.language_models.scripts.vicuna import UnshardedVicuna + from apps.stable_diffusion.src import args + + new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}" + if vicuna_model is None or new_model_vmfb_key != model_vmfb_key: + model_vmfb_key = new_model_vmfb_key + max_toks = 128 if model_name == "codegen" else 512 + + # get iree flags that need to be overridden, from commandline args + _extra_args = [] + # vulkan target triple + vulkan_target_triple = args.iree_vulkan_target_triple + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + get_vulkan_target_triple, + ) + + if device == "vulkan": + vulkaninfo_list = get_all_vulkan_devices() + if vulkan_target_triple == "": + # We already have the device_id extracted via WebUI, so we directly use + # that to find the target triple. + vulkan_target_triple = get_vulkan_target_triple( + vulkaninfo_list[device_id] + ) + _extra_args.append( + f"-iree-vulkan-target-triple={vulkan_target_triple}" + ) + if "rdna" in vulkan_target_triple: + flags_to_add = [ + "--iree-spirv-index-bits=64", + ] + _extra_args = _extra_args + flags_to_add + + if device_id is None: + id = 0 + for device in vulkaninfo_list: + target_triple = get_vulkan_target_triple( + vulkaninfo_list[id] + ) + if target_triple == vulkan_target_triple: + device_id = id + break + id += 1 + + assert ( + device_id + ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" + print(f"Will use vulkan target triple : {vulkan_target_triple}") + + elif "rocm" in device: + # add iree rocm flags + _extra_args.append( + f"--iree-rocm-target-chip={args.iree_rocm_target_chip}" + ) + print(f"extra args = {_extra_args}") + + if model_name == "vicuna4": + vicuna_model = ShardedVicuna( + model_name, + hf_model_path=model_path, + device=device, + precision=precision, + max_num_tokens=max_toks, + compressed=True, + extra_args_cmd=_extra_args, + ) + else: + # if config_file is None: + vicuna_model = UnshardedVicuna( + model_name, + hf_model_path=model_path, + hf_auth_token=args.hf_auth_token, + device=device, + vulkan_target_triple=vulkan_target_triple, + precision=precision, + max_num_tokens=max_toks, + download_vmfb=download_vmfb, + load_mlir_from_shark_tank=True, + extra_args_cmd=_extra_args, + device_id=device_id, + ) + + if vicuna_model is None: + sys.exit("Unable to instantiate the model object, exiting.") + + prompt = create_prompt(model_name, history, prompt_prefix) + + partial_text = "" token_count = 0 total_time = 0.001 # In order to avoid divide by zero error prefill_time = 0