From f6b249ad1137218036de5a93fdff8dd196c7613e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 8 Jan 2024 23:25:09 -0600 Subject: [PATCH 01/23] Streaming LLM (WIP) --- apps/shark_studio/api/llm.py | 32 +++++++++++++++++++++++++++++--- apps/shark_studio/web/ui/chat.py | 17 +++++++++++++---- apps/shark_studio/web/utils.py | 12 ++++++++++++ 3 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 apps/shark_studio/web/utils.py diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index a209d8d1ba..2db83658b5 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -4,7 +4,7 @@ get_iree_compiled_module, load_vmfb_using_mmap, ) -from apps.shark_studio.api.utils import get_resource_path +from apps.shark_studio.web.utils import get_resource_path import iree.runtime as ireert from itertools import chain import gc @@ -39,6 +39,7 @@ def __init__( precision="fp32", external_weights=None, use_system_prompt=True, + streaming_llm=False, ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] @@ -50,12 +51,15 @@ def __init__( self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None self.external_weight_file = None + self.streaming_llm = streaming_llm if external_weights is not None: self.external_weight_file = get_resource_path( self.safe_name + "." + external_weights ) self.use_system_prompt = use_system_prompt self.global_iter = 0 + self.prev_token_len = 0 + if os.path.exists(self.vmfb_name) and ( external_weights is None or os.path.exists(str(self.external_weight_file)) ): @@ -83,6 +87,7 @@ def __init__( compile_to="torch", external_weights=external_weights, external_weight_file=self.external_weight_file, + streaming_llm=self.streaming_llm, ) with open(self.tempfile_name, "w+") as f: f.write(self.torch_ir) @@ -106,7 +111,7 @@ def compile(self) -> None: frontend="torch", external_weight_file=self.external_weight_file, write_to=self.vmfb_name, - extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"], + extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"] if "cpu" in self.device else [], ) # TODO: delete the temp file @@ -129,20 +134,40 @@ def chat(self, prompt): input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids + if self.streaming_llm: + token_slice = max(self.prev_token_len - 1, 0) + input_tensor = input_tensor[:, token_slice:] + def format_out(results): return torch.tensor(results.to_host()[0][0]) history = [] for iter in range(self.max_tokens): + if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600: + print("Evicting cache space!") + self.iree_module_dict["vmfb"]["evict_kvcache_space"]() st_time = time.time() - if iter == 0: + token_len = input_tensor.shape[-1] + if iter == 0 and not self.streaming_llm: device_inputs = [ ireert.asdevicearray( self.iree_module_dict["config"].device, input_tensor ) ] token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) + token_len += 1 + elif iter == 0: + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"].device, input_tensor + ) + ] + token = self.iree_module_dict["vmfb"]["run_cached_initialize"](*device_inputs) + token_len += 1 else: + if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() device_inputs = [ ireert.asdevicearray( self.iree_module_dict["config"].device, @@ -153,6 +178,7 @@ def format_out(results): total_time = time.time() - st_time history.append(format_out(token)) + self.prev_token_len = token_len + len(history) yield self.tokenizer.decode(history), total_time if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 4726eef6e8..65d70fb0e9 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -41,6 +41,7 @@ def chat_fn( precision, download_vmfb, config_file, + streaming_llm, cli=False, ): global language_model @@ -52,8 +53,8 @@ def chat_fn( device=device, precision=precision, external_weights="safetensors", - external_weight_file="llama2_7b.safetensors", use_system_prompt=prompt_prefix, + streaming_llm=streaming_llm, ) history[-1][-1] = "Getting the model ready... Done" yield history, "" @@ -213,12 +214,18 @@ def view_json_file(file_obj): with gr.Column(): download_vmfb = gr.Checkbox( label="Download vmfb from Shark tank if available", + value=False, + interactive=True, + visible=False, + ) + streaming_llm = gr.Checkbox( + label="Run in streaming mode (requires recompilation)", value=True, interactive=True, ) prompt_prefix = gr.Checkbox( label="Add System Prompt", - value=False, + value=True, interactive=True, ) @@ -241,8 +248,8 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): config_file = gr.File(label="Upload sharding configuration", visible=False) - json_view_button = gr.Button(label="View as JSON", visible=False) - json_view = gr.JSON(interactive=True, visible=False) + json_view_button = gr.Button("View as JSON", visible=False) + json_view = gr.JSON(visible=False) json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) @@ -262,6 +269,7 @@ def view_json_file(file_obj): precision, download_vmfb, config_file, + streaming_llm, ], outputs=[chatbot, tokens_time], show_progress=False, @@ -283,6 +291,7 @@ def view_json_file(file_obj): precision, download_vmfb, config_file, + streaming_llm, ], outputs=[chatbot, tokens_time], show_progress=False, diff --git a/apps/shark_studio/web/utils.py b/apps/shark_studio/web/utils.py new file mode 100644 index 0000000000..4072491cbf --- /dev/null +++ b/apps/shark_studio/web/utils.py @@ -0,0 +1,12 @@ +import os +import sys + + +def get_available_devices(): + return ["cpu-task"] + + +def get_resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_path, relative_path) From 353d50f5185d7a2fce7ae0a136fd272794d2cbea Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 01:17:06 -0600 Subject: [PATCH 02/23] Update precision and add gpu support --- apps/shark_studio/api/llm.py | 57 +++++++- apps/shark_studio/api/utils.py | 223 ++++++++++++++++++++++++++++++- apps/shark_studio/web/ui/chat.py | 2 +- 3 files changed, 267 insertions(+), 15 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 2db83658b5..e984b2e698 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -28,7 +28,19 @@ "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", }, } +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "", "" +def append_user_prompt(history, input_prompt): + user_prompt = f"{B_INST} {input_prompt} {E_INST}" + history += user_prompt + return history + + +def append_bot_prompt(history, input_prompt): + user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}" + history += user_prompt + return history class LanguageModel: def __init__( @@ -36,17 +48,22 @@ def __init__( model_name, hf_auth_token=None, device=None, - precision="fp32", + quantization="int4", + precision="", external_weights=None, use_system_prompt=True, streaming_llm=False, ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.tempfile_name = get_resource_path("llm.torch.tempfile") - self.vmfb_name = get_resource_path("llm.vmfb.tempfile") - self.device = device - self.precision = precision + self.device = device.split("=>")[-1].strip() + self.driver = self.device.split("://")[0] + print(f" Selected {self.driver} as device driver") + self.precision = "fp32" if "cpu" in self.driver else "fp16" + self.quantization = quantization + self.tempfile_name = get_resource_path(f"llm_{self.precision}_{self.quantization}.tempfile") + #TODO: Tag vmfb with target triple of device instead of HAL backend + self.vmfb_name = get_resource_path(f"llm_{self.precision}_{self.quantization}_{self.driver}.vmfb.tempfile") self.safe_name = self.hf_model_name.strip("/").replace("/", "_") self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None @@ -70,7 +87,7 @@ def __init__( self.iree_module_dict["temp_file_to_unlink"], ) = load_vmfb_using_mmap( self.vmfb_name, - device, + self.driver, device_idx=0, rt_flags=[], external_weight_file=self.external_weight_file, @@ -87,6 +104,8 @@ def __init__( compile_to="torch", external_weights=external_weights, external_weight_file=self.external_weight_file, + precision=self.precision, + quantization=self.quantization, streaming_llm=self.streaming_llm, ) with open(self.tempfile_name, "w+") as f: @@ -104,6 +123,30 @@ def __init__( def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". + flags = [ + "--iree-input-type=torch", + "--mlir-print-debuginfo", + "--mlir-print-op-on-diagnostic=false", + "--iree-llvmcpu-target-cpu-features=host", + "--iree-llvmcpu-target-triple=x86_64-linux-gnu", + "--iree-stream-resource-index-bits=64", + "--iree-vm-target-index-bits=64", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-opt-const-expr-hoisting=False", + ] + if "cpu" in self.driver: + flags.extend( + [ + "--iree-global-opt-enable-quantized-matmul-reassociation", + "--iree-llvmcpu-enable-ukernels=all" + ] + ) + elif self.driver == "vulkan": + flags.extend( + [ + "--iree-stream-resource-max-allocation-size=4294967296" + ] + ) self.iree_module_dict = get_iree_compiled_module( self.tempfile_name, device=self.device, @@ -111,7 +154,7 @@ def compile(self) -> None: frontend="torch", external_weight_file=self.external_weight_file, write_to=self.vmfb_name, - extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"] if "cpu" in self.device else [], + extra_args=flags, ) # TODO: delete the temp file diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 4072491cbf..1258b66424 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -1,12 +1,221 @@ -import os -import sys +import numpy as np +import json +from random import ( + randint, + seed as seed_random, + getstate as random_getstate, + setstate as random_setstate, +) + +from pathlib import Path +#from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from cpuinfo import get_cpu_info + +# TODO: migrate these utils to studio +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, + get_iree_vulkan_runtime_flags, +) def get_available_devices(): - return ["cpu-task"] + def get_devices_by_name(driver_name): + from shark.iree_utils._common import iree_device_map + + device_list = [] + try: + driver_name = iree_device_map(driver_name) + device_list_dict = get_all_devices(driver_name) + print(f"{driver_name} devices are available.") + except: + print(f"{driver_name} devices are not available.") + else: + cpu_name = get_cpu_info()["brand_raw"] + for i, device in enumerate(device_list_dict): + device_name = ( + cpu_name if device["name"] == "default" else device["name"] + ) + if "local" in driver_name: + device_list.append( + f"{device_name} => {driver_name.replace('local', 'cpu')}" + ) + else: + # for drivers with single devices + # let the default device be selected without any indexing + if len(device_list_dict) == 1: + device_list.append(f"{device_name} => {driver_name}") + else: + device_list.append(f"{device_name} => {driver_name}://{i}") + return device_list + + set_iree_runtime_flags() + + available_devices = [] + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + ) + + vulkaninfo_list = get_all_vulkan_devices() + vulkan_devices = [] + id = 0 + for device in vulkaninfo_list: + vulkan_devices.append(f"{device.strip()} => vulkan://{id}") + id += 1 + if id != 0: + print(f"vulkan devices are available.") + available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) + cuda_devices = get_devices_by_name("cuda") + available_devices.extend(cuda_devices) + rocm_devices = get_devices_by_name("rocm") + available_devices.extend(rocm_devices) + cpu_device = get_devices_by_name("cpu-sync") + available_devices.extend(cpu_device) + cpu_device = get_devices_by_name("cpu-task") + available_devices.extend(cpu_device) + return available_devices + +def set_iree_runtime_flags(): + # TODO: This function should be device-agnostic and piped properly + # to general runtime driver init. + vulkan_runtime_flags = get_iree_vulkan_runtime_flags() + + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + + +def get_all_devices(driver_name): + """ + Inputs: driver_name + Returns a list of all the available devices for a given driver sorted by + the iree path names of the device as in --list_devices option in iree. + """ + from iree.runtime import get_driver + + driver = get_driver(driver_name) + device_list_src = driver.query_available_devices() + device_list_src.sort(key=lambda d: d["path"]) + return device_list_src + + +def get_device_mapping(driver, key_combination=3): + """This method ensures consistent device ordering when choosing + specific devices for execution + Args: + driver (str): execution driver (vulkan, cuda, rocm, etc) + key_combination (int, optional): choice for mapping value for + device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Returns: + dict: map to possible device names user can input mapped to desired + combination of name/path. + """ + from shark.iree_utils._common import iree_device_map + + driver = iree_device_map(driver) + device_list = get_all_devices(driver) + device_map = dict() + + def get_output_value(dev_dict): + if key_combination == 1: + return f"{driver}://{dev_dict['path']}" + if key_combination == 2: + return dev_dict["name"] + if key_combination == 3: + return dev_dict["name"], f"{driver}://{dev_dict['path']}" + + # mapping driver name to default device (driver://0) + device_map[f"{driver}"] = get_output_value(device_list[0]) + for i, device in enumerate(device_list): + # mapping with index + device_map[f"{driver}://{i}"] = get_output_value(device) + # mapping with full path + device_map[f"{driver}://{device['path']}"] = get_output_value(device) + return device_map + + +def get_opt_flags(model, precision="fp16"): + iree_flags = [] + if len(cmd_opts.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" + ) + if "rocm" in cmd_opts.device: + from shark.iree_utils.gpu_utils import get_iree_rocm_args + + rocm_args = get_iree_rocm_args() + iree_flags.extend(rocm_args) + if cmd_opts.iree_constant_folding == False: + iree_flags.append("--iree-opt-const-expr-hoisting=False") + iree_flags.append( + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + ) + if cmd_opts.data_tiling == False: + iree_flags.append("--iree-opt-data-tiling=False") + + if "vae" not in model: + # Due to lack of support for multi-reduce, we always collapse reduction + # dims before dispatch formation right now. + iree_flags += ["--iree-flow-collapse-reduction-dims"] + return iree_flags + + +def map_device_to_name_path(device, key_combination=3): + """Gives the appropriate device data (supported name/path) for user + selected execution device + Args: + device (str): user + key_combination (int, optional): choice for mapping value for + device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Raises: + ValueError: + Returns: + str / tuple: returns the mapping str or tuple of mapping str for + the device depending on key_combination value + """ + driver = device.split("://")[0] + device_map = get_device_mapping(driver, key_combination) + try: + device_mapping = device_map[device] + except KeyError: + raise ValueError(f"Device '{device}' is not a valid device.") + return device_mapping + + +# Generate and return a new seed if the provided one is not in the +# supported range (including -1) +def sanitize_seed(seed: int | str): + seed = int(seed) + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + return seed + + +# take a seed expression in an input format and convert it to +# a list of integers, where possible +def parse_seed_input(seed_input: str | list | int): + if isinstance(seed_input, str): + try: + seed_input = json.loads(seed_input) + except (ValueError, TypeError): + seed_input = None + + if isinstance(seed_input, int): + return [seed_input] + if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input): + return seed_input -def get_resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) - return os.path.join(base_path, relative_path) + raise TypeError( + "Seed input must be an integer or an array of integers in JSON format" + ) \ No newline at end of file diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 65d70fb0e9..cb038d4a78 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -201,7 +201,7 @@ def view_json_file(file_obj): ) precision = gr.Radio( label="Precision", - value="int4", + value="fp32", choices=[ # "int4", # "int8", From ee539719f4f17ac9f9e8f6ac57e3b110ae7e30db Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 10:54:59 -0600 Subject: [PATCH 03/23] (studio2) Separate weights generation for quantization support --- apps/shark_studio/api/llm.py | 40 +++++++++++--- apps/shark_studio/api/utils.py | 32 +++++------ apps/shark_studio/web/ui/chat.py | 95 -------------------------------- 3 files changed, 49 insertions(+), 118 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index e984b2e698..f4b1eaab4c 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,4 +1,5 @@ from turbine_models.custom_models import stateless_llama +from turbine_models.gen_external_params.gen_external_params import gen_external_params import time from shark.iree_utils.compile_utils import ( get_iree_compiled_module, @@ -28,6 +29,7 @@ "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", }, } + B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "", "" @@ -36,7 +38,6 @@ def append_user_prompt(history, input_prompt): history += user_prompt return history - def append_bot_prompt(history, input_prompt): user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}" history += user_prompt @@ -58,12 +59,19 @@ def __init__( self.hf_model_name = llm_model_map[model_name]["hf_model_name"] self.device = device.split("=>")[-1].strip() self.driver = self.device.split("://")[0] - print(f" Selected {self.driver} as device driver") - self.precision = "fp32" if "cpu" in self.driver else "fp16" + print(f"Selected {self.driver} as device driver") + self.precision = "f32" if "cpu" in self.driver else "f16" self.quantization = quantization - self.tempfile_name = get_resource_path(f"llm_{self.precision}_{self.quantization}.tempfile") + #TODO: find a programmatic solution for model arch spec instead of hardcoding llama2 + self.file_spec = "_".join([ + "llama2", + "streaming" if streaming_llm else "chat", + self.precision, + self.quantization, + ]) + self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") #TODO: Tag vmfb with target triple of device instead of HAL backend - self.vmfb_name = get_resource_path(f"llm_{self.precision}_{self.quantization}_{self.driver}.vmfb.tempfile") + self.vmfb_name = get_resource_path(f"{self.file_spec}_{self.driver}.vmfb.tempfile") self.safe_name = self.hf_model_name.strip("/").replace("/", "_") self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None @@ -71,12 +79,30 @@ def __init__( self.streaming_llm = streaming_llm if external_weights is not None: self.external_weight_file = get_resource_path( - self.safe_name + "." + external_weights + self.safe_name + + "_" + self.precision + + "_" + self.quantization + + "." + external_weights ) self.use_system_prompt = use_system_prompt self.global_iter = 0 self.prev_token_len = 0 - + if self.external_weight_file is not None: + if not os.path.exists(self.external_weight_file): + print( + f"External weight file {self.external_weight_file} does not exist. Generating..." + ) + gen_external_params( + hf_model_name=self.hf_model_name, + quantization=self.quantization, + weight_path=self.external_weight_file, + hf_auth_token=hf_auth_token, + precision=self.precision, + ) + else: + print( + f"External weight file {self.external_weight_file} found for {self.vmfb_name}" + ) if os.path.exists(self.vmfb_name) and ( external_weights is None or os.path.exists(str(self.external_weight_file)) ): diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 1258b66424..90f5c52c0d 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -140,22 +140,22 @@ def get_output_value(dev_dict): def get_opt_flags(model, precision="fp16"): iree_flags = [] - if len(cmd_opts.iree_vulkan_target_triple) > 0: - iree_flags.append( - f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" - ) - if "rocm" in cmd_opts.device: - from shark.iree_utils.gpu_utils import get_iree_rocm_args - - rocm_args = get_iree_rocm_args() - iree_flags.extend(rocm_args) - if cmd_opts.iree_constant_folding == False: - iree_flags.append("--iree-opt-const-expr-hoisting=False") - iree_flags.append( - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" - ) - if cmd_opts.data_tiling == False: - iree_flags.append("--iree-opt-data-tiling=False") + # if len(cmd_opts.iree_vulkan_target_triple) > 0: + # iree_flags.append( + # f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" + # ) + # if "rocm" in cmd_opts.device: + # from shark.iree_utils.gpu_utils import get_iree_rocm_args + + # rocm_args = get_iree_rocm_args() + # iree_flags.extend(rocm_args) + # if cmd_opts.iree_constant_folding == False: + # iree_flags.append("--iree-opt-const-expr-hoisting=False") + # iree_flags.append( + # "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + # ) + # if cmd_opts.data_tiling == False: + # iree_flags.append("--iree-opt-data-tiling=False") if "vae" not in model: # Due to lack of support for multi-reduce, we always collapse reduction diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index cb038d4a78..79d4085add 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -76,101 +76,6 @@ def chat_fn( yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" -def llm_chat_api(InputData: dict): - return None - print(f"Input keys : {InputData.keys()}") - # print(f"model : {InputData['model']}") - is_chat_completion_api = ( - "messages" in InputData.keys() - ) # else it is the legacy `completion` api - # For Debugging input data from API - # if is_chat_completion_api: - # print(f"message -> role : {InputData['messages'][0]['role']}") - # print(f"message -> content : {InputData['messages'][0]['content']}") - # else: - # print(f"prompt : {InputData['prompt']}") - # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now - global vicuna_model - model_name = InputData["model"] if "model" in InputData.keys() else "codegen" - model_path = llm_model_map[model_name] - device = "cpu-task" - precision = "fp16" - max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] - if max_toks is None: - max_toks = 128 if model_name == "codegen" else 512 - - # make it working for codegen first - from apps.language_models.scripts.vicuna import ( - UnshardedVicuna, - ) - - device_id = None - if vicuna_model == 0: - if "cuda" in device: - device = "cuda" - elif "sync" in device: - device = "cpu-sync" - elif "task" in device: - device = "cpu-task" - elif "vulkan" in device: - device_id = int(device.split("://")[1]) - device = "vulkan" - else: - print("unrecognized device") - - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=True, - load_mlir_from_shark_tank=True, - device_id=device_id, - ) - - # TODO: add role dict for different models - if is_chat_completion_api: - # TODO: add funtionality for multiple messages - prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) - else: - prompt = InputData["prompt"] - print("prompt = ", prompt) - - res = vicuna_model.generate(prompt) - res_op = None - for op in res: - res_op = op - - if is_chat_completion_api: - choices = [ - { - "index": 0, - "message": { - "role": "assistant", - "content": res_op, # since we are yeilding the result - }, - "finish_reason": "stop", # or length - } - ] - else: - choices = [ - { - "text": res_op, - "index": 0, - "logprobs": None, - "finish_reason": "stop", # or length - } - ] - end_time = dt.now().strftime("%Y%m%d%H%M%S%f") - return { - "id": end_time, - "object": "chat.completion" if is_chat_completion_api else "text_completion", - "created": int(end_time), - "choices": choices, - } - - def view_json_file(file_obj): content = "" with open(file_obj.name, "r") as fopen: From 79460ef9f7bc3cc94e14391bc7570ccb0f2eae7f Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 11:15:14 -0600 Subject: [PATCH 04/23] Small fixes to prompts, weights gen --- apps/shark_studio/api/llm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index f4b1eaab4c..d650309a64 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -95,7 +95,7 @@ def __init__( gen_external_params( hf_model_name=self.hf_model_name, quantization=self.quantization, - weight_path=self.external_weight_file, + weight_path="", hf_auth_token=hf_auth_token, precision=self.precision, ) @@ -200,6 +200,8 @@ def sanitize_prompt(self, prompt): def chat(self, prompt): prompt = self.sanitize_prompt(prompt) + user_prompt = input("User prompt: ") + prompt = append_user_prompt(prompt, user_prompt) input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids @@ -236,7 +238,7 @@ def format_out(results): else: if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600: print("Evicting cache space!") - self.model["evict_kvcache_space"]() + self.iree_module_dict["vmfb"]["evict_kvcache_space"]() device_inputs = [ ireert.asdevicearray( self.iree_module_dict["config"].device, @@ -248,7 +250,9 @@ def format_out(results): total_time = time.time() - st_time history.append(format_out(token)) self.prev_token_len = token_len + len(history) - yield self.tokenizer.decode(history), total_time + res = self.tokenizer.decode(history, skip_special_tokens=True) + prompt = append_bot_prompt(prompt, res) + yield prompt, total_time if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break From 82830543e02a0f44a3d6845702e1a290cad543b9 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 11:35:19 -0600 Subject: [PATCH 05/23] Adapt prompt changes to studio flow --- apps/shark_studio/api/llm.py | 6 ++---- apps/shark_studio/web/ui/chat.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index d650309a64..59722f11a8 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -72,7 +72,7 @@ def __init__( self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") #TODO: Tag vmfb with target triple of device instead of HAL backend self.vmfb_name = get_resource_path(f"{self.file_spec}_{self.driver}.vmfb.tempfile") - self.safe_name = self.hf_model_name.strip("/").replace("/", "_") + self.safe_name = self.hf_model_name.split("/")[-1].replace("-", "_") self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None self.external_weight_file = None @@ -95,7 +95,7 @@ def __init__( gen_external_params( hf_model_name=self.hf_model_name, quantization=self.quantization, - weight_path="", + weight_path=self.external_weight_file, hf_auth_token=hf_auth_token, precision=self.precision, ) @@ -200,8 +200,6 @@ def sanitize_prompt(self, prompt): def chat(self, prompt): prompt = self.sanitize_prompt(prompt) - user_prompt = input("User prompt: ") - prompt = append_user_prompt(prompt, user_prompt) input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 79d4085add..e82c5f94e7 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -13,9 +13,9 @@ LanguageModel, ) - def user(message, history): # Append the user's message to the conversation history + #message = f"{B_INST} {message} {E_INST}" return "", history + [[message, ""]] From a7b2231a7b8e63b9e4dd000424626c084bb48b4e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 11:44:49 -0600 Subject: [PATCH 06/23] Final commit --- apps/shark_studio/api/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 59722f11a8..6f5474afe8 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -249,8 +249,8 @@ def format_out(results): history.append(format_out(token)) self.prev_token_len = token_len + len(history) res = self.tokenizer.decode(history, skip_special_tokens=True) - prompt = append_bot_prompt(prompt, res) - yield prompt, total_time + #prompt = append_bot_prompt(prompt, res) + yield res, total_time if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break @@ -258,7 +258,7 @@ def format_out(results): for i in range(len(history)): if type(history[i]) != int: history[i] = int(history[i]) - result_output = self.tokenizer.decode(history) + result_output = self.tokenizer.decode(history, skip_special_tokens=True) self.global_iter += 1 return result_output, total_time From f456e878d4f2141621b90109f2078fe182eb973c Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 9 Jan 2024 12:45:39 -0600 Subject: [PATCH 07/23] Update utils.py --- apps/shark_studio/api/utils.py | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 90f5c52c0d..27a622f608 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -136,33 +136,7 @@ def get_output_value(dev_dict): # mapping with full path device_map[f"{driver}://{device['path']}"] = get_output_value(device) return device_map - - -def get_opt_flags(model, precision="fp16"): - iree_flags = [] - # if len(cmd_opts.iree_vulkan_target_triple) > 0: - # iree_flags.append( - # f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" - # ) - # if "rocm" in cmd_opts.device: - # from shark.iree_utils.gpu_utils import get_iree_rocm_args - - # rocm_args = get_iree_rocm_args() - # iree_flags.extend(rocm_args) - # if cmd_opts.iree_constant_folding == False: - # iree_flags.append("--iree-opt-const-expr-hoisting=False") - # iree_flags.append( - # "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" - # ) - # if cmd_opts.data_tiling == False: - # iree_flags.append("--iree-opt-data-tiling=False") - - if "vae" not in model: - # Due to lack of support for multi-reduce, we always collapse reduction - # dims before dispatch formation right now. - iree_flags += ["--iree-flow-collapse-reduction-dims"] - return iree_flags - + def map_device_to_name_path(device, key_combination=3): """Gives the appropriate device data (supported name/path) for user @@ -218,4 +192,4 @@ def parse_seed_input(seed_input: str | list | int): raise TypeError( "Seed input must be an integer or an array of integers in JSON format" - ) \ No newline at end of file + ) From df78dec98869e795b694a846f134f460e4a8fd0d Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 13:05:40 -0600 Subject: [PATCH 08/23] Remove outdated flag from llm compile flags. --- apps/shark_studio/api/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 6f5474afe8..0f5520257d 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -157,7 +157,6 @@ def compile(self) -> None: "--iree-llvmcpu-target-triple=x86_64-linux-gnu", "--iree-stream-resource-index-bits=64", "--iree-vm-target-index-bits=64", - "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-opt-const-expr-hoisting=False", ] if "cpu" in self.driver: From 163241a6477c8b51a5ca9a1bb028c8c92d62f69e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 17:45:21 -0600 Subject: [PATCH 09/23] (studio2) use turbine vmfbRunner --- apps/shark_studio/api/llm.py | 55 ++++++++++++++++++-------------- apps/shark_studio/web/ui/chat.py | 46 +++++++++++++------------- 2 files changed, 54 insertions(+), 47 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 0f5520257d..a867dc8b21 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,4 +1,5 @@ from turbine_models.custom_models import stateless_llama +from turbine_models.model_runner import vmfbRunner from turbine_models.gen_external_params.gen_external_params import gen_external_params import time from shark.iree_utils.compile_utils import ( @@ -106,18 +107,15 @@ def __init__( if os.path.exists(self.vmfb_name) and ( external_weights is None or os.path.exists(str(self.external_weight_file)) ): - self.iree_module_dict = dict() - ( - self.iree_module_dict["vmfb"], - self.iree_module_dict["config"], - self.iree_module_dict["temp_file_to_unlink"], - ) = load_vmfb_using_mmap( - self.vmfb_name, - self.driver, - device_idx=0, - rt_flags=[], - external_weight_file=self.external_weight_file, + self.runner = vmfbRunner( + device = self.driver, + vmfb_path=self.vmfb_name, + external_weight_path=self.external_weight_file, ) + if self.streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update self.tokenizer = AutoTokenizer.from_pretrained( self.hf_model_name, use_fast=False, @@ -181,6 +179,17 @@ def compile(self) -> None: write_to=self.vmfb_name, extra_args=flags, ) + del self.iree_module_dict + gc.collect() + self.runner = vmfbRunner( + device = self.driver, + vmfb_path=self.vmfb_name, + external_weight_path=self.external_weight_file, + ) + if self.streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update # TODO: delete the temp file def sanitize_prompt(self, prompt): @@ -211,45 +220,43 @@ def format_out(results): history = [] for iter in range(self.max_tokens): - if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600: + if self.streaming_llm and self.model["get_seq_step"]() > 600: print("Evicting cache space!") - self.iree_module_dict["vmfb"]["evict_kvcache_space"]() + self.model["evict_kvcache_space"]() st_time = time.time() token_len = input_tensor.shape[-1] if iter == 0 and not self.streaming_llm: device_inputs = [ ireert.asdevicearray( - self.iree_module_dict["config"].device, input_tensor + self.runner.config.device, input_tensor ) ] - token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) + token = self.model["run_initialize"](*device_inputs) token_len += 1 elif iter == 0: device_inputs = [ ireert.asdevicearray( - self.iree_module_dict["config"].device, input_tensor + self.runner.config.device, input_tensor ) ] - token = self.iree_module_dict["vmfb"]["run_cached_initialize"](*device_inputs) + token = self.model["run_cached_initialize"](*device_inputs) token_len += 1 else: - if self.streaming_llm and self.iree_module_dict["vmfb"]["get_seq_step"]() > 600: + if self.streaming_llm and self.model["get_seq_step"]() > 600: print("Evicting cache space!") - self.iree_module_dict["vmfb"]["evict_kvcache_space"]() + self.model["evict_kvcache_space"]() device_inputs = [ ireert.asdevicearray( - self.iree_module_dict["config"].device, + self.runner.config.device, token, ) ] - token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) + token = self.model["run_forward"](*device_inputs) total_time = time.time() - st_time history.append(format_out(token)) self.prev_token_len = token_len + len(history) - res = self.tokenizer.decode(history, skip_special_tokens=True) - #prompt = append_bot_prompt(prompt, res) - yield res, total_time + yield self.tokenizer.decode(history, skip_special_tokens=True), total_time if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index e82c5f94e7..b1d12aaec7 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -158,28 +158,28 @@ def view_json_file(file_obj): json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) - submit_event = msg.submit( - fn=user, - inputs=[msg, chatbot], - outputs=[msg, chatbot], - show_progress=False, - queue=False, - ).then( - fn=chat_fn, - inputs=[ - prompt_prefix, - chatbot, - model, - device, - precision, - download_vmfb, - config_file, - streaming_llm, - ], - outputs=[chatbot, tokens_time], - show_progress=False, - queue=True, - ) + # submit_event = msg.submit( + # fn=user, + # inputs=[msg, chatbot], + # outputs=[msg, chatbot], + # show_progress=False, + # queue=False, + # ).then( + # fn=chat_fn, + # inputs=[ + # prompt_prefix, + # chatbot, + # model, + # device, + # precision, + # download_vmfb, + # config_file, + # streaming_llm, + # ], + # outputs=[chatbot, tokens_time], + # show_progress=False, + # queue=True, + # ) submit_click_event = submit.click( fn=user, inputs=[msg, chatbot], @@ -206,7 +206,7 @@ def view_json_file(file_obj): fn=None, inputs=None, outputs=None, - cancels=[submit_event, submit_click_event], + cancels=[submit_click_event], #[submit_event, submit_click_event], queue=False, ) clear.click(lambda: None, None, [chatbot], queue=False) From 3ac56adb63a567da8171716308a193f20ab93857 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 19:51:32 -0600 Subject: [PATCH 10/23] fix streaming --- apps/shark_studio/api/llm.py | 64 +++++++++++++--------------- apps/shark_studio/web/ui/chat.py | 72 +++++++++++++++++++------------- 2 files changed, 74 insertions(+), 62 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index a867dc8b21..69057f0d2d 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -33,6 +33,9 @@ B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "", "" +DEFAULT_CHAT_SYS_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n +""" def append_user_prompt(history, input_prompt): user_prompt = f"{B_INST} {input_prompt} {E_INST}" @@ -88,6 +91,7 @@ def __init__( self.use_system_prompt = use_system_prompt self.global_iter = 0 self.prev_token_len = 0 + self.first_input=True if self.external_weight_file is not None: if not os.path.exists(self.external_weight_file): print( @@ -193,7 +197,6 @@ def compile(self) -> None: # TODO: delete the temp file def sanitize_prompt(self, prompt): - print(prompt) if isinstance(prompt, list): prompt = list(chain.from_iterable(prompt)) prompt = " ".join([x for x in prompt if isinstance(x, str)]) @@ -201,70 +204,63 @@ def sanitize_prompt(self, prompt): prompt = prompt.replace("\t", " ") prompt = prompt.replace("\r", " ") if self.use_system_prompt and self.global_iter == 0: - prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt - prompt += " [/INST]" + prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt) print(prompt) return prompt def chat(self, prompt): prompt = self.sanitize_prompt(prompt) + print(f"sanitized: {prompt}") input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids - if self.streaming_llm: - token_slice = max(self.prev_token_len - 1, 0) - input_tensor = input_tensor[:, token_slice:] - def format_out(results): return torch.tensor(results.to_host()[0][0]) history = [] for iter in range(self.max_tokens): + if self.streaming_llm: + token_slice = max(self.prev_token_len - 1, 0) + input_tensor = input_tensor[:, token_slice:] if self.streaming_llm and self.model["get_seq_step"]() > 600: print("Evicting cache space!") self.model["evict_kvcache_space"]() - st_time = time.time() token_len = input_tensor.shape[-1] - if iter == 0 and not self.streaming_llm: - device_inputs = [ + device_inputs = [ ireert.asdevicearray( self.runner.config.device, input_tensor ) ] + if self.first_input or not self.streaming_llm: + st_time = time.time() token = self.model["run_initialize"](*device_inputs) + total_time = time.time() - st_time token_len += 1 - elif iter == 0: - device_inputs = [ - ireert.asdevicearray( - self.runner.config.device, input_tensor - ) - ] + self.first_input=False + else: + st_time = time.time() token = self.model["run_cached_initialize"](*device_inputs) + total_time = time.time() - st_time token_len += 1 - else: - if self.streaming_llm and self.model["get_seq_step"]() > 600: - print("Evicting cache space!") - self.model["evict_kvcache_space"]() - device_inputs = [ - ireert.asdevicearray( - self.runner.config.device, - token, - ) - ] - token = self.model["run_forward"](*device_inputs) - - total_time = time.time() - st_time - history.append(format_out(token)) - self.prev_token_len = token_len + len(history) - yield self.tokenizer.decode(history, skip_special_tokens=True), total_time - + + + while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]: + dec_time=time.time() + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token = self.model["run_forward"](token) + history.append(format_out(token)) + total_time = time.time() - dec_time + self.prev_token_len = token_len + len(history) + yield self.tokenizer.decode(history), total_time if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: history[i] = int(history[i]) - result_output = self.tokenizer.decode(history, skip_special_tokens=True) + result_output = append_bot_prompt(history, self.tokenizer.decode(history)) self.global_iter += 1 return result_output, total_time diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index b1d12aaec7..49127b0af2 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -13,19 +13,17 @@ LanguageModel, ) + def user(message, history): + if message == "": + message = "Hello!" # Append the user's message to the conversation history - #message = f"{B_INST} {message} {E_INST}" return "", history + [[message, ""]] language_model = None -def create_prompt(model_name, history, prompt_prefix): - return "" - - def get_default_config(): return False @@ -45,6 +43,9 @@ def chat_fn( cli=False, ): global language_model + if streaming_llm and prompt_prefix=="Clear": + language_model = None + return "Clearing history...", "" if language_model is None: history[-1][-1] = "Getting the model ready..." yield history, "" @@ -158,28 +159,28 @@ def view_json_file(file_obj): json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) - # submit_event = msg.submit( - # fn=user, - # inputs=[msg, chatbot], - # outputs=[msg, chatbot], - # show_progress=False, - # queue=False, - # ).then( - # fn=chat_fn, - # inputs=[ - # prompt_prefix, - # chatbot, - # model, - # device, - # precision, - # download_vmfb, - # config_file, - # streaming_llm, - # ], - # outputs=[chatbot, tokens_time], - # show_progress=False, - # queue=True, - # ) + submit_event = msg.submit( + fn=user, + inputs=[msg, chatbot], + outputs=[msg, chatbot], + show_progress=False, + queue=False, + ).then( + fn=chat_fn, + inputs=[ + prompt_prefix, + chatbot, + model, + device, + precision, + download_vmfb, + config_file, + streaming_llm, + ], + outputs=[chatbot, tokens_time], + show_progress=False, + queue=True, + ) submit_click_event = submit.click( fn=user, inputs=[msg, chatbot], @@ -209,4 +210,19 @@ def view_json_file(file_obj): cancels=[submit_click_event], #[submit_event, submit_click_event], queue=False, ) - clear.click(lambda: None, None, [chatbot], queue=False) + clear.click( + fn=chat_fn, + inputs=[ + clear, + chatbot, + model, + device, + precision, + download_vmfb, + config_file, + streaming_llm, + ], + outputs=[chatbot, tokens_time], + show_progress=False, + queue=True, + ).then(lambda: None, None, [chatbot], queue=False) From a099d7611e8a94c34532145c2608a98ca173ef91 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 20:23:25 -0600 Subject: [PATCH 11/23] bugfixes --- apps/shark_studio/api/llm.py | 34 ++++++++++++++------------------ apps/shark_studio/web/ui/chat.py | 7 ++++++- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 69057f0d2d..bbdba683bb 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -32,7 +32,7 @@ } B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "", "" + DEFAULT_CHAT_SYS_PROMPT = """[INST] <> Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n """ @@ -42,11 +42,6 @@ def append_user_prompt(history, input_prompt): history += user_prompt return history -def append_bot_prompt(history, input_prompt): - user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}" - history += user_prompt - return history - class LanguageModel: def __init__( self, @@ -210,7 +205,6 @@ def sanitize_prompt(self, prompt): def chat(self, prompt): prompt = self.sanitize_prompt(prompt) - print(f"sanitized: {prompt}") input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids @@ -243,24 +237,26 @@ def format_out(results): total_time = time.time() - st_time token_len += 1 - - while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]: - dec_time=time.time() - if self.streaming_llm and self.model["get_seq_step"]() > 600: - print("Evicting cache space!") - self.model["evict_kvcache_space"]() - token = self.model["run_forward"](token) - history.append(format_out(token)) - total_time = time.time() - dec_time - self.prev_token_len = token_len + len(history) - yield self.tokenizer.decode(history), total_time + history.append(format_out(token)) + while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]: + dec_time=time.time() + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token = self.model["run_forward"](token) + history.append(format_out(token)) + total_time = time.time() - dec_time + yield self.tokenizer.decode(history), total_time + + self.prev_token_len = token_len + len(history) + if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: history[i] = int(history[i]) - result_output = append_bot_prompt(history, self.tokenizer.decode(history)) + result_output = self.tokenizer.decode(history) self.global_iter += 1 return result_output, total_time diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 49127b0af2..88c1db4a88 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -13,6 +13,7 @@ LanguageModel, ) +B_SYS, E_SYS = "", "" def user(message, history): if message == "": @@ -20,6 +21,10 @@ def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] +def append_bot_prompt(history, input_prompt): + user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}" + history += user_prompt + return history language_model = None @@ -65,7 +70,7 @@ def chat_fn( prefill_time = 0 is_first = True for text, exec_time in language_model.chat(history): - history[-1][-1] = text + history[-1][-1] = f"{text}{E_SYS} {E_SYS}" if is_first: prefill_time = exec_time is_first = False From 32e712b48e0fc3b30f9f58d0367f1c83be950542 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 20:30:22 -0600 Subject: [PATCH 12/23] Fix stopping --- apps/shark_studio/web/ui/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 88c1db4a88..46e12dfba3 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -212,7 +212,7 @@ def view_json_file(file_obj): fn=None, inputs=None, outputs=None, - cancels=[submit_click_event], #[submit_event, submit_click_event], + cancels=[submit_event, submit_click_event], queue=False, ) clear.click( From ead171512ff99a2650f288c4570da2761eeeaacf Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jan 2024 21:49:42 -0600 Subject: [PATCH 13/23] tweaks to prompts --- apps/shark_studio/api/llm.py | 13 +++++++++++-- apps/shark_studio/web/ui/chat.py | 7 +++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index bbdba683bb..8a45ed1873 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -32,11 +32,17 @@ } B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "", "" DEFAULT_CHAT_SYS_PROMPT = """[INST] <> Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n """ +def append_bot_prompt(history, input_prompt): + user_prompt = f" {input_prompt} {E_SYS}" + history += user_prompt + return history + def append_user_prompt(history, input_prompt): user_prompt = f"{B_INST} {input_prompt} {E_INST}" history += user_prompt @@ -200,8 +206,11 @@ def sanitize_prompt(self, prompt): prompt = prompt.replace("\r", " ") if self.use_system_prompt and self.global_iter == 0: prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt) - print(prompt) - return prompt + print(prompt) + return prompt + else: + print(prompt) + return f"{B_INST} {prompt} {E_INST}" def chat(self, prompt): prompt = self.sanitize_prompt(prompt) diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 46e12dfba3..fdeb897af3 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -16,13 +16,11 @@ B_SYS, E_SYS = "", "" def user(message, history): - if message == "": - message = "Hello!" # Append the user's message to the conversation history return "", history + [[message, ""]] def append_bot_prompt(history, input_prompt): - user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}" + user_prompt = f"{input_prompt} {E_SYS} {E_SYS}" history += user_prompt return history @@ -48,6 +46,7 @@ def chat_fn( cli=False, ): global language_model + print("Prompt prefix: ", prompt_prefix) if streaming_llm and prompt_prefix=="Clear": language_model = None return "Clearing history...", "" @@ -70,7 +69,7 @@ def chat_fn( prefill_time = 0 is_first = True for text, exec_time in language_model.chat(history): - history[-1][-1] = f"{text}{E_SYS} {E_SYS}" + history[-1][-1] = f"{text}{E_SYS}" if is_first: prefill_time = exec_time is_first = False From c195dd6cf1edfe799a8f07c0f9980d2fd3392d0e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 16 Jan 2024 11:35:28 -0600 Subject: [PATCH 14/23] Update CPU path and llm api test. --- apps/shark_studio/api/llm.py | 20 +++++++++----------- apps/shark_studio/tests/api_test.py | 8 +++++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 8a45ed1873..28f6c4f125 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -62,31 +62,30 @@ def __init__( ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.device = device.split("=>")[-1].strip() - self.driver = self.device.split("://")[0] - print(f"Selected {self.driver} as device driver") + self.device = device.split("=>")[-1].strip() if "cpu" not in device else "local-task" + self.driver = self.device.split("://")[0] if not any(x in self.device for x in ["cpu", "local-task"]) else "llvm-cpu" + print(f"Selected {self.driver} as IREE target backend.") self.precision = "f32" if "cpu" in self.driver else "f16" self.quantization = quantization + self.safe_name = self.hf_model_name.replace("/","_").replace("-", "_") #TODO: find a programmatic solution for model arch spec instead of hardcoding llama2 self.file_spec = "_".join([ - "llama2", - "streaming" if streaming_llm else "chat", + self.safe_name, self.precision, self.quantization, ]) + if streaming_llm: + self.file_spec += "_streaming" self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") #TODO: Tag vmfb with target triple of device instead of HAL backend self.vmfb_name = get_resource_path(f"{self.file_spec}_{self.driver}.vmfb.tempfile") - self.safe_name = self.hf_model_name.split("/")[-1].replace("-", "_") self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None self.external_weight_file = None self.streaming_llm = streaming_llm if external_weights is not None: self.external_weight_file = get_resource_path( - self.safe_name - + "_" + self.precision - + "_" + self.quantization + self.file_spec + "." + external_weights ) self.use_system_prompt = use_system_prompt @@ -113,7 +112,7 @@ def __init__( external_weights is None or os.path.exists(str(self.external_weight_file)) ): self.runner = vmfbRunner( - device = self.driver, + device = self.device, vmfb_path=self.vmfb_name, external_weight_path=self.external_weight_file, ) @@ -132,7 +131,6 @@ def __init__( hf_auth_token, compile_to="torch", external_weights=external_weights, - external_weight_file=self.external_weight_file, precision=self.precision, quantization=self.quantization, streaming_llm=self.streaming_llm, diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index c88a1e70cb..158d02ca7f 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -14,8 +14,10 @@ def testLLMSimple(self): lm = LanguageModel( "Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, - device="cpu-task", + device="local-task", external_weights="safetensors", + precision="fp32", + quantization="int4" ) count = 0 for msg, _ in lm.chat("hi, what are you?"): @@ -24,8 +26,8 @@ def testLLMSimple(self): count += 1 continue assert ( - msg.strip(" ") == "Hello" - ), f"LLM API failed to return correct response, expected 'Hello', received {msg}" + msg.strip(" ") == "Hello!" + ), f"LLM API failed to return correct response, expected 'Hello!', received {msg}" break From a788e12fc82726e5024735f93aac258b565d22ce Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 16 Jan 2024 11:38:02 -0600 Subject: [PATCH 15/23] Fix formatting. --- apps/shark_studio/api/llm.py | 72 ++++++++++++++++------------- apps/shark_studio/api/utils.py | 6 ++- apps/shark_studio/tests/api_test.py | 2 +- apps/shark_studio/web/ui/chat.py | 5 +- 4 files changed, 48 insertions(+), 37 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 28f6c4f125..9618951c95 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -38,16 +38,19 @@ Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n """ + def append_bot_prompt(history, input_prompt): user_prompt = f" {input_prompt} {E_SYS}" history += user_prompt return history + def append_user_prompt(history, input_prompt): user_prompt = f"{B_INST} {input_prompt} {E_INST}" history += user_prompt return history + class LanguageModel: def __init__( self, @@ -62,36 +65,45 @@ def __init__( ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.device = device.split("=>")[-1].strip() if "cpu" not in device else "local-task" - self.driver = self.device.split("://")[0] if not any(x in self.device for x in ["cpu", "local-task"]) else "llvm-cpu" + self.device = ( + device.split("=>")[-1].strip() if "cpu" not in device else "local-task" + ) + self.driver = ( + self.device.split("://")[0] + if not any(x in self.device for x in ["cpu", "local-task"]) + else "llvm-cpu" + ) print(f"Selected {self.driver} as IREE target backend.") self.precision = "f32" if "cpu" in self.driver else "f16" self.quantization = quantization - self.safe_name = self.hf_model_name.replace("/","_").replace("-", "_") - #TODO: find a programmatic solution for model arch spec instead of hardcoding llama2 - self.file_spec = "_".join([ - self.safe_name, - self.precision, - self.quantization, - ]) + self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_") + # TODO: find a programmatic solution for model arch spec instead of hardcoding llama2 + self.file_spec = "_".join( + [ + self.safe_name, + self.precision, + self.quantization, + ] + ) if streaming_llm: self.file_spec += "_streaming" self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") - #TODO: Tag vmfb with target triple of device instead of HAL backend - self.vmfb_name = get_resource_path(f"{self.file_spec}_{self.driver}.vmfb.tempfile") + # TODO: Tag vmfb with target triple of device instead of HAL backend + self.vmfb_name = get_resource_path( + f"{self.file_spec}_{self.driver}.vmfb.tempfile" + ) self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None self.external_weight_file = None self.streaming_llm = streaming_llm if external_weights is not None: self.external_weight_file = get_resource_path( - self.file_spec - + "." + external_weights + self.file_spec + "." + external_weights ) self.use_system_prompt = use_system_prompt self.global_iter = 0 self.prev_token_len = 0 - self.first_input=True + self.first_input = True if self.external_weight_file is not None: if not os.path.exists(self.external_weight_file): print( @@ -112,7 +124,7 @@ def __init__( external_weights is None or os.path.exists(str(self.external_weight_file)) ): self.runner = vmfbRunner( - device = self.device, + device=self.device, vmfb_path=self.vmfb_name, external_weight_path=self.external_weight_file, ) @@ -163,16 +175,12 @@ def compile(self) -> None: if "cpu" in self.driver: flags.extend( [ - "--iree-global-opt-enable-quantized-matmul-reassociation", - "--iree-llvmcpu-enable-ukernels=all" + "--iree-global-opt-enable-quantized-matmul-reassociation", + "--iree-llvmcpu-enable-ukernels=all", ] ) elif self.driver == "vulkan": - flags.extend( - [ - "--iree-stream-resource-max-allocation-size=4294967296" - ] - ) + flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) self.iree_module_dict = get_iree_compiled_module( self.tempfile_name, device=self.device, @@ -185,10 +193,10 @@ def compile(self) -> None: del self.iree_module_dict gc.collect() self.runner = vmfbRunner( - device = self.driver, - vmfb_path=self.vmfb_name, - external_weight_path=self.external_weight_file, - ) + device=self.driver, + vmfb_path=self.vmfb_name, + external_weight_path=self.external_weight_file, + ) if self.streaming_llm: self.model = self.runner.ctx.modules.streaming_state_update else: @@ -228,25 +236,23 @@ def format_out(results): self.model["evict_kvcache_space"]() token_len = input_tensor.shape[-1] device_inputs = [ - ireert.asdevicearray( - self.runner.config.device, input_tensor - ) - ] + ireert.asdevicearray(self.runner.config.device, input_tensor) + ] if self.first_input or not self.streaming_llm: st_time = time.time() token = self.model["run_initialize"](*device_inputs) total_time = time.time() - st_time token_len += 1 - self.first_input=False + self.first_input = False else: st_time = time.time() token = self.model["run_cached_initialize"](*device_inputs) total_time = time.time() - st_time token_len += 1 - + history.append(format_out(token)) while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]: - dec_time=time.time() + dec_time = time.time() if self.streaming_llm and self.model["get_seq_step"]() > 600: print("Evicting cache space!") self.model["evict_kvcache_space"]() diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 27a622f608..7a6e9bb4b7 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -8,7 +8,8 @@ ) from pathlib import Path -#from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + +# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from cpuinfo import get_cpu_info # TODO: migrate these utils to studio @@ -77,6 +78,7 @@ def get_devices_by_name(driver_name): available_devices.extend(cpu_device) return available_devices + def set_iree_runtime_flags(): # TODO: This function should be device-agnostic and piped properly # to general runtime driver init. @@ -136,7 +138,7 @@ def get_output_value(dev_dict): # mapping with full path device_map[f"{driver}://{device['path']}"] = get_output_value(device) return device_map - + def map_device_to_name_path(device, key_combination=3): """Gives the appropriate device data (supported name/path) for user diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index 158d02ca7f..27e8c3c26e 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -17,7 +17,7 @@ def testLLMSimple(self): device="local-task", external_weights="safetensors", precision="fp32", - quantization="int4" + quantization="int4", ) count = 0 for msg, _ in lm.chat("hi, what are you?"): diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index fdeb897af3..db5b40a469 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -15,15 +15,18 @@ B_SYS, E_SYS = "", "" + def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] + def append_bot_prompt(history, input_prompt): user_prompt = f"{input_prompt} {E_SYS} {E_SYS}" history += user_prompt return history + language_model = None @@ -47,7 +50,7 @@ def chat_fn( ): global language_model print("Prompt prefix: ", prompt_prefix) - if streaming_llm and prompt_prefix=="Clear": + if streaming_llm and prompt_prefix == "Clear": language_model = None return "Clearing history...", "" if language_model is None: From c48c5624df1fa260049d13ce8e0840940deb874b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 16 Jan 2024 11:44:16 -0600 Subject: [PATCH 16/23] Remove unused function. --- apps/shark_studio/api/llm.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 9618951c95..2273135c5c 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -38,13 +38,6 @@ Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n """ - -def append_bot_prompt(history, input_prompt): - user_prompt = f" {input_prompt} {E_SYS}" - history += user_prompt - return history - - def append_user_prompt(history, input_prompt): user_prompt = f"{B_INST} {input_prompt} {E_INST}" history += user_prompt From 643214da49b37ecf7a57bf1574baaffd998ce64a Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 16 Jan 2024 12:02:39 -0600 Subject: [PATCH 17/23] Change device in test to cpu. --- apps/shark_studio/tests/api_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index 27e8c3c26e..4eb2f37161 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -14,7 +14,7 @@ def testLLMSimple(self): lm = LanguageModel( "Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, - device="local-task", + device="cpu", external_weights="safetensors", precision="fp32", quantization="int4", From dccd5857d87bea33410be89a9bec6156cdbbdacf Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 16 Jan 2024 18:33:17 -0600 Subject: [PATCH 18/23] Fixes to runner, device names, vmfb mgmt --- apps/shark_studio/api/llm.py | 84 +++++++++++++++++------------ apps/shark_studio/tests/api_test.py | 30 ++++++++++- 2 files changed, 78 insertions(+), 36 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 2273135c5c..38c1b65c43 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -2,10 +2,7 @@ from turbine_models.model_runner import vmfbRunner from turbine_models.gen_external_params.gen_external_params import gen_external_params import time -from shark.iree_utils.compile_utils import ( - get_iree_compiled_module, - load_vmfb_using_mmap, -) +from shark.iree_utils.compile_utils import compile_module_to_flatbuffer from apps.shark_studio.web.utils import get_resource_path import iree.runtime as ireert from itertools import chain @@ -18,6 +15,7 @@ "llama2_7b": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], "stop_token": 2, "max_tokens": 4096, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", @@ -25,6 +23,23 @@ "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], + "stop_token": 2, + "max_tokens": 4096, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, + "TinyPixel/small-llama2": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "TinyPixel/small-llama2", + "compile_flags": ["--iree-opt-const-expr-hoisting=True"], + "stop_token": 2, + "max_tokens": 1024, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, + "anushehchaudry/llama-2-tiny-random": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "anushehchaudry/llama-2-tiny-random", + "compile_flags": ["--iree-opt-const-expr-hoisting=True"], "stop_token": 2, "max_tokens": 4096, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", @@ -38,6 +53,7 @@ Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n """ + def append_user_prompt(history, input_prompt): user_prompt = f"{B_INST} {input_prompt} {E_INST}" history += user_prompt @@ -56,43 +72,47 @@ def __init__( use_system_prompt=True, streaming_llm=False, ): - print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.device = ( - device.split("=>")[-1].strip() if "cpu" not in device else "local-task" - ) - self.driver = ( - self.device.split("://")[0] - if not any(x in self.device for x in ["cpu", "local-task"]) - else "llvm-cpu" - ) - print(f"Selected {self.driver} as IREE target backend.") - self.precision = "f32" if "cpu" in self.driver else "f16" + self.device = device.split("=>")[-1].strip() + self.backend = self.device.split("://")[0] + self.driver = self.backend + if "cpu" in device: + self.device = "cpu" + self.backend = "llvm-cpu" + self.driver = "local-task" + + print(f"Selected {self.backend} as IREE target backend.") + self.precision = "f32" if "cpu" in device else "f16" self.quantization = quantization self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_") + self.external_weight_file = None # TODO: find a programmatic solution for model arch spec instead of hardcoding llama2 self.file_spec = "_".join( [ self.safe_name, self.precision, - self.quantization, ] ) + if self.quantization != "None": + self.file_spec += "_" + self.quantization + + if external_weights is not None: + self.external_weight_file = get_resource_path( + self.file_spec + "." + external_weights + ) + if streaming_llm: + # Add streaming suffix to file spec after setting external weights filename. self.file_spec += "_streaming" + self.streaming_llm = streaming_llm + self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") # TODO: Tag vmfb with target triple of device instead of HAL backend self.vmfb_name = get_resource_path( - f"{self.file_spec}_{self.driver}.vmfb.tempfile" + f"{self.file_spec}_{self.backend}.vmfb.tempfile" ) self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None - self.external_weight_file = None - self.streaming_llm = streaming_llm - if external_weights is not None: - self.external_weight_file = get_resource_path( - self.file_spec + "." + external_weights - ) self.use_system_prompt = use_system_prompt self.global_iter = 0 self.prev_token_len = 0 @@ -117,7 +137,7 @@ def __init__( external_weights is None or os.path.exists(str(self.external_weight_file)) ): self.runner = vmfbRunner( - device=self.device, + device=self.driver, vmfb_path=self.vmfb_name, external_weight_path=self.external_weight_file, ) @@ -163,28 +183,25 @@ def compile(self) -> None: "--iree-llvmcpu-target-triple=x86_64-linux-gnu", "--iree-stream-resource-index-bits=64", "--iree-vm-target-index-bits=64", - "--iree-opt-const-expr-hoisting=False", ] - if "cpu" in self.driver: + if "cpu" in self.backend: flags.extend( [ "--iree-global-opt-enable-quantized-matmul-reassociation", "--iree-llvmcpu-enable-ukernels=all", ] ) - elif self.driver == "vulkan": + elif self.backend == "vulkan": flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) - self.iree_module_dict = get_iree_compiled_module( + flags.extend(llm_model_map[self.hf_model_name]["compile_flags"]) + flatbuffer_blob = compile_module_to_flatbuffer( self.tempfile_name, device=self.device, - mmap=True, frontend="torch", - external_weight_file=self.external_weight_file, - write_to=self.vmfb_name, + model_config_path=None, extra_args=flags, + write_to=self.vmfb_name, ) - del self.iree_module_dict - gc.collect() self.runner = vmfbRunner( device=self.driver, vmfb_path=self.vmfb_name, @@ -194,7 +211,6 @@ def compile(self) -> None: self.model = self.runner.ctx.modules.streaming_state_update else: self.model = self.runner.ctx.modules.state_update - # TODO: delete the temp file def sanitize_prompt(self, prompt): if isinstance(prompt, list): diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index 4eb2f37161..cf6813c28d 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -7,20 +7,44 @@ import logging import unittest from apps.shark_studio.api.llm import LanguageModel +import gc class LLMAPITest(unittest.TestCase): - def testLLMSimple(self): + def test01_LLMSmall(self): lm = LanguageModel( + "TinyPixel/small-llama2", + hf_auth_token=None, + device="cpu", + external_weights="safetensors", + precision="fp32", + quantization="None", + ) + count = 0 + for msg, _ in lm.chat("hi, what are you?"): + # skip first token output + if count == 0: + count += 1 + continue + assert ( + msg.strip(" ") == "Turkish Turkish Turkish" + ), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}" + break + del lm + gc.collect() + + def test02_stream(self): + llama2 = LanguageModel( "Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, device="cpu", external_weights="safetensors", precision="fp32", quantization="int4", + streaming_llm=True, ) count = 0 - for msg, _ in lm.chat("hi, what are you?"): + for msg, _ in llama2.chat("hi, what are you?"): # skip first token output if count == 0: count += 1 @@ -29,6 +53,8 @@ def testLLMSimple(self): msg.strip(" ") == "Hello!" ), f"LLM API failed to return correct response, expected 'Hello!', received {msg}" break + del llama2 + gc.collect() if __name__ == "__main__": From 1ed5543746090f4ceb9b30e9b457e9740cb9b8cf Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 16 Jan 2024 18:45:12 -0600 Subject: [PATCH 19/23] Only run small tests. --- apps/shark_studio/tests/api_test.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index cf6813c28d..d341826b62 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -33,29 +33,6 @@ def test01_LLMSmall(self): del lm gc.collect() - def test02_stream(self): - llama2 = LanguageModel( - "Trelis/Llama-2-7b-chat-hf-function-calling-v2", - hf_auth_token=None, - device="cpu", - external_weights="safetensors", - precision="fp32", - quantization="int4", - streaming_llm=True, - ) - count = 0 - for msg, _ in llama2.chat("hi, what are you?"): - # skip first token output - if count == 0: - count += 1 - continue - assert ( - msg.strip(" ") == "Hello!" - ), f"LLM API failed to return correct response, expected 'Hello!', received {msg}" - break - del llama2 - gc.collect() - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) From 9ff8fa71c3ee0943563e749c03c27996a6502c1c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 17 Jan 2024 11:03:57 -0600 Subject: [PATCH 20/23] Use small test without external weights. --- apps/shark_studio/tests/api_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index d341826b62..203edd0821 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -16,7 +16,6 @@ def test01_LLMSmall(self): "TinyPixel/small-llama2", hf_auth_token=None, device="cpu", - external_weights="safetensors", precision="fp32", quantization="None", ) From ed780b058c8d60e96d3535aaa56aac69f8cd8120 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:15:41 -0600 Subject: [PATCH 21/23] remove redundant compile flags --- apps/shark_studio/api/llm.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 38c1b65c43..b925eb8470 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -36,14 +36,6 @@ "max_tokens": 1024, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", }, - "anushehchaudry/llama-2-tiny-random": { - "initializer": stateless_llama.export_transformer_model, - "hf_model_name": "anushehchaudry/llama-2-tiny-random", - "compile_flags": ["--iree-opt-const-expr-hoisting=True"], - "stop_token": 2, - "max_tokens": 4096, - "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", - }, } B_INST, E_INST = "[INST]", "[/INST]" @@ -175,20 +167,12 @@ def __init__( def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". - flags = [ - "--iree-input-type=torch", - "--mlir-print-debuginfo", - "--mlir-print-op-on-diagnostic=false", - "--iree-llvmcpu-target-cpu-features=host", - "--iree-llvmcpu-target-triple=x86_64-linux-gnu", - "--iree-stream-resource-index-bits=64", - "--iree-vm-target-index-bits=64", - ] + # ONLY architecture/api-specific compile-time flags for each backend, if needed. + # hf_model_id-specific global flags currently in model map. if "cpu" in self.backend: flags.extend( [ "--iree-global-opt-enable-quantized-matmul-reassociation", - "--iree-llvmcpu-enable-ukernels=all", ] ) elif self.backend == "vulkan": From 29c40f435cea7a4970984bd0990e4cde8fc7ccac Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:17:10 -0600 Subject: [PATCH 22/23] remove print of prompt prefix --- apps/shark_studio/web/ui/chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index db5b40a469..6e10cfaf6b 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -49,7 +49,6 @@ def chat_fn( cli=False, ): global language_model - print("Prompt prefix: ", prompt_prefix) if streaming_llm and prompt_prefix == "Clear": language_model = None return "Clearing history...", "" From 0c6e5cad9a018f5f238037c6e3c8042122024eb1 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 17 Jan 2024 13:21:03 -0600 Subject: [PATCH 23/23] Fix flags. --- apps/shark_studio/api/llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index b925eb8470..a9d39f8e7b 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -169,6 +169,7 @@ def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". # ONLY architecture/api-specific compile-time flags for each backend, if needed. # hf_model_id-specific global flags currently in model map. + flags = [] if "cpu" in self.backend: flags.extend( [