Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Studio2) Refactors SD pipeline to rely on turbine-models pipeline, fixes to LLM, gitignore #2129

Merged
merged 20 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/test-studio.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,4 @@ jobs:
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ cython_debug/
# vscode related
.vscode

# Shark related artefacts
# Shark related artifacts
gpetters94 marked this conversation as resolved.
Show resolved Hide resolved
*venv/
shark_tmp/
*.vmfb
.use-iree
tank/dict_configs.py
*.csv
reproducers/
apps/shark_studio/web/configs

# ORT related artefacts
cache_models/
Expand All @@ -188,6 +189,11 @@ variants.json
# models folder
apps/stable_diffusion/web/models/

# model artifacts (SHARK)
*.tempfile
*.mlir
*.vmfb

# Stencil annotators.
stencil_annotator/

Expand Down
16 changes: 7 additions & 9 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
Expand Down Expand Up @@ -258,8 +260,7 @@ def format_out(results):

history.append(format_out(token))
while (
format_out(token)
!= llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
format_out(token) != llm_model_map[self.hf_model_name]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
Expand All @@ -273,10 +274,7 @@ def format_out(results):

self.prev_token_len = token_len + len(history)

if (
format_out(token)
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
):
if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
break

for i in range(len(history)):
Expand Down Expand Up @@ -310,7 +308,7 @@ def chat_hf(self, prompt):
self.first_input = False

history.append(int(token))
while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
while token != llm_model_map[self.hf_model_name]["stop_token"]:
dec_time = time.time()
result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv)
history.append(int(token))
Expand All @@ -321,7 +319,7 @@ def chat_hf(self, prompt):

self.prev_token_len = token_len + len(history)

if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]:
if token == llm_model_map[self.hf_model_name]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
Expand Down
Loading
Loading