From 6bf51f1f1d78eb2b058e500fa4335faf69787657 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Thu, 1 Feb 2024 09:46:22 -0800 Subject: [PATCH] HF-Reference LLM mode + Update test result to match latest Turbine. (#2080) * HF-Reference LLM mode. * Fixup test to match current output from Turbine. * lint * Fix test error message + Only initialize HF torch model when used. * Remove redundant format_out change. --- apps/shark_studio/api/llm.py | 49 ++++++++++++++++++++++++++++- apps/shark_studio/tests/api_test.py | 5 +-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index a9d39f8e7b..647d6a5af1 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -9,7 +9,7 @@ import gc import os import torch -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForCausalLM llm_model_map = { "llama2_7b": { @@ -109,6 +109,7 @@ def __init__( self.global_iter = 0 self.prev_token_len = 0 self.first_input = True + self.hf_auth_token = hf_auth_token if self.external_weight_file is not None: if not os.path.exists(self.external_weight_file): print( @@ -164,6 +165,8 @@ def __init__( use_auth_token=hf_auth_token, ) self.compile() + # Reserved for running HF torch model as reference. + self.hf_mod = None def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". @@ -267,6 +270,50 @@ def format_out(results): self.global_iter += 1 return result_output, total_time + # Reference HF model function for sanity checks. + def chat_hf(self, prompt): + if self.hf_mod is None: + self.hf_mod = AutoModelForCausalLM.from_pretrained( + self.hf_model_name, + torch_dtype=torch.float, + token=self.hf_auth_token, + ) + prompt = self.sanitize_prompt(prompt) + + input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids + history = [] + for iter in range(self.max_tokens): + token_len = input_tensor.shape[-1] + if self.first_input: + st_time = time.time() + result = self.hf_mod(input_tensor) + token = torch.argmax(result.logits[:, -1, :], dim=1) + total_time = time.time() - st_time + token_len += 1 + pkv = result.past_key_values + self.first_input = False + + history.append(int(token)) + while token != llm_model_map["llama2_7b"]["stop_token"]: + dec_time = time.time() + result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv) + history.append(int(token)) + total_time = time.time() - dec_time + token = torch.argmax(result.logits[:, -1, :], dim=1) + pkv = result.past_key_values + yield self.tokenizer.decode(history), total_time + + self.prev_token_len = token_len + len(history) + + if token == llm_model_map["llama2_7b"]["stop_token"]: + break + for i in range(len(history)): + if type(history[i]) != int: + history[i] = int(history[i]) + result_output = self.tokenizer.decode(history) + self.global_iter += 1 + return result_output, total_time + if __name__ == "__main__": lm = LanguageModel( diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index 203edd0821..d07bb05b90 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -20,14 +20,15 @@ def test01_LLMSmall(self): quantization="None", ) count = 0 + label = "Turkishoure Turkish" for msg, _ in lm.chat("hi, what are you?"): # skip first token output if count == 0: count += 1 continue assert ( - msg.strip(" ") == "Turkish Turkish Turkish" - ), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}" + msg.strip(" ") == label + ), f"LLM API failed to return correct response, expected '{label}', received {msg}" break del lm gc.collect()