From ebfcfec3389479ddc70999b04ac415ca3066b829 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Thu, 14 Dec 2023 21:44:37 -0600 Subject: [PATCH] remove shark 1.0 tests, add support for 2.0 llm * add support for external weights * add tests and edit deps --- .github/workflows/test-models.yml | 164 -------------- .github/workflows/test-studio.yml | 86 ++++++++ apps/shark_studio/api/llm.py | 163 ++++++++++---- apps/shark_studio/api/utils.py | 4 +- apps/shark_studio/tests/api_test.py | 34 +++ apps/shark_studio/web/index.py | 6 +- apps/shark_studio/web/ui/chat.py | 273 +++--------------------- build_tools/stable_diffusion_testing.py | 28 +-- dataset/annotation_tool.py | 31 +-- process_skipfiles.py | 7 +- pyproject.toml | 19 +- pytest.ini | 2 +- requirements.txt | 11 +- rest_api_tests/api_test.py | 64 ++---- setup.py | 7 - shark/iree_utils/compile_utils.py | 54 ++++- 16 files changed, 377 insertions(+), 576 deletions(-) delete mode 100644 .github/workflows/test-models.yml create mode 100644 .github/workflows/test-studio.yml create mode 100644 apps/shark_studio/tests/api_test.py diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml deleted file mode 100644 index 8e9809ee41..0000000000 --- a/.github/workflows/test-models.yml +++ /dev/null @@ -1,164 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Validate Models on Shark Runtime - -on: - push: - branches: [ main ] - paths-ignore: - - '**.md' - - 'shark/examples/**' - pull_request: - branches: [ main ] - paths-ignore: - - '**.md' - - 'shark/examples/**' - workflow_dispatch: - -# Ensure that only a single job or workflow using the same -# concurrency group will run at a time. This would cancel -# any in-progress jobs in the same github workflow and github -# ref (e.g. refs/heads/main or refs/pull//merge). -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - build-validate: - strategy: - fail-fast: true - matrix: - os: [7950x, icelake, a100, MacStudio, ubuntu-latest] - suite: [cpu,cuda,vulkan] - python-version: ["3.11"] - include: - - os: ubuntu-latest - suite: lint - - os: MacStudio - suite: metal - exclude: - - os: ubuntu-latest - suite: vulkan - - os: ubuntu-latest - suite: cuda - - os: ubuntu-latest - suite: cpu - - os: MacStudio - suite: cuda - - os: MacStudio - suite: cpu - - os: MacStudio - suite: vulkan - - os: icelake - suite: vulkan - - os: icelake - suite: cuda - - os: a100 - suite: cpu - - os: 7950x - suite: cpu - - os: 7950x - suite: cuda - - runs-on: ${{ matrix.os }} - - steps: - - uses: actions/checkout@v3 - - - name: Set Environment Variables - if: matrix.os != '7950x' - run: | - echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV - echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - - - name: Set up Python Version File ${{ matrix.python-version }} - if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake' - run: | - # See https://github.com/actions/setup-python/issues/433 - echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version - - - name: Set up Python ${{ matrix.python-version }} - if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake' - uses: actions/setup-python@v4 - with: - python-version: '${{ matrix.python-version }}' - #cache: 'pip' - #cache-dependency-path: | - # **/requirements-importer.txt - # **/requirements.txt - - - name: Install dependencies - if: matrix.suite == 'lint' - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest toml black - - - name: Lint with flake8 - if: matrix.suite == 'lint' - run: | - # black format check - black --version - black --check . - # stop the build if there are Python syntax errors or undefined names - flake8 . --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \ - --statistics --exclude lit.cfg.py - - - name: Validate Models on CPU - if: matrix.suite == 'cpu' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --benchmark=native --update_tank -k cpu - gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv - gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv - python build_tools/vicuna_testing.py - - - name: Validate Models on NVIDIA GPU - if: matrix.suite == 'cuda' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --benchmark=native --update_tank -k cuda - gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv - gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv - # Disabled due to black image bug - # python build_tools/stable_diffusion_testing.py --device=cuda - - - name: Validate Vulkan Models (MacOS) - if: matrix.suite == 'metal' && matrix.os == 'MacStudio' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} ./setup_venv.sh - source shark.venv/bin/activate - echo $PATH - pip list | grep -E "torch|iree" - # disabled due to a low-visibility memory issue with pytest on macos. - # pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal - - - name: Validate Vulkan Models (a100) - if: matrix.suite == 'vulkan' && matrix.os == 'a100' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --update_tank -k vulkan - python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail - - - name: Validate Vulkan Models (Windows) - if: matrix.suite == 'vulkan' && matrix.os == '7950x' - run: | - ./setup_venv.ps1 - pytest -k vulkan -s --ci - - - name: Validate Stable Diffusion Models (Windows) - if: matrix.suite == 'vulkan' && matrix.os == '7950x' - run: | - ./setup_venv.ps1 - python process_skipfiles.py - pyinstaller .\apps\stable_diffusion\shark_sd.spec - python build_tools/stable_diffusion_testing.py --device=vulkan diff --git a/.github/workflows/test-studio.yml b/.github/workflows/test-studio.yml new file mode 100644 index 0000000000..765a6bf761 --- /dev/null +++ b/.github/workflows/test-studio.yml @@ -0,0 +1,86 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Validate Shark Studio + +on: + push: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + pull_request: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + workflow_dispatch: + +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-validate: + strategy: + fail-fast: true + matrix: + os: [nodai-ubuntu-builder-large] + suite: [cpu] #,cuda,vulkan] + python-version: ["3.11"] + include: + - os: nodai-ubuntu-builder-large + suite: lint + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v3 + + - name: Set Environment Variables + run: | + echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV + echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + + - name: Set up Python Version File ${{ matrix.python-version }} + run: | + echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: '${{ matrix.python-version }}' + + - name: Install dependencies + if: matrix.suite == 'lint' + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest toml black + + - name: Lint with flake8 + if: matrix.suite == 'lint' + run: | + # black format check + black --version + black --check apps/shark_studio + # stop the build if there are Python syntax errors or undefined names + flake8 . --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --exclude lit.cfg.py + + - name: Validate Models on CPU + if: matrix.suite == 'cpu' + run: | + cd $GITHUB_WORKSPACE + python${{ matrix.python-version }} -m venv shark.venv + source shark.venv/bin/activate + pip install -r requirements.txt --no-cache-dir + pip install -e . + pip uninstall -y torch + pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python apps/shark_studio/tests/api_test.py diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 9e92e58cb5..1a03b817ff 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,9 +1,16 @@ from turbine_models.custom_models import stateless_llama -from shark.iree_utils.compile_utils import get_iree_compiled_module +import time +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, +) from apps.shark_studio.api.utils import get_resource_path import iree.runtime as ireert +from itertools import chain import gc +import os import torch +from transformers import AutoTokenizer llm_model_map = { "llama2_7b": { @@ -11,81 +18,161 @@ "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", "stop_token": 2, "max_tokens": 4096, - } + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, + "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "stop_token": 2, + "max_tokens": 4096, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, } class LanguageModel: def __init__( - self, model_name, hf_auth_token=None, device=None, precision="fp32" + self, + model_name, + hf_auth_token=None, + device=None, + precision="fp32", + external_weights=None, + use_system_prompt=True, ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.torch_ir, self.tokenizer = llm_model_map[model_name][ - "initializer" - ](self.hf_model_name, hf_auth_token, compile_to="torch") self.tempfile_name = get_resource_path("llm.torch.tempfile") - with open(self.tempfile_name, "w+") as f: - f.write(self.torch_ir) - del self.torch_ir - gc.collect() - + self.vmfb_name = get_resource_path("llm.vmfb.tempfile") self.device = device self.precision = precision + self.safe_name = self.hf_model_name.strip("/").replace("/", "_") self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None - self.compile() + self.external_weight_file = None + if external_weights is not None: + self.external_weight_file = get_resource_path( + self.safe_name + "." + external_weights + ) + self.use_system_prompt = use_system_prompt + self.global_iter = 0 + if os.path.exists(self.vmfb_name) and ( + external_weights is None or os.path.exists(str(self.external_weight_file)) + ): + self.iree_module_dict = dict() + ( + self.iree_module_dict["vmfb"], + self.iree_module_dict["config"], + self.iree_module_dict["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + self.vmfb_name, + device, + device_idx=0, + rt_flags=[], + external_weight_file=self.external_weight_file, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + elif not os.path.exists(self.tempfile_name): + self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( + self.hf_model_name, + hf_auth_token, + compile_to="torch", + external_weights=external_weights, + external_weight_file=self.external_weight_file, + ) + with open(self.tempfile_name, "w+") as f: + f.write(self.torch_ir) + del self.torch_ir + gc.collect() + self.compile() + else: + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + self.compile() def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". self.iree_module_dict = get_iree_compiled_module( - self.tempfile_name, device=self.device, frontend="torch" + self.tempfile_name, + device=self.device, + mmap=True, + frontend="torch", + external_weight_file=self.external_weight_file, + write_to=self.vmfb_name, ) # TODO: delete the temp file + def sanitize_prompt(self, prompt): + print(prompt) + if isinstance(prompt, list): + prompt = list(chain.from_iterable(prompt)) + prompt = " ".join([x for x in prompt if isinstance(x, str)]) + prompt = prompt.replace("\n", " ") + prompt = prompt.replace("\t", " ") + prompt = prompt.replace("\r", " ") + if self.use_system_prompt and self.global_iter == 0: + prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt + prompt += " [/INST]" + print(prompt) + return prompt + def chat(self, prompt): + prompt = self.sanitize_prompt(prompt) + + input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids + + def format_out(results): + return torch.tensor(results.to_host()[0][0]) + history = [] for iter in range(self.max_tokens): - input_tensor = self.tokenizer( - prompt, return_tensors="pt" - ).input_ids - device_inputs = [ - ireert.asdevicearray( - self.iree_module_dict["config"], input_tensor - ) - ] + st_time = time.time() if iter == 0: - token = torch.tensor( - self.iree_module_dict["vmfb"]["run_initialize"]( - *device_inputs - ).to_host()[0][0] - ) + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"].device, input_tensor + ) + ] + token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) else: - token = torch.tensor( - self.iree_module_dict["vmfb"]["run_forward"]( - *device_inputs - ).to_host()[0][0] - ) + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"].device, + token, + ) + ] + token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) - history.append(token) - yield self.tokenizer.decode(history) + total_time = time.time() - st_time + history.append(format_out(token)) + yield self.tokenizer.decode(history), total_time - if token == llm_model_map["llama2_7b"]["stop_token"]: + if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: history[i] = int(history[i]) result_output = self.tokenizer.decode(history) - yield result_output + self.global_iter += 1 + return result_output, total_time if __name__ == "__main__": lm = LanguageModel( - "llama2_7b", - hf_auth_token="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", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, device="cpu-task", + external_weights="safetensors", ) + print("model loaded") - for i in lm.chat("Hello, I am a robot."): + for i in lm.chat("hi, what are you?"): print(i) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index bb5e150364..4072491cbf 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -8,7 +8,5 @@ def get_available_devices(): def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py new file mode 100644 index 0000000000..c88a1e70cb --- /dev/null +++ b/apps/shark_studio/tests/api_test.py @@ -0,0 +1,34 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest +from apps.shark_studio.api.llm import LanguageModel + + +class LLMAPITest(unittest.TestCase): + def testLLMSimple(self): + lm = LanguageModel( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + device="cpu-task", + external_weights="safetensors", + ) + count = 0 + for msg, _ in lm.chat("hi, what are you?"): + # skip first token output + if count == 0: + count += 1 + continue + assert ( + msg.strip(" ") == "Hello" + ), f"LLM API failed to return correct response, expected 'Hello', received {msg}" + break + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 59b66bee23..3ef6bc5739 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -93,9 +93,7 @@ def launch_app(address): def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) dark_theme = resource_path("ui/css/sd_dark_theme.css") @@ -201,7 +199,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): ) with gr.Blocks( - css=dark_theme, analytics_enabled=False, title="Stable Diffusion" + css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta" ) as sd_web: with gr.Tabs() as tabs: # NOTE: If adding, removing, or re-ordering tabs, make sure that they diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index dd1c2d94e3..4726eef6e8 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -1,4 +1,5 @@ import gradio as gr +import time import os from pathlib import Path from datetime import datetime as dt @@ -21,104 +22,12 @@ def user(message, history): language_model = None -# NOTE: Each `model_name` should have its own start message -start_message = { - "llama2_7b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_13b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_70b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "vicuna": ( - "A chat between a curious user and an artificial intelligence " - "assistant. The assistant gives helpful, detailed, and " - "polite answers to the user's questions.\n" - ), -} - - def create_prompt(model_name, history, prompt_prefix): return "" - system_message = "" - if prompt_prefix: - system_message = start_message[model_name] - - if "llama2" in model_name: - B_INST, E_INST = "[INST]", "[/INST]" - B_SYS, E_SYS = "<>\n", "\n<>\n\n" - conversation = "".join( - [f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]] - ) - if prompt_prefix: - msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}" - else: - msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}" - elif model_name in ["vicuna"]: - conversation = "".join( - [ - "".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) - for item in history - ] - ) - msg = system_message + conversation - msg = msg.strip() - else: - conversation = "".join( - ["".join([item[0], item[1]]) for item in history] - ) - msg = system_message + conversation - msg = msg.strip() - return msg def get_default_config(): return False - import torch - from transformers import AutoTokenizer - - hf_model_path = "TheBloke/vicuna-7B-1.1-HF" - tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) - compilation_prompt = "".join(["0" for _ in range(17)]) - compilation_input_ids = tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor(compilation_input_ids).reshape( - [1, 19] - ) - firstVicunaCompileInput = (compilation_input_ids,) - from apps.language_models.src.model_wrappers.vicuna_model import ( - CombinedModel, - ) - from shark.shark_generate_model_config import GenerateConfigFile - - model = CombinedModel() - c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) - c.split_into_layers() # model_vmfb_key = "" @@ -133,153 +42,37 @@ def chat_fn( download_vmfb, config_file, cli=False, - progress=gr.Progress(), ): global language_model if language_model is None: + history[-1][-1] = "Getting the model ready..." + yield history, "" language_model = LanguageModel( - model, device=device, precision=precision - ) - - language_model.chat(prompt_prefix) - return "", "" - global past_key_values - global model_vmfb_key - - device_id = None - model_name, model_path = list(map(str.strip, model.split("=>"))) - if "cuda" in device: - device = "cuda" - elif "sync" in device: - device = "cpu-sync" - elif "task" in device: - device = "cpu-task" - elif "vulkan" in device: - device_id = int(device.split("://")[1]) - device = "vulkan" - elif "rocm" in device: - device = "rocm" - else: - print("unrecognized device") - - from apps.language_models.scripts.vicuna import ShardedVicuna - from apps.language_models.scripts.vicuna import UnshardedVicuna - from apps.stable_diffusion.src import args - - new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}" - if vicuna_model is None or new_model_vmfb_key != model_vmfb_key: - model_vmfb_key = new_model_vmfb_key - max_toks = 128 if model_name == "codegen" else 512 - - # get iree flags that need to be overridden, from commandline args - _extra_args = [] - # vulkan target triple - vulkan_target_triple = args.iree_vulkan_target_triple - from shark.iree_utils.vulkan_utils import ( - get_all_vulkan_devices, - get_vulkan_target_triple, + model, + device=device, + precision=precision, + external_weights="safetensors", + external_weight_file="llama2_7b.safetensors", + use_system_prompt=prompt_prefix, ) - - if device == "vulkan": - vulkaninfo_list = get_all_vulkan_devices() - if vulkan_target_triple == "": - # We already have the device_id extracted via WebUI, so we directly use - # that to find the target triple. - vulkan_target_triple = get_vulkan_target_triple( - vulkaninfo_list[device_id] - ) - _extra_args.append( - f"-iree-vulkan-target-triple={vulkan_target_triple}" - ) - if "rdna" in vulkan_target_triple: - flags_to_add = [ - "--iree-spirv-index-bits=64", - ] - _extra_args = _extra_args + flags_to_add - - if device_id is None: - id = 0 - for device in vulkaninfo_list: - target_triple = get_vulkan_target_triple( - vulkaninfo_list[id] - ) - if target_triple == vulkan_target_triple: - device_id = id - break - id += 1 - - assert ( - device_id - ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" - print(f"Will use vulkan target triple : {vulkan_target_triple}") - - elif "rocm" in device: - # add iree rocm flags - _extra_args.append( - f"--iree-rocm-target-chip={args.iree_rocm_target_chip}" - ) - print(f"extra args = {_extra_args}") - - if model_name == "vicuna4": - vicuna_model = ShardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - compressed=True, - extra_args_cmd=_extra_args, - ) - else: - # if config_file is None: - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - hf_auth_token=args.hf_auth_token, - device=device, - vulkan_target_triple=vulkan_target_triple, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=download_vmfb, - load_mlir_from_shark_tank=True, - extra_args_cmd=_extra_args, - device_id=device_id, - ) - - if vicuna_model is None: - sys.exit("Unable to instantiate the model object, exiting.") - - prompt = create_prompt(model_name, history, prompt_prefix) - - partial_text = "" + history[-1][-1] = "Getting the model ready... Done" + yield history, "" + history[-1][-1] = "" token_count = 0 - total_time_ms = 0.001 # In order to avoid divide by zero error + total_time = 0.001 # In order to avoid divide by zero error prefill_time = 0 is_first = True - for text, msg, exec_time in progress.tqdm( - vicuna_model.generate(prompt, cli=cli), - desc="generating response", - ): - if msg is None: - if is_first: - prefill_time = exec_time - is_first = False - else: - total_time_ms += exec_time - token_count += 1 - partial_text += text + " " - history[-1][1] = partial_text + for text, exec_time in language_model.chat(history): + history[-1][-1] = text + if is_first: + prefill_time = exec_time + is_first = False yield history, f"Prefill: {prefill_time:.2f}" - elif "formatted" in msg: - history[-1][1] = text - tokens_per_sec = (token_count / total_time_ms) * 1000 - yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" else: - sys.exit( - "unexpected message from the vicuna generate call, exiting." - ) - - return history, "" + total_time += exec_time + token_count += 1 + tokens_per_sec = token_count / total_time + yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" def llm_chat_api(InputData: dict): @@ -297,17 +90,11 @@ def llm_chat_api(InputData: dict): # print(f"prompt : {InputData['prompt']}") # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now global vicuna_model - model_name = ( - InputData["model"] if "model" in InputData.keys() else "codegen" - ) + model_name = InputData["model"] if "model" in InputData.keys() else "codegen" model_path = llm_model_map[model_name] device = "cpu-task" precision = "fp16" - max_toks = ( - None - if "max_tokens" not in InputData.keys() - else InputData["max_tokens"] - ) + max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] if max_toks is None: max_toks = 128 if model_name == "codegen" else 512 @@ -344,9 +131,7 @@ def llm_chat_api(InputData: dict): # TODO: add role dict for different models if is_chat_completion_api: # TODO: add funtionality for multiple messages - prompt = create_prompt( - model_name, [(InputData["messages"][0]["content"], "")] - ) + prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) else: prompt = InputData["prompt"] print("prompt = ", prompt) @@ -379,9 +164,7 @@ def llm_chat_api(InputData: dict): end_time = dt.now().strftime("%Y%m%d%H%M%S%f") return { "id": end_time, - "object": "chat.completion" - if is_chat_completion_api - else "text_completion", + "object": "chat.completion" if is_chat_completion_api else "text_completion", "created": int(end_time), "choices": choices, } @@ -457,9 +240,7 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): - config_file = gr.File( - label="Upload sharding configuration", visible=False - ) + config_file = gr.File(label="Upload sharding configuration", visible=False) json_view_button = gr.Button(label="View as JSON", visible=False) json_view = gr.JSON(interactive=True, visible=False) json_view_button.click( diff --git a/build_tools/stable_diffusion_testing.py b/build_tools/stable_diffusion_testing.py index ced919732c..8eeb1a7395 100644 --- a/build_tools/stable_diffusion_testing.py +++ b/build_tools/stable_diffusion_testing.py @@ -36,9 +36,7 @@ def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir): metrics[val] = line.split(" ")[-1].strip("\n") metrics["Average step"] = metrics["Average step"].strip("ms/it") - metrics["Total image generation"] = metrics[ - "Total image generation" - ].strip("sec") + metrics["Total image generation"] = metrics["Total image generation"].strip("sec") metrics["device"] = device metrics["use_tune"] = use_tune metrics["model_name"] = model_name @@ -84,10 +82,14 @@ def test_loop( ] import_options = ["--import_mlir", "--no-import_mlir"] prompt_text = "--prompt=cyberpunk forest by Salvador Dali" - inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + inpaint_prompt_text = ( + "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + ) if os.name == "nt": prompt_text = '--prompt="cyberpunk forest by Salvador Dali"' - inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + inpaint_prompt_text = ( + '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + ) if beta: extra_flags.append("--beta_models=True") extra_flags.append("--no-progress_bar") @@ -174,9 +176,7 @@ def test_loop( ) print(command) print("Successfully generated image") - os.makedirs( - "./test_images/golden/" + model_name, exist_ok=True - ) + os.makedirs("./test_images/golden/" + model_name, exist_ok=True) download_public_file( "gs://shark_tank/testdata/golden/" + model_name, "./test_images/golden/" + model_name, @@ -191,14 +191,10 @@ def test_loop( ) test_file = glob(test_file_path)[0] - golden_path = ( - "./test_images/golden/" + model_name + "/*.png" - ) + golden_path = "./test_images/golden/" + model_name + "/*.png" golden_file = glob(golden_path)[0] try: - compare_images( - test_file, golden_file, upload=upload_bool - ) + compare_images(test_file, golden_file, upload=upload_bool) except AssertionError as e: print(e) if exit_on_fail == True: @@ -267,9 +263,7 @@ def prepare_artifacts(): parser.add_argument( "-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True ) -parser.add_argument( - "-g", "--gen", action=argparse.BooleanOptionalAction, default=False -) +parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False) if __name__ == "__main__": args = parser.parse_args() diff --git a/dataset/annotation_tool.py b/dataset/annotation_tool.py index 8c8c85cdfd..60f607146d 100644 --- a/dataset/annotation_tool.py +++ b/dataset/annotation_tool.py @@ -10,9 +10,7 @@ shark_root = Path(__file__).parent.parent demo_css = shark_root.joinpath("web/demo.css").resolve() -nodlogo_loc = shark_root.joinpath( - "web/models/stable_diffusion/logos/nod-logo.png" -) +nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png") with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web: @@ -76,9 +74,7 @@ def filter_datasets(dataset): with jsonlines.open(dataset_path + "/metadata.jsonl") as reader: for line in reader.iter(type=dict, skip_invalid=True): prompt_data[line["file_name"]] = ( - [line["text"]] - if type(line["text"]) is str - else line["text"] + [line["text"]] if type(line["text"]) is str else line["text"] ) return gr.Dropdown.update(choices=images[dataset]) @@ -104,9 +100,7 @@ def display_image(dataset, image_name): prompt_data[image_name] = [] prompt_choices = ["Add new"] prompt_choices += prompt_data[image_name] - return gr.Image.update(value=img), gr.Dropdown.update( - choices=prompt_choices - ) + return gr.Image.update(value=img), gr.Dropdown.update(choices=prompt_choices) image_name.change( fn=display_image, @@ -123,12 +117,7 @@ def edit_prompt(prompts): prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt) def save_prompt(dataset, image_name, prompts, prompt): - if ( - dataset is None - or image_name is None - or prompts is None - or prompt is None - ): + if dataset is None or image_name is None or prompts is None or prompt is None: return if prompts == "Add new": @@ -137,9 +126,7 @@ def save_prompt(dataset, image_name, prompts, prompt): idx = prompt_data[image_name].index(prompts) prompt_data[image_name][idx] = prompt - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -166,9 +153,7 @@ def delete_prompt(dataset, image_name, prompts): return prompt_data[image_name].remove(prompts) - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -231,9 +216,7 @@ def finish_annotation(dataset): # upload prompt and remove local data dataset_path = str(shark_root) + "/dataset/" + dataset dataset_gs_path = args.gs_url + "/" + dataset + "/" - os.system( - f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"' - ) + os.system(f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"') os.system(f'rm -rf "{dataset_path}"') return gr.Dropdown.update(value=None) diff --git a/process_skipfiles.py b/process_skipfiles.py index a846159451..339c7ebec6 100644 --- a/process_skipfiles.py +++ b/process_skipfiles.py @@ -8,8 +8,7 @@ # Temporary workaround for transformers/__init__.py. path_to_transformers_hook = Path( - get_python_lib() - + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" + get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" ) if path_to_transformers_hook.is_file(): pass @@ -59,9 +58,7 @@ # For getting around timm's packaging. # Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505 -path_to_timm_activations = Path( - get_python_lib() + "/timm/layers/activations_jit.py" -) +path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py") for line in fileinput.input(path_to_timm_activations, inplace=True): if "@torch.jit.script" in line: print("@torch.jit._script_if_tracing", end="\n") diff --git a/pyproject.toml b/pyproject.toml index 22e0210c50..876df2f8bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,14 +5,25 @@ requires = [ "packaging", "numpy>=1.22.4", - "torch-mlir>=20230620.875", "iree-compiler>=20221022.190", "iree-runtime>=20221022.190", ] build-backend = "setuptools.build_meta" [tool.black] -line-length = 79 include = '\.pyi?$' -exclude = "apps/language_models/scripts/vicuna.py" -extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py" +exclude = ''' +( + /( + | apps/stable_diffusion + | apps/language_models + | shark + | benchmarks + | tank + | build + | generated_imgs + | shark.venv + )/ + | setup.py +) +''' diff --git a/pytest.ini b/pytest.ini index 11f57888b2..3857248785 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] addopts = --verbose -s -p no:warnings -norecursedirs = inference tank/tflite examples benchmarks shark +norecursedirs = inference tank/tflite examples benchmarks shark apps/shark_studio diff --git a/requirements.txt b/requirements.txt index ff649a4468..3f7e719e67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,13 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +-f https://openxla.github.io/iree/pip-release-links.html --pre setuptools wheel +shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine#egg=turbine-models&subdirectory=python/turbine_models + # SHARK Runner tqdm @@ -17,11 +21,7 @@ pytest-forked Pillow parameterized -#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main # Add transformers, diffusers and scipy since it most commonly used -tokenizers==0.13.3 -transformers -diffusers #accelerate is now required for diffusers import from ckpt. accelerate scipy @@ -49,9 +49,6 @@ pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions pefile pyinstaller -# vicuna quantization -brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea - # For quantized GPTQ models optimum auto_gptq diff --git a/rest_api_tests/api_test.py b/rest_api_tests/api_test.py index 7a4cf042c2..f3c0b0e170 100644 --- a/rest_api_tests/api_test.py +++ b/rest_api_tests/api_test.py @@ -44,14 +44,10 @@ def upscaler_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[upscaler] response from server was : {res.status_code} {res.reason}" - ) + print(f"[upscaler] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def img2img_test(verbose=False): @@ -96,14 +92,10 @@ def img2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[img2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[img2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") # NOTE Uncomment below to save the picture @@ -133,13 +125,9 @@ def inpainting_test(verbose=False): image_path = r"./rest_api_tests/dog.png" img_file = open(image_path, "rb") - image = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + image = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() img_file = open(image_path, "rb") - mask = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + mask = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() url = "http://127.0.0.1:8080/sdapi/v1/inpaint" @@ -166,14 +154,10 @@ def inpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[inpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[inpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def outpainting_test(verbose=False): @@ -223,14 +207,10 @@ def outpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[outpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[outpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def txt2img_test(verbose=False): @@ -262,14 +242,10 @@ def txt2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[txt2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[txt2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def sd_models_test(verbose=False): @@ -283,9 +259,7 @@ def sd_models_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_models] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_models] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -302,9 +276,7 @@ def sd_samplers_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_samplers] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_samplers] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -321,9 +293,7 @@ def options_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[options] response from server was : {res.status_code} {res.reason}" - ) + print(f"[options] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -340,9 +310,7 @@ def cmd_flags_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[cmd-flags] response from server was : {res.status_code} {res.reason}" - ) + print(f"[cmd-flags] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") diff --git a/setup.py b/setup.py index c387fe9add..061873e7a8 100644 --- a/setup.py +++ b/setup.py @@ -9,11 +9,6 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5" backend_deps = [] -if "NO_BACKEND" in os.environ.keys(): - backend_deps = [ - "iree-compiler>=20221022.190", - "iree-runtime>=20221022.190", - ] setup( name="nodai-SHARK", @@ -39,7 +34,5 @@ install_requires=[ "numpy", "PyYAML", - "torch-mlir", ] - + backend_deps, ) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 6cfe369426..bae1908e1c 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -305,6 +305,7 @@ def compile_module_to_flatbuffer( model_name="None", debug=False, compile_str=False, + write_to=None, ): # Setup Compile arguments wrt to frontends. input_type = "auto" @@ -342,12 +343,24 @@ def compile_module_to_flatbuffer( extra_args=args, ) + if write_to is not None: + with open(write_to, "wb") as f: + f.write(flatbuffer_blob) + return None + return flatbuffer_blob def get_iree_module( - flatbuffer_blob, device, device_idx=None, rt_flags: list = [] + flatbuffer_blob, + device, + device_idx=None, + rt_flags: list = [], + external_weight_file=None, ): + if external_weight_file is not None: + index = ireert.ParameterIndex() + index.load(external_weight_file) # Returns the compiled module and the configs. for flag in rt_flags: ireert.flags.parse_flag(flag) @@ -369,7 +382,10 @@ def get_iree_module( vm_module = ireert.VmModule.from_buffer( config.vm_instance, flatbuffer_blob, warn_if_copy=False ) - ctx = ireert.SystemContext(config=config) + modules = [] + if external_weight_file is not None: + modules.append(index.create_provider(scope="model")) + ctx = ireert.SystemContext(vm_modules=modules, config=config) ctx.add_vm_module(vm_module) ModuleCompiled = getattr(ctx.modules, vm_module.name) return ModuleCompiled, config @@ -380,6 +396,7 @@ def load_vmfb_using_mmap( device: str, device_idx: int = None, rt_flags: list = [], + external_weight_file: str = None, ): print(f"Loading module {flatbuffer_blob_or_path}...") if "task" in device: @@ -440,17 +457,28 @@ def load_vmfb_using_mmap( mmaped_vmfb = ireert.VmModule.mmap( config.vm_instance, flatbuffer_blob_or_path ) + vm_modules = [] + if external_weight_file is not None: + index = ireert.ParameterIndex() + index.load(external_weight_file) + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.append(param_module) + vm_modules.append(mmaped_vmfb) + vm_modules.append( + ireert.create_hal_module(config.vm_instance, config.device) + ) dl.log(f"mmap {flatbuffer_blob_or_path}") - ctx = ireert.SystemContext(config=config) - for flag in shark_args.additional_runtime_args: - ireert.flags.parse_flags(flag) - dl.log(f"ireert.SystemContext created") if "vulkan" in device: # Vulkan pipeline creation consumes significant amount of time. print( "\tCompiling Vulkan shaders. This may take a few minutes." ) - ctx.add_vm_module(mmaped_vmfb) + ctx = ireert.SystemContext(config=config, vm_modules=vm_modules) + dl.log(f"ireert.SystemContext created") + for flag in shark_args.additional_runtime_args: + ireert.flags.parse_flags(flag) dl.log(f"module initialized") mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) else: @@ -475,6 +503,8 @@ def get_iree_compiled_module( mmap: bool = False, debug: bool = False, compile_str: bool = False, + external_weight_file: str = None, + write_to: bool = None, ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( @@ -485,6 +515,7 @@ def get_iree_compiled_module( extra_args=extra_args, debug=debug, compile_str=compile_str, + write_to=write_to, ) temp_file_to_unlink = None # TODO: Currently mmap=True control flow path has been switched off for mmap. @@ -492,8 +523,14 @@ def get_iree_compiled_module( # we're setting delete=False when creating NamedTemporaryFile. That's why # I'm getting hold of the name of the temporary file in `temp_file_to_unlink`. if mmap: + if write_to is not None: + flatbuffer_blob = write_to vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( - flatbuffer_blob, device, device_idx, rt_flags + flatbuffer_blob, + device, + device_idx, + rt_flags, + external_weight_file=external_weight_file, ) else: vmfb, config = get_iree_module( @@ -501,6 +538,7 @@ def get_iree_compiled_module( device, device_idx=device_idx, rt_flags=rt_flags, + external_weight_file=external_weight_file, ) ret_params = { "vmfb": vmfb,