Skip to content

Commit

Permalink
[vicuna] Integrate sharded vicuna in web (#1717)
Browse files Browse the repository at this point in the history
Signed-off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav authored Aug 4, 2023
1 parent bd30044 commit 51ec1a1
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json


def user(message, history):
Expand Down Expand Up @@ -106,7 +107,15 @@ def set_vicuna_model(model):


# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision, cli=True):
def chat(
curr_system_message,
history,
model,
devices,
precision,
config_file,
cli=True,
):
global past_key_values

global vicuna_model
Expand All @@ -121,10 +130,12 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
ShardedVicuna,
)
from apps.stable_diffusion.src import args

if vicuna_model == 0:
device = devices[0]
if "cuda" in device:
device = "cuda"
elif "sync" in device:
Expand All @@ -137,14 +148,28 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
print("unrecognized device")

max_toks = 128 if model_name == "codegen" else 512
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
if len(devices) == 1 and config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
else:
if config_file is not None:
config_file = open(config_file)
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
vicuna_model = ShardedVicuna(
model_name,
device=device,
precision=precision,
config_json=config_json,
)
prompt = create_prompt(model_name, history)

for partial_text in vicuna_model.generate(prompt, cli=cli):
Expand Down Expand Up @@ -307,13 +332,14 @@ def view_json_file(file_obj):
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
devices = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
multiselect=True,
)
precision = gr.Radio(
label="Precision",
Expand Down Expand Up @@ -357,15 +383,15 @@ def view_json_file(file_obj):
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, device, precision],
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
queue=True,
)
Expand Down

0 comments on commit 51ec1a1

Please sign in to comment.