Skip to content

Commit

Permalink
add tests and edit deps
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Dec 12, 2023
1 parent ba77cf8 commit 0a89c63
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 10 deletions.
84 changes: 84 additions & 0 deletions .github/workflows/test-studio.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Validate Shark Studio

on:
push:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
pull_request:
branches: [ main ]
paths-ignore:
- '**.md'
- 'shark/examples/**'
workflow_dispatch:

# Ensure that only a single job or workflow using the same
# concurrency group will run at a time. This would cancel
# any in-progress jobs in the same github workflow and github
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build-validate:
strategy:
fail-fast: true
matrix:
os: [nodai-ubuntu-builder-large]
suite: [cpu] #,cuda,vulkan]
python-version: ["3.11"]
include:
- os: nodai-ubuntu-builder-large
suite: lint

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v3

- name: Set Environment Variables
run: |
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
- name: Set up Python Version File ${{ matrix.python-version }}
run: |
echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: '${{ matrix.python-version }}'

- name: Install dependencies
if: matrix.suite == 'lint'
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml black
- name: Lint with flake8
if: matrix.suite == 'lint'
run: |
# black format check
black --version
black --check apps/shark_studio
# stop the build if there are Python syntax errors or undefined names
flake8 . --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
--statistics --exclude lit.cfg.py
- name: Validate Models on CPU
if: matrix.suite == 'cpu'
run: |
cd $GITHUB_WORKSPACE
python${{ matrix.python-version }} -m venv shark.venv
shark.venv/bin/activate
pip install -r requirements.txt
pip install -e .
python apps/shark_studio/tests/api_tests.py
34 changes: 34 additions & 0 deletions apps/shark_studio/tests/api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import unittest
from apps.shark_studio.api.llm import LanguageModel


class LLMAPITest(unittest.TestCase):
def testLLMSimple(self):
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
)
count = 0
for msg, _ in lm.chat("hi, what are you?"):
# skip first token output
if count == 0:
count+=1
continue
assert msg.strip(" ") == "Hello", f"LLM API failed to return correct response, expected 'Hello', received {msg}"
break




if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
16 changes: 13 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ requires = [
"packaging",

"numpy>=1.22.4",
"torch-mlir>=20230620.875",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]
Expand All @@ -14,5 +13,16 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = "apps/language_models/scripts/vicuna.py"
extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py"
exclude = '''
(
/(
| apps/shark_studio
| apps/language_models/scripts/vicuna.py
| apps/language_models/src/pipelines/minigpt4_pipeline.py
| build
| generated_imgs
| shark.venv
)/
| setup.py
)
'''
7 changes: 0 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@

PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
backend_deps = []
if "NO_BACKEND" in os.environ.keys():
backend_deps = [
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]

setup(
name="nodai-SHARK",
Expand All @@ -39,7 +34,5 @@
install_requires=[
"numpy",
"PyYAML",
"torch-mlir",
]
+ backend_deps,
)

0 comments on commit 0a89c63

Please sign in to comment.