Skip to content

Commit

Permalink
[vicuna] Add tokens streaming(step=3) (#1600)
Browse files Browse the repository at this point in the history
Signed-off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav authored Jun 27, 2023
1 parent 75672c0 commit 1d6a1f9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
26 changes: 16 additions & 10 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,14 @@ def compile(self):
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model

def decode_tokens(self, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])

res_str = self.tokenizer.decode(res_tokens)
return res_str

def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
Expand All @@ -448,7 +456,6 @@ def generate(self, prompt, cli=False):
self.first_vic = self.compile_first_vicuna()
if self.second_vic == None:
self.second_vic = self.compile_second_vicuna()
res = []
res_tokens = []
params = {
"prompt": prompt,
Expand All @@ -464,8 +471,8 @@ def generate(self, prompt, cli=False):
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
yield detok

res.append(detok)
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
Expand Down Expand Up @@ -498,25 +505,24 @@ def generate(self, prompt, cli=False):
break
res_tokens.append(token)
if detok == "<0x0A>":
res.append("\n")
if cli:
print("\n", end="", flush=True)
else:
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)

if len(res_tokens) % 3 == 0:
part_str = self.decode_tokens(res_tokens)
yield part_str

if self.device == "cuda":
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()

for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])

res_str = self.tokenizer.decode(res_tokens)
res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
return res_str
yield res_str

def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):
Expand Down
9 changes: 2 additions & 7 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,11 @@ def chat(curr_system_message, history, model, device, precision):
)
prompt = messages.strip()
print("prompt = ", prompt)
sentence = vicuna_model.generate(prompt)

partial_text = ""
for new_text in sentence.split(" "):
# print(new_text)
partial_text += new_text + " "
for partial_text in vicuna_model.generate(prompt):
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
history[-1][1] = sentence

return history

# else Model is StableLM
Expand Down

0 comments on commit 1d6a1f9

Please sign in to comment.