diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml deleted file mode 100644 index 8e9809ee41..0000000000 --- a/.github/workflows/test-models.yml +++ /dev/null @@ -1,164 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Validate Models on Shark Runtime - -on: - push: - branches: [ main ] - paths-ignore: - - '**.md' - - 'shark/examples/**' - pull_request: - branches: [ main ] - paths-ignore: - - '**.md' - - 'shark/examples/**' - workflow_dispatch: - -# Ensure that only a single job or workflow using the same -# concurrency group will run at a time. This would cancel -# any in-progress jobs in the same github workflow and github -# ref (e.g. refs/heads/main or refs/pull//merge). -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - build-validate: - strategy: - fail-fast: true - matrix: - os: [7950x, icelake, a100, MacStudio, ubuntu-latest] - suite: [cpu,cuda,vulkan] - python-version: ["3.11"] - include: - - os: ubuntu-latest - suite: lint - - os: MacStudio - suite: metal - exclude: - - os: ubuntu-latest - suite: vulkan - - os: ubuntu-latest - suite: cuda - - os: ubuntu-latest - suite: cpu - - os: MacStudio - suite: cuda - - os: MacStudio - suite: cpu - - os: MacStudio - suite: vulkan - - os: icelake - suite: vulkan - - os: icelake - suite: cuda - - os: a100 - suite: cpu - - os: 7950x - suite: cpu - - os: 7950x - suite: cuda - - runs-on: ${{ matrix.os }} - - steps: - - uses: actions/checkout@v3 - - - name: Set Environment Variables - if: matrix.os != '7950x' - run: | - echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV - echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - - - name: Set up Python Version File ${{ matrix.python-version }} - if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake' - run: | - # See https://github.com/actions/setup-python/issues/433 - echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version - - - name: Set up Python ${{ matrix.python-version }} - if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake' - uses: actions/setup-python@v4 - with: - python-version: '${{ matrix.python-version }}' - #cache: 'pip' - #cache-dependency-path: | - # **/requirements-importer.txt - # **/requirements.txt - - - name: Install dependencies - if: matrix.suite == 'lint' - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest toml black - - - name: Lint with flake8 - if: matrix.suite == 'lint' - run: | - # black format check - black --version - black --check . - # stop the build if there are Python syntax errors or undefined names - flake8 . --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \ - --statistics --exclude lit.cfg.py - - - name: Validate Models on CPU - if: matrix.suite == 'cpu' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --benchmark=native --update_tank -k cpu - gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv - gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv - python build_tools/vicuna_testing.py - - - name: Validate Models on NVIDIA GPU - if: matrix.suite == 'cuda' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --benchmark=native --update_tank -k cuda - gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv - gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv - # Disabled due to black image bug - # python build_tools/stable_diffusion_testing.py --device=cuda - - - name: Validate Vulkan Models (MacOS) - if: matrix.suite == 'metal' && matrix.os == 'MacStudio' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} ./setup_venv.sh - source shark.venv/bin/activate - echo $PATH - pip list | grep -E "torch|iree" - # disabled due to a low-visibility memory issue with pytest on macos. - # pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal - - - name: Validate Vulkan Models (a100) - if: matrix.suite == 'vulkan' && matrix.os == 'a100' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --update_tank -k vulkan - python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail - - - name: Validate Vulkan Models (Windows) - if: matrix.suite == 'vulkan' && matrix.os == '7950x' - run: | - ./setup_venv.ps1 - pytest -k vulkan -s --ci - - - name: Validate Stable Diffusion Models (Windows) - if: matrix.suite == 'vulkan' && matrix.os == '7950x' - run: | - ./setup_venv.ps1 - python process_skipfiles.py - pyinstaller .\apps\stable_diffusion\shark_sd.spec - python build_tools/stable_diffusion_testing.py --device=vulkan diff --git a/.github/workflows/test-studio.yml b/.github/workflows/test-studio.yml new file mode 100644 index 0000000000..765a6bf761 --- /dev/null +++ b/.github/workflows/test-studio.yml @@ -0,0 +1,86 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Validate Shark Studio + +on: + push: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + pull_request: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + workflow_dispatch: + +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-validate: + strategy: + fail-fast: true + matrix: + os: [nodai-ubuntu-builder-large] + suite: [cpu] #,cuda,vulkan] + python-version: ["3.11"] + include: + - os: nodai-ubuntu-builder-large + suite: lint + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v3 + + - name: Set Environment Variables + run: | + echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV + echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + + - name: Set up Python Version File ${{ matrix.python-version }} + run: | + echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: '${{ matrix.python-version }}' + + - name: Install dependencies + if: matrix.suite == 'lint' + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest toml black + + - name: Lint with flake8 + if: matrix.suite == 'lint' + run: | + # black format check + black --version + black --check apps/shark_studio + # stop the build if there are Python syntax errors or undefined names + flake8 . --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --exclude lit.cfg.py + + - name: Validate Models on CPU + if: matrix.suite == 'cpu' + run: | + cd $GITHUB_WORKSPACE + python${{ matrix.python-version }} -m venv shark.venv + source shark.venv/bin/activate + pip install -r requirements.txt --no-cache-dir + pip install -e . + pip uninstall -y torch + pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python apps/shark_studio/tests/api_test.py diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 9e92e58cb5..1a03b817ff 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,9 +1,16 @@ from turbine_models.custom_models import stateless_llama -from shark.iree_utils.compile_utils import get_iree_compiled_module +import time +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, +) from apps.shark_studio.api.utils import get_resource_path import iree.runtime as ireert +from itertools import chain import gc +import os import torch +from transformers import AutoTokenizer llm_model_map = { "llama2_7b": { @@ -11,81 +18,161 @@ "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", "stop_token": 2, "max_tokens": 4096, - } + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, + "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "stop_token": 2, + "max_tokens": 4096, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, } class LanguageModel: def __init__( - self, model_name, hf_auth_token=None, device=None, precision="fp32" + self, + model_name, + hf_auth_token=None, + device=None, + precision="fp32", + external_weights=None, + use_system_prompt=True, ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.torch_ir, self.tokenizer = llm_model_map[model_name][ - "initializer" - ](self.hf_model_name, hf_auth_token, compile_to="torch") self.tempfile_name = get_resource_path("llm.torch.tempfile") - with open(self.tempfile_name, "w+") as f: - f.write(self.torch_ir) - del self.torch_ir - gc.collect() - + self.vmfb_name = get_resource_path("llm.vmfb.tempfile") self.device = device self.precision = precision + self.safe_name = self.hf_model_name.strip("/").replace("/", "_") self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None - self.compile() + self.external_weight_file = None + if external_weights is not None: + self.external_weight_file = get_resource_path( + self.safe_name + "." + external_weights + ) + self.use_system_prompt = use_system_prompt + self.global_iter = 0 + if os.path.exists(self.vmfb_name) and ( + external_weights is None or os.path.exists(str(self.external_weight_file)) + ): + self.iree_module_dict = dict() + ( + self.iree_module_dict["vmfb"], + self.iree_module_dict["config"], + self.iree_module_dict["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + self.vmfb_name, + device, + device_idx=0, + rt_flags=[], + external_weight_file=self.external_weight_file, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + elif not os.path.exists(self.tempfile_name): + self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( + self.hf_model_name, + hf_auth_token, + compile_to="torch", + external_weights=external_weights, + external_weight_file=self.external_weight_file, + ) + with open(self.tempfile_name, "w+") as f: + f.write(self.torch_ir) + del self.torch_ir + gc.collect() + self.compile() + else: + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + self.compile() def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". self.iree_module_dict = get_iree_compiled_module( - self.tempfile_name, device=self.device, frontend="torch" + self.tempfile_name, + device=self.device, + mmap=True, + frontend="torch", + external_weight_file=self.external_weight_file, + write_to=self.vmfb_name, ) # TODO: delete the temp file + def sanitize_prompt(self, prompt): + print(prompt) + if isinstance(prompt, list): + prompt = list(chain.from_iterable(prompt)) + prompt = " ".join([x for x in prompt if isinstance(x, str)]) + prompt = prompt.replace("\n", " ") + prompt = prompt.replace("\t", " ") + prompt = prompt.replace("\r", " ") + if self.use_system_prompt and self.global_iter == 0: + prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt + prompt += " [/INST]" + print(prompt) + return prompt + def chat(self, prompt): + prompt = self.sanitize_prompt(prompt) + + input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids + + def format_out(results): + return torch.tensor(results.to_host()[0][0]) + history = [] for iter in range(self.max_tokens): - input_tensor = self.tokenizer( - prompt, return_tensors="pt" - ).input_ids - device_inputs = [ - ireert.asdevicearray( - self.iree_module_dict["config"], input_tensor - ) - ] + st_time = time.time() if iter == 0: - token = torch.tensor( - self.iree_module_dict["vmfb"]["run_initialize"]( - *device_inputs - ).to_host()[0][0] - ) + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"].device, input_tensor + ) + ] + token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) else: - token = torch.tensor( - self.iree_module_dict["vmfb"]["run_forward"]( - *device_inputs - ).to_host()[0][0] - ) + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"].device, + token, + ) + ] + token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) - history.append(token) - yield self.tokenizer.decode(history) + total_time = time.time() - st_time + history.append(format_out(token)) + yield self.tokenizer.decode(history), total_time - if token == llm_model_map["llama2_7b"]["stop_token"]: + if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: history[i] = int(history[i]) result_output = self.tokenizer.decode(history) - yield result_output + self.global_iter += 1 + return result_output, total_time if __name__ == "__main__": lm = LanguageModel( - "llama2_7b", - hf_auth_token="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", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, device="cpu-task", + external_weights="safetensors", ) + print("model loaded") - for i in lm.chat("Hello, I am a robot."): + for i in lm.chat("hi, what are you?"): print(i) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index bb5e150364..4072491cbf 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -8,7 +8,5 @@ def get_available_devices(): def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py new file mode 100644 index 0000000000..c88a1e70cb --- /dev/null +++ b/apps/shark_studio/tests/api_test.py @@ -0,0 +1,34 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest +from apps.shark_studio.api.llm import LanguageModel + + +class LLMAPITest(unittest.TestCase): + def testLLMSimple(self): + lm = LanguageModel( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + device="cpu-task", + external_weights="safetensors", + ) + count = 0 + for msg, _ in lm.chat("hi, what are you?"): + # skip first token output + if count == 0: + count += 1 + continue + assert ( + msg.strip(" ") == "Hello" + ), f"LLM API failed to return correct response, expected 'Hello', received {msg}" + break + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 59b66bee23..3ef6bc5739 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -93,9 +93,7 @@ def launch_app(address): def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) dark_theme = resource_path("ui/css/sd_dark_theme.css") @@ -201,7 +199,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): ) with gr.Blocks( - css=dark_theme, analytics_enabled=False, title="Stable Diffusion" + css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta" ) as sd_web: with gr.Tabs() as tabs: # NOTE: If adding, removing, or re-ordering tabs, make sure that they diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index dd1c2d94e3..4726eef6e8 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -1,4 +1,5 @@ import gradio as gr +import time import os from pathlib import Path from datetime import datetime as dt @@ -21,104 +22,12 @@ def user(message, history): language_model = None -# NOTE: Each `model_name` should have its own start message -start_message = { - "llama2_7b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_13b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_70b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "vicuna": ( - "A chat between a curious user and an artificial intelligence " - "assistant. The assistant gives helpful, detailed, and " - "polite answers to the user's questions.\n" - ), -} - - def create_prompt(model_name, history, prompt_prefix): return "" - system_message = "" - if prompt_prefix: - system_message = start_message[model_name] - - if "llama2" in model_name: - B_INST, E_INST = "[INST]", "[/INST]" - B_SYS, E_SYS = "<>\n", "\n<>\n\n" - conversation = "".join( - [f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]] - ) - if prompt_prefix: - msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}" - else: - msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}" - elif model_name in ["vicuna"]: - conversation = "".join( - [ - "".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) - for item in history - ] - ) - msg = system_message + conversation - msg = msg.strip() - else: - conversation = "".join( - ["".join([item[0], item[1]]) for item in history] - ) - msg = system_message + conversation - msg = msg.strip() - return msg def get_default_config(): return False - import torch - from transformers import AutoTokenizer - - hf_model_path = "TheBloke/vicuna-7B-1.1-HF" - tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) - compilation_prompt = "".join(["0" for _ in range(17)]) - compilation_input_ids = tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor(compilation_input_ids).reshape( - [1, 19] - ) - firstVicunaCompileInput = (compilation_input_ids,) - from apps.language_models.src.model_wrappers.vicuna_model import ( - CombinedModel, - ) - from shark.shark_generate_model_config import GenerateConfigFile - - model = CombinedModel() - c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) - c.split_into_layers() # model_vmfb_key = "" @@ -133,153 +42,37 @@ def chat_fn( download_vmfb, config_file, cli=False, - progress=gr.Progress(), ): global language_model if language_model is None: + history[-1][-1] = "Getting the model ready..." + yield history, "" language_model = LanguageModel( - model, device=device, precision=precision - ) - - language_model.chat(prompt_prefix) - return "", "" - global past_key_values - global model_vmfb_key - - device_id = None - model_name, model_path = list(map(str.strip, model.split("=>"))) - if "cuda" in device: - device = "cuda" - elif "sync" in device: - device = "cpu-sync" - elif "task" in device: - device = "cpu-task" - elif "vulkan" in device: - device_id = int(device.split("://")[1]) - device = "vulkan" - elif "rocm" in device: - device = "rocm" - else: - print("unrecognized device") - - from apps.language_models.scripts.vicuna import ShardedVicuna - from apps.language_models.scripts.vicuna import UnshardedVicuna - from apps.stable_diffusion.src import args - - new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}" - if vicuna_model is None or new_model_vmfb_key != model_vmfb_key: - model_vmfb_key = new_model_vmfb_key - max_toks = 128 if model_name == "codegen" else 512 - - # get iree flags that need to be overridden, from commandline args - _extra_args = [] - # vulkan target triple - vulkan_target_triple = args.iree_vulkan_target_triple - from shark.iree_utils.vulkan_utils import ( - get_all_vulkan_devices, - get_vulkan_target_triple, + model, + device=device, + precision=precision, + external_weights="safetensors", + external_weight_file="llama2_7b.safetensors", + use_system_prompt=prompt_prefix, ) - - if device == "vulkan": - vulkaninfo_list = get_all_vulkan_devices() - if vulkan_target_triple == "": - # We already have the device_id extracted via WebUI, so we directly use - # that to find the target triple. - vulkan_target_triple = get_vulkan_target_triple( - vulkaninfo_list[device_id] - ) - _extra_args.append( - f"-iree-vulkan-target-triple={vulkan_target_triple}" - ) - if "rdna" in vulkan_target_triple: - flags_to_add = [ - "--iree-spirv-index-bits=64", - ] - _extra_args = _extra_args + flags_to_add - - if device_id is None: - id = 0 - for device in vulkaninfo_list: - target_triple = get_vulkan_target_triple( - vulkaninfo_list[id] - ) - if target_triple == vulkan_target_triple: - device_id = id - break - id += 1 - - assert ( - device_id - ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" - print(f"Will use vulkan target triple : {vulkan_target_triple}") - - elif "rocm" in device: - # add iree rocm flags - _extra_args.append( - f"--iree-rocm-target-chip={args.iree_rocm_target_chip}" - ) - print(f"extra args = {_extra_args}") - - if model_name == "vicuna4": - vicuna_model = ShardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - compressed=True, - extra_args_cmd=_extra_args, - ) - else: - # if config_file is None: - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - hf_auth_token=args.hf_auth_token, - device=device, - vulkan_target_triple=vulkan_target_triple, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=download_vmfb, - load_mlir_from_shark_tank=True, - extra_args_cmd=_extra_args, - device_id=device_id, - ) - - if vicuna_model is None: - sys.exit("Unable to instantiate the model object, exiting.") - - prompt = create_prompt(model_name, history, prompt_prefix) - - partial_text = "" + history[-1][-1] = "Getting the model ready... Done" + yield history, "" + history[-1][-1] = "" token_count = 0 - total_time_ms = 0.001 # In order to avoid divide by zero error + total_time = 0.001 # In order to avoid divide by zero error prefill_time = 0 is_first = True - for text, msg, exec_time in progress.tqdm( - vicuna_model.generate(prompt, cli=cli), - desc="generating response", - ): - if msg is None: - if is_first: - prefill_time = exec_time - is_first = False - else: - total_time_ms += exec_time - token_count += 1 - partial_text += text + " " - history[-1][1] = partial_text + for text, exec_time in language_model.chat(history): + history[-1][-1] = text + if is_first: + prefill_time = exec_time + is_first = False yield history, f"Prefill: {prefill_time:.2f}" - elif "formatted" in msg: - history[-1][1] = text - tokens_per_sec = (token_count / total_time_ms) * 1000 - yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" else: - sys.exit( - "unexpected message from the vicuna generate call, exiting." - ) - - return history, "" + total_time += exec_time + token_count += 1 + tokens_per_sec = token_count / total_time + yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" def llm_chat_api(InputData: dict): @@ -297,17 +90,11 @@ def llm_chat_api(InputData: dict): # print(f"prompt : {InputData['prompt']}") # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now global vicuna_model - model_name = ( - InputData["model"] if "model" in InputData.keys() else "codegen" - ) + model_name = InputData["model"] if "model" in InputData.keys() else "codegen" model_path = llm_model_map[model_name] device = "cpu-task" precision = "fp16" - max_toks = ( - None - if "max_tokens" not in InputData.keys() - else InputData["max_tokens"] - ) + max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] if max_toks is None: max_toks = 128 if model_name == "codegen" else 512 @@ -344,9 +131,7 @@ def llm_chat_api(InputData: dict): # TODO: add role dict for different models if is_chat_completion_api: # TODO: add funtionality for multiple messages - prompt = create_prompt( - model_name, [(InputData["messages"][0]["content"], "")] - ) + prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) else: prompt = InputData["prompt"] print("prompt = ", prompt) @@ -379,9 +164,7 @@ def llm_chat_api(InputData: dict): end_time = dt.now().strftime("%Y%m%d%H%M%S%f") return { "id": end_time, - "object": "chat.completion" - if is_chat_completion_api - else "text_completion", + "object": "chat.completion" if is_chat_completion_api else "text_completion", "created": int(end_time), "choices": choices, } @@ -457,9 +240,7 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): - config_file = gr.File( - label="Upload sharding configuration", visible=False - ) + config_file = gr.File(label="Upload sharding configuration", visible=False) json_view_button = gr.Button(label="View as JSON", visible=False) json_view = gr.JSON(interactive=True, visible=False) json_view_button.click( diff --git a/apps/stable_diffusion/web/api/sdapi_v1.py b/apps/stable_diffusion/web/api/sdapi_v1.py index 3eebd5c113..f376f0fe9d 100644 --- a/apps/stable_diffusion/web/api/sdapi_v1.py +++ b/apps/stable_diffusion/web/api/sdapi_v1.py @@ -374,7 +374,8 @@ def inpaint_api( res = inpaint_inf( InputData.prompt, InputData.negative_prompt, - {"image": init_image, "mask": mask}, + init_image, + mask, InputData.height, InputData.width, InputData.is_full_res, diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 15ffd1b77a..5f7ac0bc65 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -106,6 +106,7 @@ def cleanup_mei_folders(): # It has to be in this order or gradio ignores what we've set up. from apps.stable_diffusion.web.utils.tmp_configs import ( config_tmp, + shark_tmp, ) config_tmp() @@ -115,6 +116,8 @@ def cleanup_mei_folders(): from apps.stable_diffusion.web.ui.utils import ( create_custom_models_folders, nodicon_loc, + mask_editor_value_for_gallery_data, + mask_editor_value_for_image_file, ) create_custom_models_folders() @@ -206,10 +209,20 @@ def resource_path(relative_path): # init global sd pipeline and config global_obj._init() - def register_button_click(button, selectedid, inputs, outputs): + def register_sendto_click(button, selectedid, inputs, outputs): button.click( lambda x: ( - x[0]["name"] if len(x) != 0 else None, + x.root[0].image.path if len(x.root) != 0 else None, + gr.Tabs(selected=selectedid), + ), + inputs, + outputs, + ) + + def register_sendto_editor_click(button, selectedid, inputs, outputs): + button.click( + lambda x: ( + mask_editor_value_for_gallery_data(x), gr.Tabs(selected=selectedid), ), inputs, @@ -225,9 +238,12 @@ def register_modelmanager_button(button, selectedid, inputs, outputs): ), inputs, outputs, + queue=False, ) - def register_outputgallery_button(button, selectedid, inputs, outputs): + def register_outputgallery_sendto_button( + button, selectedid, inputs, outputs + ): button.click( lambda x: ( x, @@ -237,6 +253,18 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): outputs, ) + def register_outputgallery_sendto_editor_button( + button, selectedid, inputs, outputs + ): + button.click( + lambda x: ( + mask_editor_value_for_image_file(x), + gr.Tabs(selected=selectedid), + ), + inputs, + outputs, + ) + dark_theme = resource_path("ui/css/sd_dark_theme.css") with gr.Blocks( @@ -265,19 +293,6 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): if args.output_gallery: with gr.TabItem(label="Output Gallery", id=5) as og_tab: outputgallery_web.render() - - # extra output gallery configuration - outputgallery_tab_select(og_tab.select) - outputgallery_watch( - [ - txt2img_status, - img2img_status, - inpaint_status, - outpaint_status, - upscaler_status, - txt2img_sdxl_status, - ] - ) # with gr.TabItem(label="Model Manager", id=6): # model_web.render() # with gr.TabItem(label="LoRA Training (Experimental)", id=7): @@ -297,6 +312,19 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): with gr.TabItem(label="Text-to-Image (SDXL)", id=13): txt2img_sdxl_web.render() + # extra output gallery configuration + outputgallery_tab_select(og_tab.select) + outputgallery_watch( + [ + txt2img_status, + img2img_status, + inpaint_status, + outpaint_status, + upscaler_status, + txt2img_sdxl_status, + ], + ) + actual_port = app.usable_port() if actual_port != args.server_port: sd_web.load( @@ -307,134 +335,134 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): ) # send to buttons - register_button_click( + register_sendto_click( txt2img_sendto_img2img, 1, [txt2img_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( txt2img_sendto_inpaint, 2, [txt2img_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( txt2img_sendto_outpaint, 3, [txt2img_gallery], [outpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( txt2img_sendto_upscaler, 4, [txt2img_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( img2img_sendto_inpaint, 2, [img2img_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( img2img_sendto_outpaint, 3, [img2img_gallery], [outpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( img2img_sendto_upscaler, 4, [img2img_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_click( inpaint_sendto_img2img, 1, [inpaint_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_click( inpaint_sendto_outpaint, 3, [inpaint_gallery], [outpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( inpaint_sendto_upscaler, 4, [inpaint_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_click( outpaint_sendto_img2img, 1, [outpaint_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( outpaint_sendto_inpaint, 2, [outpaint_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( outpaint_sendto_upscaler, 4, [outpaint_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_click( upscaler_sendto_img2img, 1, [upscaler_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( upscaler_sendto_inpaint, 2, [upscaler_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( upscaler_sendto_outpaint, 3, [upscaler_gallery], [outpaint_init_image, tabs], ) if args.output_gallery: - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_txt2img, 0, [outputgallery_filename], [txt2img_png_info_img, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_img2img, 1, [outputgallery_filename], [img2img_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_editor_button( outputgallery_sendto_inpaint, 2, [outputgallery_filename], [inpaint_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_outpaint, 3, [outputgallery_filename], [outpaint_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_upscaler, 4, [outputgallery_filename], [upscaler_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_txt2img_sdxl, 0, [outputgallery_filename], diff --git a/apps/stable_diffusion/web/ui/common_ui_events.py b/apps/stable_diffusion/web/ui/common_ui_events.py index 230619b61d..f467f6b0ed 100644 --- a/apps/stable_diffusion/web/ui/common_ui_events.py +++ b/apps/stable_diffusion/web/ui/common_ui_events.py @@ -1,3 +1,5 @@ +import gradio as gr + from apps.stable_diffusion.web.ui.utils import ( HSLHue, hsl_color, diff --git a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css index 5686f0868c..fa8d50adf2 100644 --- a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css +++ b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css @@ -239,8 +239,9 @@ footer { padding: 0 !important; } -#output_subdir_container :first-child { - border: none; +#output_subdir_container { + background-color: var(--block-background-fill); + padding-right: 8px; } /* reduced animation load when generating */ @@ -279,10 +280,19 @@ footer { /* output gallery tab */ .output_parameters_dataframe table.table { - /* works around a gradio bug that always shows scrollbars */ +/* works around a gradio bug that always shows scrollbars */ overflow: clip auto; } +.output_parameters_dataframe .cell-wrap span { + /* inadequate workaround for gradio issue #6086 */ + user-select:text !important; + -moz-user-select:text !important; + -webkit-user-select:text !important; + -o-user-select:text !important; + -ms-user-select:text !important; +} + .output_parameters_dataframe tbody td { font-size: small; line-height: var(--line-xs); @@ -291,7 +301,7 @@ footer { .output_icon_button { max-width: 30px; align-self: end; - padding-bottom: 8px; + padding-bottom: 16px !important; } .outputgallery_sendto { @@ -308,6 +318,11 @@ footer { object-fit: contain !important; } +/* use the whole gallery area for previeews */ +#outputgallery_gallery .preview { + width: inherit; +} + /* centered logo for when there are no images */ #top_logo.logo_centered { height: 100%; diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index f3522656e4..a6df246325 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -326,14 +326,21 @@ def img2img_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + # TODO: make this import image prompt info if it exists + img2img_init_image = gr.Image( + label="Input Image", + type="pil", + interactive=True, + sources=["upload"], + ) with gr.Row(): # janky fix for overflowing text i2i_model_info = ( @@ -380,14 +387,6 @@ def img2img_inf( lines=2, elem_id="negative_prompt_box", ) - # TODO: make this import image prompt info if it exists - img2img_init_image = gr.Image( - label="Input Image", - type="pil", - height=300, - interactive=True, - ) - with gr.Accordion(label="Multistencil Options", open=False): choices = [ "None", @@ -958,6 +957,8 @@ def update_cn_input( elem_id="gallery", columns=2, object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{i2i_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 8cd56f452b..4ce4795a82 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -3,8 +3,15 @@ import time import sys import gradio as gr +import PIL.ImageOps from PIL import Image +from gradio.components.image_editor import ( + Brush, + Eraser, + EditorData, + EditorValue, +) from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -37,11 +44,53 @@ init_import_mlir = args.import_mlir +def set_image_states(editor_data): + input_mask = editor_data["layers"][0] + + # inpaint_inf wants white mask on black background (?), whilst ImageEditor + # delivers black mask on transparent (0 opacity) background + inference_mask = Image.new( + mode="RGB", size=input_mask.size, color=(255, 255, 255) + ) + inference_mask.paste(input_mask, input_mask) + inference_mask = PIL.ImageOps.invert(inference_mask) + + return ( + # we set the ImageEditor data again, because it likes to clear + # the image layers (which include the mask) if the user hasn't + # used the upload button, and we sent it and image + # TODO: work out what is going wrong in that case so we don't have + # to do this + { + "background": editor_data["background"], + "layers": [input_mask], + "composite": None, + }, + editor_data["background"], + input_mask, + inference_mask, + ) + + +def reload_image_editor(editor_image, editor_mask): + # we set the ImageEditor data again, because it likes to clear + # the image layers (which include the mask) if the user hasn't + # used the upload button, and we sent it the image + # TODO: work out what is going wrong in that case so we don't have + # to do this + return { + "background": editor_image, + "layers": [editor_mask], + "composite": None, + } + + # Exposed to UI. def inpaint_inf( prompt: str, negative_prompt: str, - image_dict, + image, + mask_image, height: int, width: int, inpaint_full_res: bool, @@ -175,8 +224,6 @@ def inpaint_inf( start_time = time.time() global_obj.get_sd_obj().log = "" generated_imgs = [] - image = image_dict["image"] - mask_image = image_dict["mask"] text_output = "" try: seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) @@ -223,6 +270,9 @@ def inpaint_inf( with gr.Blocks(title="Inpainting") as inpaint_web: + editor_image = gr.State() + editor_mask = gr.State() + inference_mask = gr.State() with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) with gr.Row(): @@ -231,14 +281,24 @@ def inpaint_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + inpaint_init_image = gr.Sketchpad( + label="Masked Image", + type="pil", + sources=("clipboard", "upload"), + interactive=True, + brush=Brush( + colors=["#000000"], + color_mode="fixed", + ), + ) with gr.Row(): # janky fix for overflowing text inpaint_model_info = ( @@ -288,14 +348,6 @@ def inpaint_inf( lines=2, elem_id="negative_prompt_box", ) - - inpaint_init_image = gr.Image( - label="Masked Image", - sources="upload", - type="pil", - height=350, - ) - with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): # janky fix for overflowing text @@ -448,6 +500,8 @@ def inpaint_inf( elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{inpaint_model_info}\n" @@ -484,7 +538,8 @@ def inpaint_inf( inputs=[ prompt, negative_prompt, - inpaint_init_image, + editor_image, + inference_mask, height, width, inpaint_full_res, @@ -514,18 +569,53 @@ def inpaint_inf( fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs), inputs=[batch_count, batch_size], outputs=inpaint_status, + show_progress="none", + ) + set_image_states_args = dict( + fn=set_image_states, + inputs=[inpaint_init_image], + outputs=[ + inpaint_init_image, + editor_image, + editor_mask, + inference_mask, + ], + show_progress="none", + ) + reload_image_editor_args = dict( + fn=reload_image_editor, + inputs=[editor_image, editor_mask], + outputs=[inpaint_init_image], + show_progress="none", ) - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs + # all these trigger generation + prompt_submit = ( + prompt.submit(**set_image_states_args) + .then(**status_kwargs) + .then(**kwargs) + .then(**reload_image_editor_args) ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) + neg_prompt_submit = ( + negative_prompt.submit(**set_image_states_args) + .then(**status_kwargs) + .then(**kwargs) + .then(**reload_image_editor_args) + ) + generate_click = ( + stable_diffusion.click(**set_image_states_args) + .then(**status_kwargs) + .then(**kwargs) + .then(**reload_image_editor_args) + ) + + # Attempts to cancel generation stop_batch.click( fn=cancel_sd, cancels=[prompt_submit, neg_prompt_submit, generate_click], ) + # Updates LoRA information when one is selected lora_weights.change( fn=lora_changed, inputs=[lora_weights], diff --git a/apps/stable_diffusion/web/ui/lora_train_ui.py b/apps/stable_diffusion/web/ui/lora_train_ui.py index d84728bc26..45c1c3e243 100644 --- a/apps/stable_diffusion/web/ui/lora_train_ui.py +++ b/apps/stable_diffusion/web/ui/lora_train_ui.py @@ -23,10 +23,10 @@ value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): diff --git a/apps/stable_diffusion/web/ui/model_manager.py b/apps/stable_diffusion/web/ui/model_manager.py index 11e01fe873..21c0939f5e 100644 --- a/apps/stable_diffusion/web/ui/model_manager.py +++ b/apps/stable_diffusion/web/ui/model_manager.py @@ -105,6 +105,7 @@ def get_image_from_model(model_json): label="Civitai Model Gallery", value=None, visible=False, + show_download_button=False, ) with gr.Row(visible=False) as sendto_btns: diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py index 2a4c0039e7..a515f6c90e 100644 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ b/apps/stable_diffusion/web/ui/outpaint_ui.py @@ -236,14 +236,17 @@ def outpaint_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + outpaint_init_image = gr.Image( + label="Input Image", type="pil", sources=["upload"] + ) with gr.Row(): outpaint_model_info = ( f"Custom Model Path: {str(get_custom_model_path())}" @@ -291,13 +294,6 @@ def outpaint_inf( lines=2, elem_id="negative_prompt_box", ) - - outpaint_init_image = gr.Image( - label="Input Image", - type="pil", - height=300, - ) - with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): # janky fix for overflowing text @@ -473,6 +469,8 @@ def outpaint_inf( elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{outpaint_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/outputgallery_ui.py b/apps/stable_diffusion/web/ui/outputgallery_ui.py index 35ef80736f..d33e5f5393 100644 --- a/apps/stable_diffusion/web/ui/outputgallery_ui.py +++ b/apps/stable_diffusion/web/ui/outputgallery_ui.py @@ -80,28 +80,28 @@ def output_subdirs() -> list[str]: label="Getting subdirectories...", value=nod_logo, interactive=False, + show_download_button=False, visible=True, show_label=True, elem_id="top_logo", elem_classes="logo_centered", - show_download_button=False, ) - gallery = gr.Gallery( label="", value=gallery_files.value, visible=False, show_label=True, columns=4, + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) with gr.Column(scale=4): with gr.Group(): - with gr.Row(): + with gr.Row(elem_id="output_subdir_container"): with gr.Column( scale=15, min_width=160, - elem_id="output_subdir_container", ): subdirectories = gr.Dropdown( label=f"Subdirectories of {output_dir}", @@ -109,7 +109,7 @@ def output_subdirs() -> list[str]: choices=subdirectory_paths.value, value="", interactive=True, - elem_classes="dropdown_no_container", + # elem_classes="dropdown_no_container", allow_custom_value=True, ) with gr.Column( @@ -149,11 +149,12 @@ def output_subdirs() -> list[str]: ) as parameters_accordian: image_parameters = gr.DataFrame( headers=["Parameter", "Value"], - col_count=2, + col_count=(2, "fixed"), + row_count=(1, "fixed"), wrap=True, elem_classes="output_parameters_dataframe", value=[["Status", "No image selected"]], - interactive=True, + interactive=False, ) with gr.Accordion(label="Send To", open=True): @@ -327,12 +328,18 @@ def on_select_image(images: list[str], evt: gr.SelectData) -> list: else: return [ filename, - list(map(list, params["parameters"].items())), + gr.DataFrame( + value=list(map(list, params["parameters"].items())), + row_count=(len(params["parameters"]), "fixed"), + ), ] return [ filename, - [["Status", "No parameters found"]], + gr.DataFrame( + value=[["Status", "No parameters found"]], + row_count=(1, "fixed"), + ), ] def on_outputgallery_filename_change(filename: str) -> list: @@ -450,11 +457,12 @@ def outputgallery_tab_select(select): # We should have been passed a list of components on other tabs that update # when a new image has generated on that tab, so set things up so the user # will see that new image if they are looking at today's subdirectory - def outputgallery_watch(components: gr.Textbox): + def outputgallery_watch(components: gr.Textbox, queued_components=[]): for component in components: component.change( on_new_image, inputs=[subdirectories, subdirectory_paths, component], outputs=[gallery_files, gallery, logo], - queue=False, + queue=component in queued_components, + show_progress="none", ) diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 85dae66d2e..807c30ad2e 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -240,10 +240,10 @@ def txt2img_sdxl_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): @@ -264,7 +264,7 @@ def txt2img_sdxl_inf( custom_checkpoint_type="sdxl" ), allow_custom_value=True, - scale=2, + scale=11, ) t2i_sdxl_vae_info = ( str(get_custom_model_path("vae")) @@ -283,15 +283,16 @@ def txt2img_sdxl_inf( ] + get_custom_model_files("vae"), allow_custom_value=True, + scale=4, + ) + txt2img_sdxl_png_info_img = gr.Image( scale=1, + label="Import PNG info", + elem_id="txt2img_prompt_image", + type="pil", + visible=True, + sources=["upload"], ) - with gr.Column(scale=1, min_width=170): - txt2img_sdxl_png_info_img = gr.Image( - label="Import PNG info", - elem_id="txt2img_prompt_image", - type="pil", - visible=True, - ) with gr.Group(elem_id="prompt_box_outer"): txt2img_sdxl_autogen = gr.Checkbox( @@ -477,6 +478,8 @@ def txt2img_sdxl_inf( elem_id="gallery", columns=[2], object_fit="scale_down", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{t2i_sdxl_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 9df392a90a..3b6c936cf8 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -427,16 +427,16 @@ def onload_load_settings(): value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): with gr.Row(): - with gr.Column(scale=10): + with gr.Column(): with gr.Row(): t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}" txt2img_custom_model = gr.Dropdown( @@ -449,7 +449,7 @@ def onload_load_settings(): choices=get_custom_model_files() + predefined_models, allow_custom_value=True, - scale=2, + scale=11, ) # janky fix for overflowing text t2i_vae_info = ( @@ -464,16 +464,16 @@ def onload_load_settings(): choices=["None"] + get_custom_model_files("vae"), allow_custom_value=True, + scale=4, + ) + txt2img_png_info_img = gr.Image( + label="Import PNG info", + elem_id="txt2img_prompt_image", + type="pil", + visible=True, + sources=["upload"], scale=1, ) - with gr.Column(scale=1, min_width=170): - txt2img_png_info_img = gr.Image( - label="Import PNG info", - elem_id="txt2img_prompt_image", - type="pil", - visible=True, - ) - with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( label="Prompt", @@ -688,6 +688,8 @@ def onload_load_settings(): elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{t2i_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py index 42157dbd98..88d0507adb 100644 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ b/apps/stable_diffusion/web/ui/upscaler_ui.py @@ -255,14 +255,19 @@ def upscaler_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + upscaler_init_image = gr.Image( + label="Input Image", + type="pil", + sources=["upload"], + ) with gr.Row(): upscaler_model_info = ( f"Custom Model Path: {str(get_custom_model_path())}" @@ -311,13 +316,6 @@ def upscaler_inf( lines=2, elem_id="negative_prompt_box", ) - - upscaler_init_image = gr.Image( - label="Input Image", - type="pil", - height=300, - ) - with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): # janky fix for overflowing text @@ -471,6 +469,8 @@ def upscaler_inf( elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{upscaler_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index 9252ecee9f..0572089e84 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -5,11 +5,13 @@ import json import safetensors import gradio as gr +import PIL.Image as Image from pathlib import Path from apps.stable_diffusion.src import args from dataclasses import dataclass from enum import IntEnum +from gradio.components.image_editor import EditorValue from apps.stable_diffusion.src import get_available_devices import apps.stable_diffusion.web.utils.global_obj as global_obj @@ -315,6 +317,25 @@ def default_config_exists(model_ckpt_or_id): return None +def mask_editor_value_for_image_file(filepath): + image = Image.open(filepath) + mask = Image.new(mode="RGBA", size=image.size, color=(0, 0, 0, 0)) + return {"background": image, "layers": [mask], "composite": image} + + +def mask_editor_value_for_gallery_data(gallery_data): + filepath = ( + gallery_data.root[0].image.path + if len(gallery_data.root) != 0 + else None + ) + + if os.path.isfile(filepath): + return mask_editor_value_for_image_file(filepath) + + return EditorValue() + + default_configs = { "stabilityai/sdxl-turbo": [ gr.Textbox(label="", interactive=False, value=None, visible=False), @@ -350,6 +371,7 @@ def default_config_exists(model_ckpt_or_id): ], } + nodlogo_loc = resource_path("logos/nod-logo.png") nodicon_loc = resource_path("logos/nod-icon.png") available_devices = get_available_devices() diff --git a/build_tools/stable_diffusion_testing.py b/build_tools/stable_diffusion_testing.py index ced919732c..8eeb1a7395 100644 --- a/build_tools/stable_diffusion_testing.py +++ b/build_tools/stable_diffusion_testing.py @@ -36,9 +36,7 @@ def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir): metrics[val] = line.split(" ")[-1].strip("\n") metrics["Average step"] = metrics["Average step"].strip("ms/it") - metrics["Total image generation"] = metrics[ - "Total image generation" - ].strip("sec") + metrics["Total image generation"] = metrics["Total image generation"].strip("sec") metrics["device"] = device metrics["use_tune"] = use_tune metrics["model_name"] = model_name @@ -84,10 +82,14 @@ def test_loop( ] import_options = ["--import_mlir", "--no-import_mlir"] prompt_text = "--prompt=cyberpunk forest by Salvador Dali" - inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + inpaint_prompt_text = ( + "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + ) if os.name == "nt": prompt_text = '--prompt="cyberpunk forest by Salvador Dali"' - inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + inpaint_prompt_text = ( + '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + ) if beta: extra_flags.append("--beta_models=True") extra_flags.append("--no-progress_bar") @@ -174,9 +176,7 @@ def test_loop( ) print(command) print("Successfully generated image") - os.makedirs( - "./test_images/golden/" + model_name, exist_ok=True - ) + os.makedirs("./test_images/golden/" + model_name, exist_ok=True) download_public_file( "gs://shark_tank/testdata/golden/" + model_name, "./test_images/golden/" + model_name, @@ -191,14 +191,10 @@ def test_loop( ) test_file = glob(test_file_path)[0] - golden_path = ( - "./test_images/golden/" + model_name + "/*.png" - ) + golden_path = "./test_images/golden/" + model_name + "/*.png" golden_file = glob(golden_path)[0] try: - compare_images( - test_file, golden_file, upload=upload_bool - ) + compare_images(test_file, golden_file, upload=upload_bool) except AssertionError as e: print(e) if exit_on_fail == True: @@ -267,9 +263,7 @@ def prepare_artifacts(): parser.add_argument( "-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True ) -parser.add_argument( - "-g", "--gen", action=argparse.BooleanOptionalAction, default=False -) +parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False) if __name__ == "__main__": args = parser.parse_args() diff --git a/dataset/annotation_tool.py b/dataset/annotation_tool.py index edd088229f..60f607146d 100644 --- a/dataset/annotation_tool.py +++ b/dataset/annotation_tool.py @@ -10,9 +10,7 @@ shark_root = Path(__file__).parent.parent demo_css = shark_root.joinpath("web/demo.css").resolve() -nodlogo_loc = shark_root.joinpath( - "web/models/stable_diffusion/logos/nod-logo.png" -) +nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png") with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web: @@ -23,6 +21,7 @@ value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=100, @@ -75,9 +74,7 @@ def filter_datasets(dataset): with jsonlines.open(dataset_path + "/metadata.jsonl") as reader: for line in reader.iter(type=dict, skip_invalid=True): prompt_data[line["file_name"]] = ( - [line["text"]] - if type(line["text"]) is str - else line["text"] + [line["text"]] if type(line["text"]) is str else line["text"] ) return gr.Dropdown.update(choices=images[dataset]) @@ -103,9 +100,7 @@ def display_image(dataset, image_name): prompt_data[image_name] = [] prompt_choices = ["Add new"] prompt_choices += prompt_data[image_name] - return gr.Image.update(value=img), gr.Dropdown.update( - choices=prompt_choices - ) + return gr.Image.update(value=img), gr.Dropdown.update(choices=prompt_choices) image_name.change( fn=display_image, @@ -122,12 +117,7 @@ def edit_prompt(prompts): prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt) def save_prompt(dataset, image_name, prompts, prompt): - if ( - dataset is None - or image_name is None - or prompts is None - or prompt is None - ): + if dataset is None or image_name is None or prompts is None or prompt is None: return if prompts == "Add new": @@ -136,9 +126,7 @@ def save_prompt(dataset, image_name, prompts, prompt): idx = prompt_data[image_name].index(prompts) prompt_data[image_name][idx] = prompt - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -165,9 +153,7 @@ def delete_prompt(dataset, image_name, prompts): return prompt_data[image_name].remove(prompts) - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -230,9 +216,7 @@ def finish_annotation(dataset): # upload prompt and remove local data dataset_path = str(shark_root) + "/dataset/" + dataset dataset_gs_path = args.gs_url + "/" + dataset + "/" - os.system( - f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"' - ) + os.system(f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"') os.system(f'rm -rf "{dataset_path}"') return gr.Dropdown.update(value=None) diff --git a/process_skipfiles.py b/process_skipfiles.py index a846159451..339c7ebec6 100644 --- a/process_skipfiles.py +++ b/process_skipfiles.py @@ -8,8 +8,7 @@ # Temporary workaround for transformers/__init__.py. path_to_transformers_hook = Path( - get_python_lib() - + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" + get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" ) if path_to_transformers_hook.is_file(): pass @@ -59,9 +58,7 @@ # For getting around timm's packaging. # Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505 -path_to_timm_activations = Path( - get_python_lib() + "/timm/layers/activations_jit.py" -) +path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py") for line in fileinput.input(path_to_timm_activations, inplace=True): if "@torch.jit.script" in line: print("@torch.jit._script_if_tracing", end="\n") diff --git a/pyproject.toml b/pyproject.toml index 22e0210c50..876df2f8bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,14 +5,25 @@ requires = [ "packaging", "numpy>=1.22.4", - "torch-mlir>=20230620.875", "iree-compiler>=20221022.190", "iree-runtime>=20221022.190", ] build-backend = "setuptools.build_meta" [tool.black] -line-length = 79 include = '\.pyi?$' -exclude = "apps/language_models/scripts/vicuna.py" -extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py" +exclude = ''' +( + /( + | apps/stable_diffusion + | apps/language_models + | shark + | benchmarks + | tank + | build + | generated_imgs + | shark.venv + )/ + | setup.py +) +''' diff --git a/pytest.ini b/pytest.ini index 11f57888b2..3857248785 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] addopts = --verbose -s -p no:warnings -norecursedirs = inference tank/tflite examples benchmarks shark +norecursedirs = inference tank/tflite examples benchmarks shark apps/shark_studio diff --git a/requirements.txt b/requirements.txt index a97baa83a3..3f7e719e67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,13 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +-f https://openxla.github.io/iree/pip-release-links.html --pre setuptools wheel +shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine#egg=turbine-models&subdirectory=python/turbine_models + # SHARK Runner tqdm @@ -17,16 +21,12 @@ pytest-forked Pillow parameterized -#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main # Add transformers, diffusers and scipy since it most commonly used -tokenizers==0.13.3 -transformers -diffusers #accelerate is now required for diffusers import from ckpt. accelerate scipy ftfy -gradio==4.7.1 +gradio==4.8.0 altair omegaconf # 0.3.2 doesn't have binaries for arm64 @@ -49,9 +49,6 @@ pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions pefile pyinstaller -# vicuna quantization -brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea - # For quantized GPTQ models optimum auto_gptq diff --git a/rest_api_tests/api_test.py b/rest_api_tests/api_test.py index 7a4cf042c2..f3c0b0e170 100644 --- a/rest_api_tests/api_test.py +++ b/rest_api_tests/api_test.py @@ -44,14 +44,10 @@ def upscaler_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[upscaler] response from server was : {res.status_code} {res.reason}" - ) + print(f"[upscaler] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def img2img_test(verbose=False): @@ -96,14 +92,10 @@ def img2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[img2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[img2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") # NOTE Uncomment below to save the picture @@ -133,13 +125,9 @@ def inpainting_test(verbose=False): image_path = r"./rest_api_tests/dog.png" img_file = open(image_path, "rb") - image = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + image = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() img_file = open(image_path, "rb") - mask = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + mask = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() url = "http://127.0.0.1:8080/sdapi/v1/inpaint" @@ -166,14 +154,10 @@ def inpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[inpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[inpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def outpainting_test(verbose=False): @@ -223,14 +207,10 @@ def outpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[outpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[outpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def txt2img_test(verbose=False): @@ -262,14 +242,10 @@ def txt2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[txt2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[txt2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def sd_models_test(verbose=False): @@ -283,9 +259,7 @@ def sd_models_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_models] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_models] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -302,9 +276,7 @@ def sd_samplers_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_samplers] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_samplers] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -321,9 +293,7 @@ def options_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[options] response from server was : {res.status_code} {res.reason}" - ) + print(f"[options] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -340,9 +310,7 @@ def cmd_flags_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[cmd-flags] response from server was : {res.status_code} {res.reason}" - ) + print(f"[cmd-flags] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") diff --git a/setup.py b/setup.py index c387fe9add..061873e7a8 100644 --- a/setup.py +++ b/setup.py @@ -9,11 +9,6 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5" backend_deps = [] -if "NO_BACKEND" in os.environ.keys(): - backend_deps = [ - "iree-compiler>=20221022.190", - "iree-runtime>=20221022.190", - ] setup( name="nodai-SHARK", @@ -39,7 +34,5 @@ install_requires=[ "numpy", "PyYAML", - "torch-mlir", ] - + backend_deps, ) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 6cfe369426..bae1908e1c 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -305,6 +305,7 @@ def compile_module_to_flatbuffer( model_name="None", debug=False, compile_str=False, + write_to=None, ): # Setup Compile arguments wrt to frontends. input_type = "auto" @@ -342,12 +343,24 @@ def compile_module_to_flatbuffer( extra_args=args, ) + if write_to is not None: + with open(write_to, "wb") as f: + f.write(flatbuffer_blob) + return None + return flatbuffer_blob def get_iree_module( - flatbuffer_blob, device, device_idx=None, rt_flags: list = [] + flatbuffer_blob, + device, + device_idx=None, + rt_flags: list = [], + external_weight_file=None, ): + if external_weight_file is not None: + index = ireert.ParameterIndex() + index.load(external_weight_file) # Returns the compiled module and the configs. for flag in rt_flags: ireert.flags.parse_flag(flag) @@ -369,7 +382,10 @@ def get_iree_module( vm_module = ireert.VmModule.from_buffer( config.vm_instance, flatbuffer_blob, warn_if_copy=False ) - ctx = ireert.SystemContext(config=config) + modules = [] + if external_weight_file is not None: + modules.append(index.create_provider(scope="model")) + ctx = ireert.SystemContext(vm_modules=modules, config=config) ctx.add_vm_module(vm_module) ModuleCompiled = getattr(ctx.modules, vm_module.name) return ModuleCompiled, config @@ -380,6 +396,7 @@ def load_vmfb_using_mmap( device: str, device_idx: int = None, rt_flags: list = [], + external_weight_file: str = None, ): print(f"Loading module {flatbuffer_blob_or_path}...") if "task" in device: @@ -440,17 +457,28 @@ def load_vmfb_using_mmap( mmaped_vmfb = ireert.VmModule.mmap( config.vm_instance, flatbuffer_blob_or_path ) + vm_modules = [] + if external_weight_file is not None: + index = ireert.ParameterIndex() + index.load(external_weight_file) + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.append(param_module) + vm_modules.append(mmaped_vmfb) + vm_modules.append( + ireert.create_hal_module(config.vm_instance, config.device) + ) dl.log(f"mmap {flatbuffer_blob_or_path}") - ctx = ireert.SystemContext(config=config) - for flag in shark_args.additional_runtime_args: - ireert.flags.parse_flags(flag) - dl.log(f"ireert.SystemContext created") if "vulkan" in device: # Vulkan pipeline creation consumes significant amount of time. print( "\tCompiling Vulkan shaders. This may take a few minutes." ) - ctx.add_vm_module(mmaped_vmfb) + ctx = ireert.SystemContext(config=config, vm_modules=vm_modules) + dl.log(f"ireert.SystemContext created") + for flag in shark_args.additional_runtime_args: + ireert.flags.parse_flags(flag) dl.log(f"module initialized") mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) else: @@ -475,6 +503,8 @@ def get_iree_compiled_module( mmap: bool = False, debug: bool = False, compile_str: bool = False, + external_weight_file: str = None, + write_to: bool = None, ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( @@ -485,6 +515,7 @@ def get_iree_compiled_module( extra_args=extra_args, debug=debug, compile_str=compile_str, + write_to=write_to, ) temp_file_to_unlink = None # TODO: Currently mmap=True control flow path has been switched off for mmap. @@ -492,8 +523,14 @@ def get_iree_compiled_module( # we're setting delete=False when creating NamedTemporaryFile. That's why # I'm getting hold of the name of the temporary file in `temp_file_to_unlink`. if mmap: + if write_to is not None: + flatbuffer_blob = write_to vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( - flatbuffer_blob, device, device_idx, rt_flags + flatbuffer_blob, + device, + device_idx, + rt_flags, + external_weight_file=external_weight_file, ) else: vmfb, config = get_iree_module( @@ -501,6 +538,7 @@ def get_iree_compiled_module( device, device_idx=device_idx, rt_flags=rt_flags, + external_weight_file=external_weight_file, ) ret_params = { "vmfb": vmfb,