From 05b498267ebfa8f2caad1a06b3591fb6ab28d3c3 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 18 Jan 2024 19:01:07 -0600 Subject: [PATCH] Add StreamingLLM support to studio2 chat (#2060) * Streaming LLM * Update precision and add gpu support * (studio2) Separate weights generation for quantization support * Adapt prompt changes to studio flow * Remove outdated flag from llm compile flags. * (studio2) use turbine vmfbRunner * tweaks to prompts * Update CPU path and llm api test. * Change device in test to cpu. * Fixes to runner, device names, vmfb mgmt * Use small test without external weights. --- apps/shark_studio/api/llm.py | 210 +++++++++++++++++++++------- apps/shark_studio/api/utils.py | 199 +++++++++++++++++++++++++- apps/shark_studio/tests/api_test.py | 16 ++- apps/shark_studio/web/ui/chat.py | 146 ++++++------------- apps/shark_studio/web/utils.py | 12 ++ 5 files changed, 411 insertions(+), 172 deletions(-) create mode 100644 apps/shark_studio/web/utils.py diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index a209d8d1ba..a9d39f8e7b 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,10 +1,9 @@ from turbine_models.custom_models import stateless_llama +from turbine_models.model_runner import vmfbRunner +from turbine_models.gen_external_params.gen_external_params import gen_external_params import time -from shark.iree_utils.compile_utils import ( - get_iree_compiled_module, - load_vmfb_using_mmap, -) -from apps.shark_studio.api.utils import get_resource_path +from shark.iree_utils.compile_utils import compile_module_to_flatbuffer +from apps.shark_studio.web.utils import get_resource_path import iree.runtime as ireert from itertools import chain import gc @@ -16,6 +15,7 @@ "llama2_7b": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], "stop_token": 2, "max_tokens": 4096, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", @@ -23,12 +23,34 @@ "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], "stop_token": 2, "max_tokens": 4096, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", }, + "TinyPixel/small-llama2": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "TinyPixel/small-llama2", + "compile_flags": ["--iree-opt-const-expr-hoisting=True"], + "stop_token": 2, + "max_tokens": 1024, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, } +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "", "" + +DEFAULT_CHAT_SYS_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n +""" + + +def append_user_prompt(history, input_prompt): + user_prompt = f"{B_INST} {input_prompt} {E_INST}" + history += user_prompt + return history + class LanguageModel: def __init__( @@ -36,41 +58,85 @@ def __init__( model_name, hf_auth_token=None, device=None, - precision="fp32", + quantization="int4", + precision="", external_weights=None, use_system_prompt=True, + streaming_llm=False, ): - print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.tempfile_name = get_resource_path("llm.torch.tempfile") - self.vmfb_name = get_resource_path("llm.vmfb.tempfile") - self.device = device - self.precision = precision - self.safe_name = self.hf_model_name.strip("/").replace("/", "_") - self.max_tokens = llm_model_map[model_name]["max_tokens"] - self.iree_module_dict = None + self.device = device.split("=>")[-1].strip() + self.backend = self.device.split("://")[0] + self.driver = self.backend + if "cpu" in device: + self.device = "cpu" + self.backend = "llvm-cpu" + self.driver = "local-task" + + print(f"Selected {self.backend} as IREE target backend.") + self.precision = "f32" if "cpu" in device else "f16" + self.quantization = quantization + self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_") self.external_weight_file = None + # TODO: find a programmatic solution for model arch spec instead of hardcoding llama2 + self.file_spec = "_".join( + [ + self.safe_name, + self.precision, + ] + ) + if self.quantization != "None": + self.file_spec += "_" + self.quantization + if external_weights is not None: self.external_weight_file = get_resource_path( - self.safe_name + "." + external_weights + self.file_spec + "." + external_weights ) + + if streaming_llm: + # Add streaming suffix to file spec after setting external weights filename. + self.file_spec += "_streaming" + self.streaming_llm = streaming_llm + + self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") + # TODO: Tag vmfb with target triple of device instead of HAL backend + self.vmfb_name = get_resource_path( + f"{self.file_spec}_{self.backend}.vmfb.tempfile" + ) + self.max_tokens = llm_model_map[model_name]["max_tokens"] + self.iree_module_dict = None self.use_system_prompt = use_system_prompt self.global_iter = 0 + self.prev_token_len = 0 + self.first_input = True + if self.external_weight_file is not None: + if not os.path.exists(self.external_weight_file): + print( + f"External weight file {self.external_weight_file} does not exist. Generating..." + ) + gen_external_params( + hf_model_name=self.hf_model_name, + quantization=self.quantization, + weight_path=self.external_weight_file, + hf_auth_token=hf_auth_token, + precision=self.precision, + ) + else: + print( + f"External weight file {self.external_weight_file} found for {self.vmfb_name}" + ) if os.path.exists(self.vmfb_name) and ( external_weights is None or os.path.exists(str(self.external_weight_file)) ): - self.iree_module_dict = dict() - ( - self.iree_module_dict["vmfb"], - self.iree_module_dict["config"], - self.iree_module_dict["temp_file_to_unlink"], - ) = load_vmfb_using_mmap( - self.vmfb_name, - device, - device_idx=0, - rt_flags=[], - external_weight_file=self.external_weight_file, + self.runner = vmfbRunner( + device=self.driver, + vmfb_path=self.vmfb_name, + external_weight_path=self.external_weight_file, ) + if self.streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update self.tokenizer = AutoTokenizer.from_pretrained( self.hf_model_name, use_fast=False, @@ -82,7 +148,9 @@ def __init__( hf_auth_token, compile_to="torch", external_weights=external_weights, - external_weight_file=self.external_weight_file, + precision=self.precision, + quantization=self.quantization, + streaming_llm=self.streaming_llm, ) with open(self.tempfile_name, "w+") as f: f.write(self.torch_ir) @@ -99,19 +167,37 @@ def __init__( def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". - self.iree_module_dict = get_iree_compiled_module( + # ONLY architecture/api-specific compile-time flags for each backend, if needed. + # hf_model_id-specific global flags currently in model map. + flags = [] + if "cpu" in self.backend: + flags.extend( + [ + "--iree-global-opt-enable-quantized-matmul-reassociation", + ] + ) + elif self.backend == "vulkan": + flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) + flags.extend(llm_model_map[self.hf_model_name]["compile_flags"]) + flatbuffer_blob = compile_module_to_flatbuffer( self.tempfile_name, device=self.device, - mmap=True, frontend="torch", - external_weight_file=self.external_weight_file, + model_config_path=None, + extra_args=flags, write_to=self.vmfb_name, - extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"], ) - # TODO: delete the temp file + self.runner = vmfbRunner( + device=self.driver, + vmfb_path=self.vmfb_name, + external_weight_path=self.external_weight_file, + ) + if self.streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update def sanitize_prompt(self, prompt): - print(prompt) if isinstance(prompt, list): prompt = list(chain.from_iterable(prompt)) prompt = " ".join([x for x in prompt if isinstance(x, str)]) @@ -119,10 +205,12 @@ def sanitize_prompt(self, prompt): prompt = prompt.replace("\t", " ") prompt = prompt.replace("\r", " ") if self.use_system_prompt and self.global_iter == 0: - prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt - prompt += " [/INST]" - print(prompt) - return prompt + prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt) + print(prompt) + return prompt + else: + print(prompt) + return f"{B_INST} {prompt} {E_INST}" def chat(self, prompt): prompt = self.sanitize_prompt(prompt) @@ -134,26 +222,40 @@ def format_out(results): history = [] for iter in range(self.max_tokens): - st_time = time.time() - if iter == 0: - device_inputs = [ - ireert.asdevicearray( - self.iree_module_dict["config"].device, input_tensor - ) - ] - token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) + if self.streaming_llm: + token_slice = max(self.prev_token_len - 1, 0) + input_tensor = input_tensor[:, token_slice:] + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token_len = input_tensor.shape[-1] + device_inputs = [ + ireert.asdevicearray(self.runner.config.device, input_tensor) + ] + if self.first_input or not self.streaming_llm: + st_time = time.time() + token = self.model["run_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 + self.first_input = False else: - device_inputs = [ - ireert.asdevicearray( - self.iree_module_dict["config"].device, - token, - ) - ] - token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) + st_time = time.time() + token = self.model["run_cached_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 - total_time = time.time() - st_time history.append(format_out(token)) - yield self.tokenizer.decode(history), total_time + while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]: + dec_time = time.time() + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token = self.model["run_forward"](token) + history.append(format_out(token)) + total_time = time.time() - dec_time + yield self.tokenizer.decode(history), total_time + + self.prev_token_len = token_len + len(history) if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 4072491cbf..7a6e9bb4b7 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -1,12 +1,197 @@ -import os -import sys +import numpy as np +import json +from random import ( + randint, + seed as seed_random, + getstate as random_getstate, + setstate as random_setstate, +) + +from pathlib import Path + +# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from cpuinfo import get_cpu_info + +# TODO: migrate these utils to studio +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, + get_iree_vulkan_runtime_flags, +) def get_available_devices(): - return ["cpu-task"] + def get_devices_by_name(driver_name): + from shark.iree_utils._common import iree_device_map + + device_list = [] + try: + driver_name = iree_device_map(driver_name) + device_list_dict = get_all_devices(driver_name) + print(f"{driver_name} devices are available.") + except: + print(f"{driver_name} devices are not available.") + else: + cpu_name = get_cpu_info()["brand_raw"] + for i, device in enumerate(device_list_dict): + device_name = ( + cpu_name if device["name"] == "default" else device["name"] + ) + if "local" in driver_name: + device_list.append( + f"{device_name} => {driver_name.replace('local', 'cpu')}" + ) + else: + # for drivers with single devices + # let the default device be selected without any indexing + if len(device_list_dict) == 1: + device_list.append(f"{device_name} => {driver_name}") + else: + device_list.append(f"{device_name} => {driver_name}://{i}") + return device_list + + set_iree_runtime_flags() + + available_devices = [] + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + ) + + vulkaninfo_list = get_all_vulkan_devices() + vulkan_devices = [] + id = 0 + for device in vulkaninfo_list: + vulkan_devices.append(f"{device.strip()} => vulkan://{id}") + id += 1 + if id != 0: + print(f"vulkan devices are available.") + available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) + cuda_devices = get_devices_by_name("cuda") + available_devices.extend(cuda_devices) + rocm_devices = get_devices_by_name("rocm") + available_devices.extend(rocm_devices) + cpu_device = get_devices_by_name("cpu-sync") + available_devices.extend(cpu_device) + cpu_device = get_devices_by_name("cpu-task") + available_devices.extend(cpu_device) + return available_devices + + +def set_iree_runtime_flags(): + # TODO: This function should be device-agnostic and piped properly + # to general runtime driver init. + vulkan_runtime_flags = get_iree_vulkan_runtime_flags() + + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + + +def get_all_devices(driver_name): + """ + Inputs: driver_name + Returns a list of all the available devices for a given driver sorted by + the iree path names of the device as in --list_devices option in iree. + """ + from iree.runtime import get_driver + + driver = get_driver(driver_name) + device_list_src = driver.query_available_devices() + device_list_src.sort(key=lambda d: d["path"]) + return device_list_src + + +def get_device_mapping(driver, key_combination=3): + """This method ensures consistent device ordering when choosing + specific devices for execution + Args: + driver (str): execution driver (vulkan, cuda, rocm, etc) + key_combination (int, optional): choice for mapping value for + device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Returns: + dict: map to possible device names user can input mapped to desired + combination of name/path. + """ + from shark.iree_utils._common import iree_device_map + + driver = iree_device_map(driver) + device_list = get_all_devices(driver) + device_map = dict() + + def get_output_value(dev_dict): + if key_combination == 1: + return f"{driver}://{dev_dict['path']}" + if key_combination == 2: + return dev_dict["name"] + if key_combination == 3: + return dev_dict["name"], f"{driver}://{dev_dict['path']}" + + # mapping driver name to default device (driver://0) + device_map[f"{driver}"] = get_output_value(device_list[0]) + for i, device in enumerate(device_list): + # mapping with index + device_map[f"{driver}://{i}"] = get_output_value(device) + # mapping with full path + device_map[f"{driver}://{device['path']}"] = get_output_value(device) + return device_map + + +def map_device_to_name_path(device, key_combination=3): + """Gives the appropriate device data (supported name/path) for user + selected execution device + Args: + device (str): user + key_combination (int, optional): choice for mapping value for + device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Raises: + ValueError: + Returns: + str / tuple: returns the mapping str or tuple of mapping str for + the device depending on key_combination value + """ + driver = device.split("://")[0] + device_map = get_device_mapping(driver, key_combination) + try: + device_mapping = device_map[device] + except KeyError: + raise ValueError(f"Device '{device}' is not a valid device.") + return device_mapping + + +# Generate and return a new seed if the provided one is not in the +# supported range (including -1) +def sanitize_seed(seed: int | str): + seed = int(seed) + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + return seed + + +# take a seed expression in an input format and convert it to +# a list of integers, where possible +def parse_seed_input(seed_input: str | list | int): + if isinstance(seed_input, str): + try: + seed_input = json.loads(seed_input) + except (ValueError, TypeError): + seed_input = None + + if isinstance(seed_input, int): + return [seed_input] + if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input): + return seed_input -def get_resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) - return os.path.join(base_path, relative_path) + raise TypeError( + "Seed input must be an integer or an array of integers in JSON format" + ) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index c88a1e70cb..203edd0821 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -7,15 +7,17 @@ import logging import unittest from apps.shark_studio.api.llm import LanguageModel +import gc class LLMAPITest(unittest.TestCase): - def testLLMSimple(self): + def test01_LLMSmall(self): lm = LanguageModel( - "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "TinyPixel/small-llama2", hf_auth_token=None, - device="cpu-task", - external_weights="safetensors", + device="cpu", + precision="fp32", + quantization="None", ) count = 0 for msg, _ in lm.chat("hi, what are you?"): @@ -24,9 +26,11 @@ def testLLMSimple(self): count += 1 continue assert ( - msg.strip(" ") == "Hello" - ), f"LLM API failed to return correct response, expected 'Hello', received {msg}" + msg.strip(" ") == "Turkish Turkish Turkish" + ), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}" break + del lm + gc.collect() if __name__ == "__main__": diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 4726eef6e8..6e10cfaf6b 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -13,17 +13,21 @@ LanguageModel, ) +B_SYS, E_SYS = "", "" + def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] -language_model = None +def append_bot_prompt(history, input_prompt): + user_prompt = f"{input_prompt} {E_SYS} {E_SYS}" + history += user_prompt + return history -def create_prompt(model_name, history, prompt_prefix): - return "" +language_model = None def get_default_config(): @@ -41,9 +45,13 @@ def chat_fn( precision, download_vmfb, config_file, + streaming_llm, cli=False, ): global language_model + if streaming_llm and prompt_prefix == "Clear": + language_model = None + return "Clearing history...", "" if language_model is None: history[-1][-1] = "Getting the model ready..." yield history, "" @@ -52,8 +60,8 @@ def chat_fn( device=device, precision=precision, external_weights="safetensors", - external_weight_file="llama2_7b.safetensors", use_system_prompt=prompt_prefix, + streaming_llm=streaming_llm, ) history[-1][-1] = "Getting the model ready... Done" yield history, "" @@ -63,7 +71,7 @@ def chat_fn( prefill_time = 0 is_first = True for text, exec_time in language_model.chat(history): - history[-1][-1] = text + history[-1][-1] = f"{text}{E_SYS}" if is_first: prefill_time = exec_time is_first = False @@ -75,101 +83,6 @@ def chat_fn( yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" -def llm_chat_api(InputData: dict): - return None - print(f"Input keys : {InputData.keys()}") - # print(f"model : {InputData['model']}") - is_chat_completion_api = ( - "messages" in InputData.keys() - ) # else it is the legacy `completion` api - # For Debugging input data from API - # if is_chat_completion_api: - # print(f"message -> role : {InputData['messages'][0]['role']}") - # print(f"message -> content : {InputData['messages'][0]['content']}") - # else: - # print(f"prompt : {InputData['prompt']}") - # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now - global vicuna_model - model_name = InputData["model"] if "model" in InputData.keys() else "codegen" - model_path = llm_model_map[model_name] - device = "cpu-task" - precision = "fp16" - max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] - if max_toks is None: - max_toks = 128 if model_name == "codegen" else 512 - - # make it working for codegen first - from apps.language_models.scripts.vicuna import ( - UnshardedVicuna, - ) - - device_id = None - if vicuna_model == 0: - if "cuda" in device: - device = "cuda" - elif "sync" in device: - device = "cpu-sync" - elif "task" in device: - device = "cpu-task" - elif "vulkan" in device: - device_id = int(device.split("://")[1]) - device = "vulkan" - else: - print("unrecognized device") - - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=True, - load_mlir_from_shark_tank=True, - device_id=device_id, - ) - - # TODO: add role dict for different models - if is_chat_completion_api: - # TODO: add funtionality for multiple messages - prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) - else: - prompt = InputData["prompt"] - print("prompt = ", prompt) - - res = vicuna_model.generate(prompt) - res_op = None - for op in res: - res_op = op - - if is_chat_completion_api: - choices = [ - { - "index": 0, - "message": { - "role": "assistant", - "content": res_op, # since we are yeilding the result - }, - "finish_reason": "stop", # or length - } - ] - else: - choices = [ - { - "text": res_op, - "index": 0, - "logprobs": None, - "finish_reason": "stop", # or length - } - ] - end_time = dt.now().strftime("%Y%m%d%H%M%S%f") - return { - "id": end_time, - "object": "chat.completion" if is_chat_completion_api else "text_completion", - "created": int(end_time), - "choices": choices, - } - - def view_json_file(file_obj): content = "" with open(file_obj.name, "r") as fopen: @@ -200,7 +113,7 @@ def view_json_file(file_obj): ) precision = gr.Radio( label="Precision", - value="int4", + value="fp32", choices=[ # "int4", # "int8", @@ -213,12 +126,18 @@ def view_json_file(file_obj): with gr.Column(): download_vmfb = gr.Checkbox( label="Download vmfb from Shark tank if available", + value=False, + interactive=True, + visible=False, + ) + streaming_llm = gr.Checkbox( + label="Run in streaming mode (requires recompilation)", value=True, interactive=True, ) prompt_prefix = gr.Checkbox( label="Add System Prompt", - value=False, + value=True, interactive=True, ) @@ -241,8 +160,8 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): config_file = gr.File(label="Upload sharding configuration", visible=False) - json_view_button = gr.Button(label="View as JSON", visible=False) - json_view = gr.JSON(interactive=True, visible=False) + json_view_button = gr.Button("View as JSON", visible=False) + json_view = gr.JSON(visible=False) json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) @@ -262,6 +181,7 @@ def view_json_file(file_obj): precision, download_vmfb, config_file, + streaming_llm, ], outputs=[chatbot, tokens_time], show_progress=False, @@ -283,6 +203,7 @@ def view_json_file(file_obj): precision, download_vmfb, config_file, + streaming_llm, ], outputs=[chatbot, tokens_time], show_progress=False, @@ -295,4 +216,19 @@ def view_json_file(file_obj): cancels=[submit_event, submit_click_event], queue=False, ) - clear.click(lambda: None, None, [chatbot], queue=False) + clear.click( + fn=chat_fn, + inputs=[ + clear, + chatbot, + model, + device, + precision, + download_vmfb, + config_file, + streaming_llm, + ], + outputs=[chatbot, tokens_time], + show_progress=False, + queue=True, + ).then(lambda: None, None, [chatbot], queue=False) diff --git a/apps/shark_studio/web/utils.py b/apps/shark_studio/web/utils.py new file mode 100644 index 0000000000..4072491cbf --- /dev/null +++ b/apps/shark_studio/web/utils.py @@ -0,0 +1,12 @@ +import os +import sys + + +def get_available_devices(): + return ["cpu-task"] + + +def get_resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_path, relative_path)