Skip to content

Commit

Permalink
add tests and edit deps
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Dec 14, 2023
1 parent f0d3d42 commit f6d1a9d
Show file tree
Hide file tree
Showing 15 changed files with 206 additions and 369 deletions.
164 changes: 0 additions & 164 deletions .github/workflows/test-models.yml

This file was deleted.

86 changes: 86 additions & 0 deletions .github/workflows/test-studio.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Validate Shark Studio

on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:

# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v3

- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'

- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
source shark.venv/bin/activate
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python apps/shark_studio/tests/api_test.py
43 changes: 25 additions & 18 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
}
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
}


Expand All @@ -31,7 +38,6 @@ def __init__(
device=None,
precision="fp32",
external_weights=None,
external_weight_file=None,
use_system_prompt=True,
):
print(llm_model_map[model_name])
Expand All @@ -40,12 +46,19 @@ def __init__(
self.vmfb_name = get_resource_path("llm.vmfb.tempfile")
self.device = device
self.precision = precision
self.safe_name = self.hf_model_name.strip("/").replace("/", "_")
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.external_weight_file = external_weight_file
self.external_weight_file = None
if external_weights is not None:
self.external_weight_file = get_resource_path(
self.safe_name + "." + external_weights
)
self.use_system_prompt = use_system_prompt
self.global_iter = 0
if os.path.exists(self.vmfb_name):
if os.path.exists(self.vmfb_name) and (
external_weights is None or os.path.exists(str(self.external_weight_file))
):
self.iree_module_dict = dict()
(
self.iree_module_dict["vmfb"],
Expand All @@ -56,22 +69,20 @@ def __init__(
device,
device_idx=0,
rt_flags=[],
external_weight_file=external_weight_file,
external_weight_file=self.external_weight_file,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
external_weight_file=external_weight_file,
external_weight_file=self.external_weight_file,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
Expand Down Expand Up @@ -129,19 +140,15 @@ def format_out(results):
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
)
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
else:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
token,
)
]
token = self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
)
token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)

total_time = time.time() - st_time
history.append(format_out(token))
Expand All @@ -160,12 +167,12 @@ def format_out(results):

if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
external_weight_file="llama2_7b.safetensors",
)

print("model loaded")
for i in lm.chat("hi, what are you?"):
print(i)
4 changes: 1 addition & 3 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ def get_available_devices():

def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
Loading

0 comments on commit f6d1a9d

Please sign in to comment.