From 644c3ce56a0e976db273c4307ce39e965bfe274d Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Fri, 23 Jun 2023 20:06:44 +0530 Subject: [PATCH] [vicuna] Add streaming of tokens Signed-Off-by: Gaurav Shukla --- apps/language_models/src/pipelines/vicuna_pipeline.py | 3 +++ apps/stable_diffusion/web/ui/stablelm_ui.py | 8 +++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 3c9c123c67..0e418c65f0 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -461,6 +461,7 @@ def generate(self, prompt, cli=False): logits = generated_token_op["logits"] pkv = generated_token_op["pkv"] detok = generated_token_op["detok"] + yield detok res.append(detok) res_tokens.append(token) @@ -502,6 +503,8 @@ def generate(self, prompt, cli=False): res.append(detok) if cli: print(f"{detok}", end=" ", flush=True) + yield detok + if self.device == "cuda": del sec_vic, pkv, logits torch.cuda.empty_cache() diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 0e5cf4092d..84d5921fcc 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -65,16 +65,14 @@ def chat(curr_system_message, history, model, device, precision): ) prompt = messages.strip() print("prompt = ", prompt) - sentence = vicuna_model.generate(prompt) partial_text = "" - for new_text in sentence.split(" "): + for new_text in vicuna_model.generate(prompt): # print(new_text) partial_text += new_text + " " history[-1][1] = partial_text # Yield an empty string to cleanup the message textbox and the updated conversation history yield history - history[-1][1] = sentence return history # else Model is StableLM @@ -124,13 +122,13 @@ def chat(curr_system_message, history, model, device, precision): "TheBloke/vicuna-7B-1.1-HF", ], ) - supported_devices = available_devices + supported_devices = available_devices + ["AMD-AIE"] enabled = len(supported_devices) > 0 device = gr.Dropdown( label="Device", value=supported_devices[0] if enabled - else "Only CUDA Supported for now", + else "No devices supported for now", choices=supported_devices, interactive=enabled, )