Skip to content

Commit

Permalink
[Shard] Add sharding generation in shark studio
Browse files Browse the repository at this point in the history
Signed-Off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav committed Aug 4, 2023
1 parent c9de272 commit bd30044
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 19 deletions.
17 changes: 9 additions & 8 deletions apps/language_models/src/model_wrappers/vicuna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,13 @@ def __init__(
self.second_vicuna = SecondVicuna(second_vicuna_model_path)

def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]

token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
first_output = self.first_vicuna(input_ids=input_ids)
# generate second vicuna
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
second_output = self.second_vicuna(*secondVicunaCompileInput)
return second_output
21 changes: 12 additions & 9 deletions apps/stable_diffusion/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def resource_path(relative_path):
upscaler_sendto_outpaint,
lora_train_web,
model_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
Expand Down Expand Up @@ -221,6 +222,16 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
outpaint_web.render()
with gr.TabItem(label="Upscaler", id=4):
upscaler_web.render()
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="Chat Bot(Experimental)", id=7):
stablelm_chat.render()
with gr.TabItem(label="Generate Sharding Config", id=8):
model_config_web.render()
with gr.TabItem(label="LoRA Training(Experimental)", id=9):
lora_train_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
if args.output_gallery:
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
outputgallery_web.render()
Expand All @@ -236,15 +247,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
upscaler_status,
]
)
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
stablelm_chat.render()
with gr.TabItem(label="MultiModal (Experimental)", id=9):
minigpt4_web.render()
with gr.TabItem(label="DocuChat(Experimental)", id=10):
with gr.TabItem(label="DocuChat(Experimental)", id=11):
h2ogpt_web.render()

# send to buttons
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
Expand Down
41 changes: 41 additions & 0 deletions apps/stable_diffusion/web/ui/generate_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import gradio as gr
import torch
from transformers import AutoTokenizer
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
from shark.shark_generate_model_config import GenerateConfigFile


def get_model_config():
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)

model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
return c.split_into_layers()


with gr.Blocks() as model_config_web:
with gr.Row():
hf_models = gr.Dropdown(
label="Model List",
choices=["Vicuna"],
value="Vicuna",
visible=True,
)
get_model_config_btn = gr.Button(value="Get Model Config")
json_view = gr.JSON()

get_model_config_btn.click(
fn=get_model_config,
inputs=[],
outputs=[json_view],
)
30 changes: 28 additions & 2 deletions shark/shark_generate_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def __init__(
def split_into_dispatches(
self,
backend,
fx_tracing_required=True,
fx_tracing_required=False,
f16_model=False,
torch_mlir_tracing=False,
torch_mlir_tracing=True,
):
graph_for_compilation = self.model
if fx_tracing_required:
Expand Down Expand Up @@ -103,3 +103,29 @@ def split_into_layers(self):
def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)


if __name__ == "__main__":
import torch
from transformers import AutoTokenizer

hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
CombinedModel,
)

model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_dispatches("vulkan")

0 comments on commit bd30044

Please sign in to comment.