From 81d6e059acfdcc534bd33f5cddb89049c40d8f5e Mon Sep 17 00:00:00 2001 From: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:18:16 -0400 Subject: [PATCH] Fix Llama2 on CPU (#2133) --- apps/shark_studio/api/llm.py | 20 ++++++++++++++------ apps/shark_studio/web/ui/chat.py | 2 ++ requirements.txt | 1 + 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index a88aaa9b02..6ee80ae49e 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -13,7 +13,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM llm_model_map = { - "llama2_7b": { + "meta-llama/Llama-2-7b-chat-hf": { "initializer": stateless_llama.export_transformer_model, "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", "compile_flags": ["--iree-opt-const-expr-hoisting=False"], @@ -258,7 +258,8 @@ def format_out(results): history.append(format_out(token)) while ( - format_out(token) != llm_model_map["llama2_7b"]["stop_token"] + format_out(token) + != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] and len(history) < self.max_tokens ): dec_time = time.time() @@ -272,7 +273,10 @@ def format_out(results): self.prev_token_len = token_len + len(history) - if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: + if ( + format_out(token) + == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] + ): break for i in range(len(history)): @@ -306,7 +310,7 @@ def chat_hf(self, prompt): self.first_input = False history.append(int(token)) - while token != llm_model_map["llama2_7b"]["stop_token"]: + while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]: dec_time = time.time() result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv) history.append(int(token)) @@ -317,7 +321,7 @@ def chat_hf(self, prompt): self.prev_token_len = token_len + len(history) - if token == llm_model_map["llama2_7b"]["stop_token"]: + if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: @@ -347,7 +351,11 @@ def llm_chat_api(InputData: dict): else: print(f"prompt : {InputData['prompt']}") - model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b" + model_name = ( + InputData["model"] + if "model" in InputData.keys() + else "meta-llama/Llama-2-7b-chat-hf" + ) model_path = llm_model_map[model_name] device = InputData["device"] if "device" in InputData.keys() else "cpu" precision = "fp16" diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index f41eaaaba0..418f087548 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -9,6 +9,7 @@ llm_model_map, LanguageModel, ) +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import apps.shark_studio.web.utils.globals as global_obj B_SYS, E_SYS = "", "" @@ -64,6 +65,7 @@ def chat_fn( external_weights="safetensors", use_system_prompt=prompt_prefix, streaming_llm=streaming_llm, + hf_auth_token=cmd_opts.hf_auth_token, ) history[-1][-1] = "Getting the model ready... Done" yield history, "" diff --git a/requirements.txt b/requirements.txt index c2a598978d..8de4bd406b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,6 +35,7 @@ safetensors==0.3.1 py-cpuinfo pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions mpmath==1.3.0 +optimum # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile