Skip to content

Commit

Permalink
HF-Reference LLM mode + Update test result to match latest Turbine. (#…
Browse files Browse the repository at this point in the history
…2080)

* HF-Reference LLM mode.

* Fixup test to match current output from Turbine.

* lint

* Fix test error message + Only initialize HF torch model when used.

* Remove redundant format_out change.
  • Loading branch information
raikonenfnu authored Feb 1, 2024
1 parent 05b4982 commit 6bf51f1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
49 changes: 48 additions & 1 deletion apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import gc
import os
import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM

llm_model_map = {
"llama2_7b": {
Expand Down Expand Up @@ -109,6 +109,7 @@ def __init__(
self.global_iter = 0
self.prev_token_len = 0
self.first_input = True
self.hf_auth_token = hf_auth_token
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
Expand Down Expand Up @@ -164,6 +165,8 @@ def __init__(
use_auth_token=hf_auth_token,
)
self.compile()
# Reserved for running HF torch model as reference.
self.hf_mod = None

def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
Expand Down Expand Up @@ -267,6 +270,50 @@ def format_out(results):
self.global_iter += 1
return result_output, total_time

# Reference HF model function for sanity checks.
def chat_hf(self, prompt):
if self.hf_mod is None:
self.hf_mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=self.hf_auth_token,
)
prompt = self.sanitize_prompt(prompt)

input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
history = []
for iter in range(self.max_tokens):
token_len = input_tensor.shape[-1]
if self.first_input:
st_time = time.time()
result = self.hf_mod(input_tensor)
token = torch.argmax(result.logits[:, -1, :], dim=1)
total_time = time.time() - st_time
token_len += 1
pkv = result.past_key_values
self.first_input = False

history.append(int(token))
while token != llm_model_map["llama2_7b"]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
total_time = time.time() - dec_time
token = torch.argmax(result.logits[:, -1, :], dim=1)
pkv = result.past_key_values
yield self.tokenizer.decode(history), total_time

self.prev_token_len = token_len + len(history)

if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
self.global_iter += 1
return result_output, total_time


if __name__ == "__main__":
lm = LanguageModel(
Expand Down
5 changes: 3 additions & 2 deletions apps/shark_studio/tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ def test01_LLMSmall(self):
quantization="None",
)
count = 0
label = "Turkishoure Turkish"
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Turkish Turkish Turkish"
), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}"
msg.strip(" ") == label
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
break
del lm
gc.collect()
Expand Down

0 comments on commit 6bf51f1

Please sign in to comment.