From 25312cd791b78a426f55dae3ecc62c0e13c055c9 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 18 Jan 2024 19:01:07 -0600 Subject: [PATCH] Add StreamingLLM support to studio2 chat (#2060) * Streaming LLM * Update precision and add gpu support * (studio2) Separate weights generation for quantization support * Adapt prompt changes to studio flow * Remove outdated flag from llm compile flags. * (studio2) use turbine vmfbRunner * tweaks to prompts * Update CPU path and llm api test. * Change device in test to cpu. * Fixes to runner, device names, vmfb mgmt * Use small test without external weights. --- apps/shark_studio/api/llm.py | 210 +++++++++++++++++++++------- apps/shark_studio/tests/api_test.py | 17 ++- apps/shark_studio/web/ui/chat.py | 142 ++++++------------- apps/shark_studio/web/utils.py | 12 ++ 4 files changed, 218 insertions(+), 163 deletions(-) create mode 100644 apps/shark_studio/web/utils.py diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 852f5eff58..a9d39f8e7b 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,10 +1,9 @@ from turbine_models.custom_models import stateless_llama +from turbine_models.model_runner import vmfbRunner +from turbine_models.gen_external_params.gen_external_params import gen_external_params import time -from shark.iree_utils.compile_utils import ( - get_iree_compiled_module, - load_vmfb_using_mmap, -) -from apps.shark_studio.web.utils.file_utils import get_resource_path +from shark.iree_utils.compile_utils import compile_module_to_flatbuffer +from apps.shark_studio.web.utils import get_resource_path import iree.runtime as ireert from itertools import chain import gc @@ -16,6 +15,7 @@ "llama2_7b": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], "stop_token": 2, "max_tokens": 4096, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", @@ -23,12 +23,34 @@ "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "compile_flags": ["--iree-opt-const-expr-hoisting=False"], "stop_token": 2, "max_tokens": 4096, "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", }, + "TinyPixel/small-llama2": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "TinyPixel/small-llama2", + "compile_flags": ["--iree-opt-const-expr-hoisting=True"], + "stop_token": 2, + "max_tokens": 1024, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, } +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "", "" + +DEFAULT_CHAT_SYS_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n +""" + + +def append_user_prompt(history, input_prompt): + user_prompt = f"{B_INST} {input_prompt} {E_INST}" + history += user_prompt + return history + class LanguageModel: def __init__( @@ -36,41 +58,85 @@ def __init__( model_name, hf_auth_token=None, device=None, - precision="fp32", + quantization="int4", + precision="", external_weights=None, use_system_prompt=True, + streaming_llm=False, ): - print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.tempfile_name = get_resource_path("llm.torch.tempfile") - self.vmfb_name = get_resource_path("llm.vmfb.tempfile") - self.device = device - self.precision = precision - self.safe_name = self.hf_model_name.strip("/").replace("/", "_") - self.max_tokens = llm_model_map[model_name]["max_tokens"] - self.iree_module_dict = None + self.device = device.split("=>")[-1].strip() + self.backend = self.device.split("://")[0] + self.driver = self.backend + if "cpu" in device: + self.device = "cpu" + self.backend = "llvm-cpu" + self.driver = "local-task" + + print(f"Selected {self.backend} as IREE target backend.") + self.precision = "f32" if "cpu" in device else "f16" + self.quantization = quantization + self.safe_name = self.hf_model_name.replace("/", "_").replace("-", "_") self.external_weight_file = None + # TODO: find a programmatic solution for model arch spec instead of hardcoding llama2 + self.file_spec = "_".join( + [ + self.safe_name, + self.precision, + ] + ) + if self.quantization != "None": + self.file_spec += "_" + self.quantization + if external_weights is not None: self.external_weight_file = get_resource_path( - self.safe_name + "." + external_weights + self.file_spec + "." + external_weights ) + + if streaming_llm: + # Add streaming suffix to file spec after setting external weights filename. + self.file_spec += "_streaming" + self.streaming_llm = streaming_llm + + self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") + # TODO: Tag vmfb with target triple of device instead of HAL backend + self.vmfb_name = get_resource_path( + f"{self.file_spec}_{self.backend}.vmfb.tempfile" + ) + self.max_tokens = llm_model_map[model_name]["max_tokens"] + self.iree_module_dict = None self.use_system_prompt = use_system_prompt self.global_iter = 0 + self.prev_token_len = 0 + self.first_input = True + if self.external_weight_file is not None: + if not os.path.exists(self.external_weight_file): + print( + f"External weight file {self.external_weight_file} does not exist. Generating..." + ) + gen_external_params( + hf_model_name=self.hf_model_name, + quantization=self.quantization, + weight_path=self.external_weight_file, + hf_auth_token=hf_auth_token, + precision=self.precision, + ) + else: + print( + f"External weight file {self.external_weight_file} found for {self.vmfb_name}" + ) if os.path.exists(self.vmfb_name) and ( external_weights is None or os.path.exists(str(self.external_weight_file)) ): - self.iree_module_dict = dict() - ( - self.iree_module_dict["vmfb"], - self.iree_module_dict["config"], - self.iree_module_dict["temp_file_to_unlink"], - ) = load_vmfb_using_mmap( - self.vmfb_name, - device, - device_idx=0, - rt_flags=[], - external_weight_file=self.external_weight_file, + self.runner = vmfbRunner( + device=self.driver, + vmfb_path=self.vmfb_name, + external_weight_path=self.external_weight_file, ) + if self.streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update self.tokenizer = AutoTokenizer.from_pretrained( self.hf_model_name, use_fast=False, @@ -82,7 +148,9 @@ def __init__( hf_auth_token, compile_to="torch", external_weights=external_weights, - external_weight_file=self.external_weight_file, + precision=self.precision, + quantization=self.quantization, + streaming_llm=self.streaming_llm, ) with open(self.tempfile_name, "w+") as f: f.write(self.torch_ir) @@ -99,19 +167,37 @@ def __init__( def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". - self.iree_module_dict = get_iree_compiled_module( + # ONLY architecture/api-specific compile-time flags for each backend, if needed. + # hf_model_id-specific global flags currently in model map. + flags = [] + if "cpu" in self.backend: + flags.extend( + [ + "--iree-global-opt-enable-quantized-matmul-reassociation", + ] + ) + elif self.backend == "vulkan": + flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) + flags.extend(llm_model_map[self.hf_model_name]["compile_flags"]) + flatbuffer_blob = compile_module_to_flatbuffer( self.tempfile_name, device=self.device, - mmap=True, frontend="torch", - external_weight_file=self.external_weight_file, + model_config_path=None, + extra_args=flags, write_to=self.vmfb_name, - extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"], ) - # TODO: delete the temp file + self.runner = vmfbRunner( + device=self.driver, + vmfb_path=self.vmfb_name, + external_weight_path=self.external_weight_file, + ) + if self.streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update def sanitize_prompt(self, prompt): - print(prompt) if isinstance(prompt, list): prompt = list(chain.from_iterable(prompt)) prompt = " ".join([x for x in prompt if isinstance(x, str)]) @@ -119,10 +205,12 @@ def sanitize_prompt(self, prompt): prompt = prompt.replace("\t", " ") prompt = prompt.replace("\r", " ") if self.use_system_prompt and self.global_iter == 0: - prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt - prompt += " [/INST]" - print(prompt) - return prompt + prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt) + print(prompt) + return prompt + else: + print(prompt) + return f"{B_INST} {prompt} {E_INST}" def chat(self, prompt): prompt = self.sanitize_prompt(prompt) @@ -134,26 +222,40 @@ def format_out(results): history = [] for iter in range(self.max_tokens): - st_time = time.time() - if iter == 0: - device_inputs = [ - ireert.asdevicearray( - self.iree_module_dict["config"].device, input_tensor - ) - ] - token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) + if self.streaming_llm: + token_slice = max(self.prev_token_len - 1, 0) + input_tensor = input_tensor[:, token_slice:] + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token_len = input_tensor.shape[-1] + device_inputs = [ + ireert.asdevicearray(self.runner.config.device, input_tensor) + ] + if self.first_input or not self.streaming_llm: + st_time = time.time() + token = self.model["run_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 + self.first_input = False else: - device_inputs = [ - ireert.asdevicearray( - self.iree_module_dict["config"].device, - token, - ) - ] - token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) + st_time = time.time() + token = self.model["run_cached_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 - total_time = time.time() - st_time history.append(format_out(token)) - yield self.tokenizer.decode(history), total_time + while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]: + dec_time = time.time() + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + token = self.model["run_forward"](token) + history.append(format_out(token)) + total_time = time.time() - dec_time + yield self.tokenizer.decode(history), total_time + + self.prev_token_len = token_len + len(history) if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index f9fa23df4f..bbaa813c06 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -7,6 +7,8 @@ import logging import unittest import json +from apps.shark_studio.api.llm import LanguageModel +import gc from apps.shark_studio.api.llm import LanguageModel from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file @@ -28,12 +30,13 @@ def testSDSimple(self): print(i) class LLMAPITest(unittest.TestCase): - def testLLMSimple(self): + def test01_LLMSmall(self): lm = LanguageModel( - "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "TinyPixel/small-llama2", hf_auth_token=None, - device="cpu-task", - external_weights="safetensors", + device="cpu", + precision="fp32", + quantization="None", ) count = 0 for msg, _ in lm.chat("hi, what are you?"): @@ -42,9 +45,11 @@ def testLLMSimple(self): count += 1 continue assert ( - msg.strip(" ") == "Hello" - ), f"LLM API failed to return correct response, expected 'Hello', received {msg}" + msg.strip(" ") == "Turkish Turkish Turkish" + ), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}" break + del lm + gc.collect() if __name__ == "__main__": diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 917ac870bf..f34f89bc78 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -11,17 +11,21 @@ ) import apps.shark_studio.web.utils.globals as global_obj +B_SYS, E_SYS = "", "" + def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] -language_model = None +def append_bot_prompt(history, input_prompt): + user_prompt = f"{input_prompt} {E_SYS} {E_SYS}" + history += user_prompt + return history -def create_prompt(model_name, history, prompt_prefix): - return "" +language_model = None def get_default_config(): @@ -39,9 +43,13 @@ def chat_fn( precision, download_vmfb, config_file, + streaming_llm, cli=False, ): global language_model + if streaming_llm and prompt_prefix == "Clear": + language_model = None + return "Clearing history...", "" if language_model is None: history[-1][-1] = "Getting the model ready..." yield history, "" @@ -50,8 +58,8 @@ def chat_fn( device=device, precision=precision, external_weights="safetensors", - external_weight_file="llama2_7b.safetensors", use_system_prompt=prompt_prefix, + streaming_llm=streaming_llm, ) history[-1][-1] = "Getting the model ready... Done" yield history, "" @@ -61,7 +69,7 @@ def chat_fn( prefill_time = 0 is_first = True for text, exec_time in language_model.chat(history): - history[-1][-1] = text + history[-1][-1] = f"{text}{E_SYS}" if is_first: prefill_time = exec_time is_first = False @@ -73,101 +81,6 @@ def chat_fn( yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" -def llm_chat_api(InputData: dict): - return None - print(f"Input keys : {InputData.keys()}") - # print(f"model : {InputData['model']}") - is_chat_completion_api = ( - "messages" in InputData.keys() - ) # else it is the legacy `completion` api - # For Debugging input data from API - # if is_chat_completion_api: - # print(f"message -> role : {InputData['messages'][0]['role']}") - # print(f"message -> content : {InputData['messages'][0]['content']}") - # else: - # print(f"prompt : {InputData['prompt']}") - # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now - global vicuna_model - model_name = InputData["model"] if "model" in InputData.keys() else "codegen" - model_path = llm_model_map[model_name] - device = "cpu-task" - precision = "fp16" - max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] - if max_toks is None: - max_toks = 128 if model_name == "codegen" else 512 - - # make it working for codegen first - from apps.language_models.scripts.vicuna import ( - UnshardedVicuna, - ) - - device_id = None - if vicuna_model == 0: - if "cuda" in device: - device = "cuda" - elif "sync" in device: - device = "cpu-sync" - elif "task" in device: - device = "cpu-task" - elif "vulkan" in device: - device_id = int(device.split("://")[1]) - device = "vulkan" - else: - print("unrecognized device") - - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=True, - load_mlir_from_shark_tank=True, - device_id=device_id, - ) - - # TODO: add role dict for different models - if is_chat_completion_api: - # TODO: add funtionality for multiple messages - prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) - else: - prompt = InputData["prompt"] - print("prompt = ", prompt) - - res = vicuna_model.generate(prompt) - res_op = None - for op in res: - res_op = op - - if is_chat_completion_api: - choices = [ - { - "index": 0, - "message": { - "role": "assistant", - "content": res_op, # since we are yeilding the result - }, - "finish_reason": "stop", # or length - } - ] - else: - choices = [ - { - "text": res_op, - "index": 0, - "logprobs": None, - "finish_reason": "stop", # or length - } - ] - end_time = dt.now().strftime("%Y%m%d%H%M%S%f") - return { - "id": end_time, - "object": "chat.completion" if is_chat_completion_api else "text_completion", - "created": int(end_time), - "choices": choices, - } - - def view_json_file(file_obj): content = "" with open(file_obj.name, "r") as fopen: @@ -198,7 +111,7 @@ def view_json_file(file_obj): ) precision = gr.Radio( label="Precision", - value="int4", + value="fp32", choices=[ # "int4", # "int8", @@ -211,12 +124,18 @@ def view_json_file(file_obj): with gr.Column(): download_vmfb = gr.Checkbox( label="Download vmfb from Shark tank if available", + value=False, + interactive=True, + visible=False, + ) + streaming_llm = gr.Checkbox( + label="Run in streaming mode (requires recompilation)", value=True, interactive=True, ) prompt_prefix = gr.Checkbox( label="Add System Prompt", - value=False, + value=True, interactive=True, ) @@ -260,6 +179,7 @@ def view_json_file(file_obj): precision, download_vmfb, config_file, + streaming_llm, ], outputs=[chatbot, tokens_time], show_progress=False, @@ -281,6 +201,7 @@ def view_json_file(file_obj): precision, download_vmfb, config_file, + streaming_llm, ], outputs=[chatbot, tokens_time], show_progress=False, @@ -293,4 +214,19 @@ def view_json_file(file_obj): cancels=[submit_event, submit_click_event], queue=False, ) - clear.click(lambda: None, None, [chatbot], queue=False) + clear.click( + fn=chat_fn, + inputs=[ + clear, + chatbot, + model, + device, + precision, + download_vmfb, + config_file, + streaming_llm, + ], + outputs=[chatbot, tokens_time], + show_progress=False, + queue=True, + ).then(lambda: None, None, [chatbot], queue=False) diff --git a/apps/shark_studio/web/utils.py b/apps/shark_studio/web/utils.py new file mode 100644 index 0000000000..4072491cbf --- /dev/null +++ b/apps/shark_studio/web/utils.py @@ -0,0 +1,12 @@ +import os +import sys + + +def get_available_devices(): + return ["cpu-task"] + + +def get_resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_path, relative_path)