From 9025b7db65db8fa3d98367285ee3dd2968d255a7 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 29 Apr 2024 12:45:30 -0400 Subject: [PATCH] Fix formatting --- apps/shark_studio/api/llm.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 337b606fce..6ee80ae49e 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -258,7 +258,8 @@ def format_out(results): history.append(format_out(token)) while ( - format_out(token) != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] + format_out(token) + != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] and len(history) < self.max_tokens ): dec_time = time.time() @@ -272,7 +273,10 @@ def format_out(results): self.prev_token_len = token_len + len(history) - if format_out(token) == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]: + if ( + format_out(token) + == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] + ): break for i in range(len(history)): @@ -347,7 +351,11 @@ def llm_chat_api(InputData: dict): else: print(f"prompt : {InputData['prompt']}") - model_name = InputData["model"] if "model" in InputData.keys() else "meta-llama/Llama-2-7b-chat-hf" + model_name = ( + InputData["model"] + if "model" in InputData.keys() + else "meta-llama/Llama-2-7b-chat-hf" + ) model_path = llm_model_map[model_name] device = InputData["device"] if "device" in InputData.keys() else "cpu" precision = "fp16"