diff --git a/.github/workflows/test-studio.yml b/.github/workflows/test-studio.yml new file mode 100644 index 0000000000..e12795b2c5 --- /dev/null +++ b/.github/workflows/test-studio.yml @@ -0,0 +1,86 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Validate Shark Studio + +on: + push: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + pull_request: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + workflow_dispatch: + +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-validate: + strategy: + fail-fast: true + matrix: + os: [nodai-ubuntu-builder-large] + suite: [cpu] #,cuda,vulkan] + python-version: ["3.11"] + include: + - os: nodai-ubuntu-builder-large + suite: lint + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v3 + + - name: Set Environment Variables + run: | + echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV + echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + + - name: Set up Python Version File ${{ matrix.python-version }} + run: | + echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: '${{ matrix.python-version }}' + + - name: Install dependencies + if: matrix.suite == 'lint' + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest toml black + + - name: Lint with flake8 + if: matrix.suite == 'lint' + run: | + # black format check + black --version + black --check apps/shark_studio + # stop the build if there are Python syntax errors or undefined names + flake8 . --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --exclude lit.cfg.py + + - name: Validate Models on CPU + if: matrix.suite == 'cpu' + run: | + cd $GITHUB_WORKSPACE + python${{ matrix.python-version }} -m venv shark.venv + source shark.venv/bin/activate + pip install -r requirements.txt + pip install -e . + pip uninstall -y torch + pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python apps/shark_studio/tests/api_test.py diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index dc4b90c872..dcbd938efe 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -45,7 +45,9 @@ def __init__( self.external_weight_file = external_weight_file self.use_system_prompt = use_system_prompt self.global_iter = 0 - if os.path.exists(self.vmfb_name): + if os.path.exists(self.vmfb_name) and ( + os.path.exists(self.external_weight_file) or external_weights is None + ): self.iree_module_dict = dict() ( self.iree_module_dict["vmfb"], @@ -64,9 +66,7 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[model_name][ - "initializer" - ]( + self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( self.hf_model_name, hf_auth_token, compile_to="torch", @@ -129,9 +129,7 @@ def format_out(results): self.iree_module_dict["config"].device, input_tensor ) ] - token = self.iree_module_dict["vmfb"]["run_initialize"]( - *device_inputs - ) + token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) else: device_inputs = [ ireert.asdevicearray( @@ -139,9 +137,7 @@ def format_out(results): token, ) ] - token = self.iree_module_dict["vmfb"]["run_forward"]( - *device_inputs - ) + token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) total_time = time.time() - st_time history.append(format_out(token)) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index bb5e150364..4072491cbf 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -8,7 +8,5 @@ def get_available_devices(): def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py new file mode 100644 index 0000000000..c88a1e70cb --- /dev/null +++ b/apps/shark_studio/tests/api_test.py @@ -0,0 +1,34 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest +from apps.shark_studio.api.llm import LanguageModel + + +class LLMAPITest(unittest.TestCase): + def testLLMSimple(self): + lm = LanguageModel( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + device="cpu-task", + external_weights="safetensors", + ) + count = 0 + for msg, _ in lm.chat("hi, what are you?"): + # skip first token output + if count == 0: + count += 1 + continue + assert ( + msg.strip(" ") == "Hello" + ), f"LLM API failed to return correct response, expected 'Hello', received {msg}" + break + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 28c1573bd9..3ef6bc5739 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -93,9 +93,7 @@ def launch_app(address): def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) dark_theme = resource_path("ui/css/sd_dark_theme.css") diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 31186f4d4d..fc6a2f2cd0 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -48,9 +48,7 @@ def create_prompt(model_name, history, prompt_prefix): msg = system_message + conversation msg = msg.strip() else: - conversation = "".join( - ["".join([item[0], item[1]]) for item in history] - ) + conversation = "".join(["".join([item[0], item[1]]) for item in history]) msg = system_message + conversation msg = msg.strip() return msg @@ -68,9 +66,7 @@ def get_default_config(): compilation_prompt, return_tensors="pt", ).input_ids - compilation_input_ids = torch.tensor(compilation_input_ids).reshape( - [1, 19] - ) + compilation_input_ids = torch.tensor(compilation_input_ids).reshape([1, 19]) firstVicunaCompileInput = (compilation_input_ids,) from apps.language_models.src.model_wrappers.vicuna_model import ( CombinedModel, @@ -142,17 +138,11 @@ def llm_chat_api(InputData: dict): # print(f"prompt : {InputData['prompt']}") # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now global vicuna_model - model_name = ( - InputData["model"] if "model" in InputData.keys() else "codegen" - ) + model_name = InputData["model"] if "model" in InputData.keys() else "codegen" model_path = llm_model_map[model_name] device = "cpu-task" precision = "fp16" - max_toks = ( - None - if "max_tokens" not in InputData.keys() - else InputData["max_tokens"] - ) + max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] if max_toks is None: max_toks = 128 if model_name == "codegen" else 512 @@ -189,9 +179,7 @@ def llm_chat_api(InputData: dict): # TODO: add role dict for different models if is_chat_completion_api: # TODO: add funtionality for multiple messages - prompt = create_prompt( - model_name, [(InputData["messages"][0]["content"], "")] - ) + prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) else: prompt = InputData["prompt"] print("prompt = ", prompt) @@ -224,9 +212,7 @@ def llm_chat_api(InputData: dict): end_time = dt.now().strftime("%Y%m%d%H%M%S%f") return { "id": end_time, - "object": "chat.completion" - if is_chat_completion_api - else "text_completion", + "object": "chat.completion" if is_chat_completion_api else "text_completion", "created": int(end_time), "choices": choices, } @@ -302,9 +288,7 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): - config_file = gr.File( - label="Upload sharding configuration", visible=False - ) + config_file = gr.File(label="Upload sharding configuration", visible=False) json_view_button = gr.Button(label="View as JSON", visible=False) json_view = gr.JSON(interactive=True, visible=False) json_view_button.click( diff --git a/build_tools/stable_diffusion_testing.py b/build_tools/stable_diffusion_testing.py index ced919732c..8eeb1a7395 100644 --- a/build_tools/stable_diffusion_testing.py +++ b/build_tools/stable_diffusion_testing.py @@ -36,9 +36,7 @@ def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir): metrics[val] = line.split(" ")[-1].strip("\n") metrics["Average step"] = metrics["Average step"].strip("ms/it") - metrics["Total image generation"] = metrics[ - "Total image generation" - ].strip("sec") + metrics["Total image generation"] = metrics["Total image generation"].strip("sec") metrics["device"] = device metrics["use_tune"] = use_tune metrics["model_name"] = model_name @@ -84,10 +82,14 @@ def test_loop( ] import_options = ["--import_mlir", "--no-import_mlir"] prompt_text = "--prompt=cyberpunk forest by Salvador Dali" - inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + inpaint_prompt_text = ( + "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + ) if os.name == "nt": prompt_text = '--prompt="cyberpunk forest by Salvador Dali"' - inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + inpaint_prompt_text = ( + '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + ) if beta: extra_flags.append("--beta_models=True") extra_flags.append("--no-progress_bar") @@ -174,9 +176,7 @@ def test_loop( ) print(command) print("Successfully generated image") - os.makedirs( - "./test_images/golden/" + model_name, exist_ok=True - ) + os.makedirs("./test_images/golden/" + model_name, exist_ok=True) download_public_file( "gs://shark_tank/testdata/golden/" + model_name, "./test_images/golden/" + model_name, @@ -191,14 +191,10 @@ def test_loop( ) test_file = glob(test_file_path)[0] - golden_path = ( - "./test_images/golden/" + model_name + "/*.png" - ) + golden_path = "./test_images/golden/" + model_name + "/*.png" golden_file = glob(golden_path)[0] try: - compare_images( - test_file, golden_file, upload=upload_bool - ) + compare_images(test_file, golden_file, upload=upload_bool) except AssertionError as e: print(e) if exit_on_fail == True: @@ -267,9 +263,7 @@ def prepare_artifacts(): parser.add_argument( "-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True ) -parser.add_argument( - "-g", "--gen", action=argparse.BooleanOptionalAction, default=False -) +parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False) if __name__ == "__main__": args = parser.parse_args() diff --git a/dataset/annotation_tool.py b/dataset/annotation_tool.py index edd088229f..6c6e270978 100644 --- a/dataset/annotation_tool.py +++ b/dataset/annotation_tool.py @@ -10,9 +10,7 @@ shark_root = Path(__file__).parent.parent demo_css = shark_root.joinpath("web/demo.css").resolve() -nodlogo_loc = shark_root.joinpath( - "web/models/stable_diffusion/logos/nod-logo.png" -) +nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png") with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web: @@ -75,9 +73,7 @@ def filter_datasets(dataset): with jsonlines.open(dataset_path + "/metadata.jsonl") as reader: for line in reader.iter(type=dict, skip_invalid=True): prompt_data[line["file_name"]] = ( - [line["text"]] - if type(line["text"]) is str - else line["text"] + [line["text"]] if type(line["text"]) is str else line["text"] ) return gr.Dropdown.update(choices=images[dataset]) @@ -103,9 +99,7 @@ def display_image(dataset, image_name): prompt_data[image_name] = [] prompt_choices = ["Add new"] prompt_choices += prompt_data[image_name] - return gr.Image.update(value=img), gr.Dropdown.update( - choices=prompt_choices - ) + return gr.Image.update(value=img), gr.Dropdown.update(choices=prompt_choices) image_name.change( fn=display_image, @@ -122,12 +116,7 @@ def edit_prompt(prompts): prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt) def save_prompt(dataset, image_name, prompts, prompt): - if ( - dataset is None - or image_name is None - or prompts is None - or prompt is None - ): + if dataset is None or image_name is None or prompts is None or prompt is None: return if prompts == "Add new": @@ -136,9 +125,7 @@ def save_prompt(dataset, image_name, prompts, prompt): idx = prompt_data[image_name].index(prompts) prompt_data[image_name][idx] = prompt - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -165,9 +152,7 @@ def delete_prompt(dataset, image_name, prompts): return prompt_data[image_name].remove(prompts) - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -230,9 +215,7 @@ def finish_annotation(dataset): # upload prompt and remove local data dataset_path = str(shark_root) + "/dataset/" + dataset dataset_gs_path = args.gs_url + "/" + dataset + "/" - os.system( - f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"' - ) + os.system(f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"') os.system(f'rm -rf "{dataset_path}"') return gr.Dropdown.update(value=None) diff --git a/process_skipfiles.py b/process_skipfiles.py index a846159451..339c7ebec6 100644 --- a/process_skipfiles.py +++ b/process_skipfiles.py @@ -8,8 +8,7 @@ # Temporary workaround for transformers/__init__.py. path_to_transformers_hook = Path( - get_python_lib() - + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" + get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" ) if path_to_transformers_hook.is_file(): pass @@ -59,9 +58,7 @@ # For getting around timm's packaging. # Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505 -path_to_timm_activations = Path( - get_python_lib() + "/timm/layers/activations_jit.py" -) +path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py") for line in fileinput.input(path_to_timm_activations, inplace=True): if "@torch.jit.script" in line: print("@torch.jit._script_if_tracing", end="\n") diff --git a/pyproject.toml b/pyproject.toml index 22e0210c50..876df2f8bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,14 +5,25 @@ requires = [ "packaging", "numpy>=1.22.4", - "torch-mlir>=20230620.875", "iree-compiler>=20221022.190", "iree-runtime>=20221022.190", ] build-backend = "setuptools.build_meta" [tool.black] -line-length = 79 include = '\.pyi?$' -exclude = "apps/language_models/scripts/vicuna.py" -extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py" +exclude = ''' +( + /( + | apps/stable_diffusion + | apps/language_models + | shark + | benchmarks + | tank + | build + | generated_imgs + | shark.venv + )/ + | setup.py +) +''' diff --git a/pytest.ini b/pytest.ini index 11f57888b2..3857248785 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] addopts = --verbose -s -p no:warnings -norecursedirs = inference tank/tflite examples benchmarks shark +norecursedirs = inference tank/tflite examples benchmarks shark apps/shark_studio diff --git a/rest_api_tests/api_test.py b/rest_api_tests/api_test.py index 7a4cf042c2..f3c0b0e170 100644 --- a/rest_api_tests/api_test.py +++ b/rest_api_tests/api_test.py @@ -44,14 +44,10 @@ def upscaler_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[upscaler] response from server was : {res.status_code} {res.reason}" - ) + print(f"[upscaler] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def img2img_test(verbose=False): @@ -96,14 +92,10 @@ def img2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[img2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[img2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") # NOTE Uncomment below to save the picture @@ -133,13 +125,9 @@ def inpainting_test(verbose=False): image_path = r"./rest_api_tests/dog.png" img_file = open(image_path, "rb") - image = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + image = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() img_file = open(image_path, "rb") - mask = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + mask = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() url = "http://127.0.0.1:8080/sdapi/v1/inpaint" @@ -166,14 +154,10 @@ def inpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[inpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[inpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def outpainting_test(verbose=False): @@ -223,14 +207,10 @@ def outpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[outpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[outpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def txt2img_test(verbose=False): @@ -262,14 +242,10 @@ def txt2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[txt2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[txt2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def sd_models_test(verbose=False): @@ -283,9 +259,7 @@ def sd_models_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_models] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_models] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -302,9 +276,7 @@ def sd_samplers_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_samplers] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_samplers] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -321,9 +293,7 @@ def options_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[options] response from server was : {res.status_code} {res.reason}" - ) + print(f"[options] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -340,9 +310,7 @@ def cmd_flags_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[cmd-flags] response from server was : {res.status_code} {res.reason}" - ) + print(f"[cmd-flags] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") diff --git a/setup.py b/setup.py index c387fe9add..061873e7a8 100644 --- a/setup.py +++ b/setup.py @@ -9,11 +9,6 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5" backend_deps = [] -if "NO_BACKEND" in os.environ.keys(): - backend_deps = [ - "iree-compiler>=20221022.190", - "iree-runtime>=20221022.190", - ] setup( name="nodai-SHARK", @@ -39,7 +34,5 @@ install_requires=[ "numpy", "PyYAML", - "torch-mlir", ] - + backend_deps, )