Skip to content

Commit

Permalink
Fix test error message + Only initialize HF torch model when used.
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanley Winata committed Feb 1, 2024
1 parent 32843c3 commit ea44e47
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
19 changes: 12 additions & 7 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
self.global_iter = 0
self.prev_token_len = 0
self.first_input = True
self.hf_auth_token = hf_auth_token
if self.external_weight_file is not None:
if not os.path.exists(self.external_weight_file):
print(
Expand Down Expand Up @@ -164,11 +165,8 @@ def __init__(
use_auth_token=hf_auth_token,
)
self.compile()
self.mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=hf_auth_token,
)
# Reserved for running HF torch model as reference.
self.hf_mod = None

def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
Expand Down Expand Up @@ -272,7 +270,14 @@ def format_out(results):
self.global_iter += 1
return result_output, total_time

# Reference HF model function for sanity checks.
def chat_hf(self, prompt):
if self.hf_mod is None:
self.hf_mod = AutoModelForCausalLM.from_pretrained(
self.hf_model_name,
torch_dtype=torch.float,
token=self.hf_auth_token,
)
prompt = self.sanitize_prompt(prompt)

input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
Expand All @@ -281,7 +286,7 @@ def chat_hf(self, prompt):
token_len = input_tensor.shape[-1]
if self.first_input:
st_time = time.time()
result = self.mod(input_tensor)
result = self.hf_mod(input_tensor)
token = torch.argmax(result.logits[:, -1, :], dim=1)
total_time = time.time() - st_time
token_len += 1
Expand All @@ -291,7 +296,7 @@ def chat_hf(self, prompt):
history.append(int(token))
while token != llm_model_map["llama2_7b"]["stop_token"]:
dec_time = time.time()
result = self.mod(token.reshape([1, 1]), past_key_values=pkv)
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
total_time = time.time() - dec_time
token = torch.argmax(result.logits[:, -1, :], dim=1)
Expand Down
5 changes: 3 additions & 2 deletions apps/shark_studio/tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ def test01_LLMSmall(self):
quantization="None",
)
count = 0
label = "Turkishoure Turkish"
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Turkishoure Turkish"
), f"LLM API failed to return correct response, expected 'Turkish Turkish Turkish', received {msg}"
msg.strip(" ") == label
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
break
del lm
gc.collect()
Expand Down

0 comments on commit ea44e47

Please sign in to comment.