Skip to content

Commit

Permalink
add tests and edit deps
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Dec 13, 2023
1 parent f0d3d42 commit a92bed8
Show file tree
Hide file tree
Showing 14 changed files with 187 additions and 194 deletions.
86 changes: 86 additions & 0 deletions .github/workflows/test-studio.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Validate Shark Studio

on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:

# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v3

- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'

- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
source shark.venv/bin/activate
pip install -r requirements.txt
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py
16 changes: 6 additions & 10 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(
self.external_weight_file = external_weight_file
self.use_system_prompt = use_system_prompt
self.global_iter = 0
if os.path.exists(self.vmfb_name):
if os.path.exists(self.vmfb_name) and (
os.path.exists(self.external_weight_file) or external_weights is None
):
self.iree_module_dict = dict()
(
self.iree_module_dict["vmfb"],
Expand All @@ -64,9 +66,7 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
Expand Down Expand Up @@ -129,19 +129,15 @@ def format_out(results):
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
)
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
else:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
token,
)
]
token = self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
)
token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)

total_time = time.time() - st_time
history.append(format_out(token))
Expand Down
4 changes: 1 addition & 3 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ def get_available_devices():

def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
34 changes: 34 additions & 0 deletions apps/shark_studio/tests/api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import unittest
from apps.shark_studio.api.llm import LanguageModel


class LLMAPITest(unittest.TestCase):
def testLLMSimple(self):
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
count = 0
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count += 1
continue
assert (
msg.strip(" ") == "Hello"
), f"LLM API failed to return correct response, expected 'Hello', received {msg}"
break


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
4 changes: 1 addition & 3 deletions apps/shark_studio/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def launch_app(address):

def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)

dark_theme = resource_path("ui/css/sd_dark_theme.css")
Expand Down
76 changes: 5 additions & 71 deletions apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,63 +24,9 @@ def user(message, history):

def create_prompt(model_name, history, prompt_prefix):
return ""
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]

if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
return msg


def get_default_config():
return False
import torch
from transformers import AutoTokenizer

hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile

model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()


# model_vmfb_key = ""

Expand Down Expand Up @@ -142,17 +88,11 @@ def llm_chat_api(InputData: dict):
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_name = InputData["model"] if "model" in InputData.keys() else "codegen"
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"]
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512

Expand Down Expand Up @@ -189,9 +129,7 @@ def llm_chat_api(InputData: dict):
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")])
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
Expand Down Expand Up @@ -224,9 +162,7 @@ def llm_chat_api(InputData: dict):
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
Expand Down Expand Up @@ -302,9 +238,7 @@ def view_json_file(file_obj):

with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
Expand Down
28 changes: 11 additions & 17 deletions build_tools/stable_diffusion_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir):
metrics[val] = line.split(" ")[-1].strip("\n")

metrics["Average step"] = metrics["Average step"].strip("ms/it")
metrics["Total image generation"] = metrics[
"Total image generation"
].strip("sec")
metrics["Total image generation"] = metrics["Total image generation"].strip("sec")
metrics["device"] = device
metrics["use_tune"] = use_tune
metrics["model_name"] = model_name
Expand Down Expand Up @@ -84,10 +82,14 @@ def test_loop(
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
inpaint_prompt_text = (
"--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
)
if os.name == "nt":
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
inpaint_prompt_text = (
'--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
)
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
Expand Down Expand Up @@ -174,9 +176,7 @@ def test_loop(
)
print(command)
print("Successfully generated image")
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
os.makedirs("./test_images/golden/" + model_name, exist_ok=True)
download_public_file(
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
Expand All @@ -191,14 +191,10 @@ def test_loop(
)
test_file = glob(test_file_path)[0]

golden_path = (
"./test_images/golden/" + model_name + "/*.png"
)
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
try:
compare_images(
test_file, golden_file, upload=upload_bool
)
compare_images(test_file, golden_file, upload=upload_bool)
except AssertionError as e:
print(e)
if exit_on_fail == True:
Expand Down Expand Up @@ -267,9 +263,7 @@ def prepare_artifacts():
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False)

if __name__ == "__main__":
args = parser.parse_args()
Expand Down
Loading

0 comments on commit a92bed8

Please sign in to comment.