From a517e217b07e54e0591d392262516c72d92da3d4 Mon Sep 17 00:00:00 2001 From: AyaanShah2204 <89650087+AyaanShah2204@users.noreply.github.com> Date: Sun, 9 Jul 2023 06:45:36 -0700 Subject: [PATCH] Added support for building ZIP distributions (#1639) * added support for zip files * making linter happy * Added temporary fix for NoneType padding * Removed zip script * Added shared imports file * making linter happy --- apps/stable_diffusion/shark_sd.spec | 52 +---------------- apps/stable_diffusion/shark_studio_imports.py | 58 +++++++++++++++++++ .../pipeline_shark_stable_diffusion_utils.py | 14 +++++ apps/stable_diffusion/studio_bundle.spec | 51 ++++++++++++++++ 4 files changed, 124 insertions(+), 51 deletions(-) create mode 100644 apps/stable_diffusion/shark_studio_imports.py create mode 100644 apps/stable_diffusion/studio_bundle.spec diff --git a/apps/stable_diffusion/shark_sd.spec b/apps/stable_diffusion/shark_sd.spec index 6245fc80f6..9aaf890b0e 100644 --- a/apps/stable_diffusion/shark_sd.spec +++ b/apps/stable_diffusion/shark_sd.spec @@ -1,60 +1,10 @@ # -*- mode: python ; coding: utf-8 -*- -from PyInstaller.utils.hooks import collect_data_files -from PyInstaller.utils.hooks import copy_metadata -from PyInstaller.utils.hooks import collect_submodules - -import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) - -datas = [] -datas += collect_data_files('torch') -datas += copy_metadata('torch') -datas += copy_metadata('tqdm') -datas += copy_metadata('regex') -datas += copy_metadata('requests') -datas += copy_metadata('packaging') -datas += copy_metadata('filelock') -datas += copy_metadata('numpy') -datas += copy_metadata('importlib_metadata') -datas += copy_metadata('torch-mlir') -datas += copy_metadata('omegaconf') -datas += copy_metadata('safetensors') -datas += copy_metadata('Pillow') -datas += copy_metadata('sentencepiece') -datas += collect_data_files('tokenizers') -datas += collect_data_files('diffusers') -datas += collect_data_files('transformers') -datas += collect_data_files('pytorch_lightning') -datas += collect_data_files('opencv_python') -datas += collect_data_files('skimage') -datas += collect_data_files('gradio') -datas += collect_data_files('gradio_client') -datas += collect_data_files('iree') -datas += collect_data_files('google_cloud_storage') -datas += collect_data_files('shark') -datas += collect_data_files('tkinter') -datas += collect_data_files('webview') -datas += collect_data_files('sentencepiece') -datas += collect_data_files('jsonschema') -datas += collect_data_files('jsonschema_specifications') -datas += collect_data_files('cpuinfo') -datas += [ - ( 'src/utils/resources/prompts.json', 'resources' ), - ( 'src/utils/resources/model_db.json', 'resources' ), - ( 'src/utils/resources/opt_flags.json', 'resources' ), - ( 'src/utils/resources/base_model.json', 'resources' ), - ( 'web/ui/css/*', 'ui/css' ), - ( 'web/ui/logos/*', 'logos' ) - ] +from apps.stable_diffusion.shark_studio_imports import datas, hiddenimports binaries = [] block_cipher = None -hiddenimports = ['shark', 'shark.shark_inference', 'apps'] -hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] -hiddenimports += [x for x in collect_submodules("transformers") if "tests" not in x] -hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] - a = Analysis( ['web/index.py'], pathex=['.'], diff --git a/apps/stable_diffusion/shark_studio_imports.py b/apps/stable_diffusion/shark_studio_imports.py new file mode 100644 index 0000000000..811893d086 --- /dev/null +++ b/apps/stable_diffusion/shark_studio_imports.py @@ -0,0 +1,58 @@ +from PyInstaller.utils.hooks import collect_data_files +from PyInstaller.utils.hooks import copy_metadata +from PyInstaller.utils.hooks import collect_submodules + +import sys + +sys.setrecursionlimit(sys.getrecursionlimit() * 5) + +# datafiles for pyinstaller +datas = [] +datas += collect_data_files("torch") +datas += copy_metadata("torch") +datas += copy_metadata("tqdm") +datas += copy_metadata("regex") +datas += copy_metadata("requests") +datas += copy_metadata("packaging") +datas += copy_metadata("filelock") +datas += copy_metadata("numpy") +datas += copy_metadata("importlib_metadata") +datas += copy_metadata("torch-mlir") +datas += copy_metadata("omegaconf") +datas += copy_metadata("safetensors") +datas += copy_metadata("Pillow") +datas += copy_metadata("sentencepiece") +datas += collect_data_files("tokenizers") +datas += collect_data_files("diffusers") +datas += collect_data_files("transformers") +datas += collect_data_files("pytorch_lightning") +datas += collect_data_files("opencv_python") +datas += collect_data_files("skimage") +datas += collect_data_files("gradio") +datas += collect_data_files("gradio_client") +datas += collect_data_files("iree") +datas += collect_data_files("google_cloud_storage") +datas += collect_data_files("shark") +datas += collect_data_files("tkinter") +datas += collect_data_files("webview") +datas += collect_data_files("sentencepiece") +datas += collect_data_files("jsonschema") +datas += collect_data_files("jsonschema_specifications") +datas += collect_data_files("cpuinfo") +datas += [ + ("src/utils/resources/prompts.json", "resources"), + ("src/utils/resources/model_db.json", "resources"), + ("src/utils/resources/opt_flags.json", "resources"), + ("src/utils/resources/base_model.json", "resources"), + ("web/ui/css/*", "ui/css"), + ("web/ui/logos/*", "logos"), +] + + +# hidden imports for pyinstaller +hiddenimports = ["shark", "shark.shark_inference", "apps"] +hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] +hiddenimports += [ + x for x in collect_submodules("transformers") if "tests" not in x +] +hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 632f44d9b3..dd81f55341 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -756,6 +756,12 @@ def get_unweighted_text_embeddings( return text_embeddings +# This function deals with NoneType values occuring in tokens after padding +# It switches out None with 49407 as truncating None values causes matrix dimension errors, +def filter_nonetype_tokens(tokens: List[List]): + return [[49407 if token is None else token for token in tokens[0]]] + + def get_weighted_text_embeddings( pipe: StableDiffusionPipeline, prompt: Union[str, List[str]], @@ -847,6 +853,10 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, chunk_length=pipe.model_max_length, ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + prompt_tokens = filter_nonetype_tokens(prompt_tokens) + # prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu") if uncond_prompt is not None: @@ -859,6 +869,10 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, chunk_length=pipe.model_max_length, ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + uncond_tokens = filter_nonetype_tokens(uncond_tokens) + # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) uncond_tokens = torch.tensor( uncond_tokens, dtype=torch.long, device="cpu" diff --git a/apps/stable_diffusion/studio_bundle.spec b/apps/stable_diffusion/studio_bundle.spec new file mode 100644 index 0000000000..990088da64 --- /dev/null +++ b/apps/stable_diffusion/studio_bundle.spec @@ -0,0 +1,51 @@ +# -*- mode: python ; coding: utf-8 -*- +from apps.stable_diffusion.shark_studio_imports import datas, hiddenimports + +binaries = [] + +block_cipher = None + +a = Analysis( + ['web\\index.py'], + pathex=['.'], + binaries=binaries, + datas=datas, + hiddenimports=hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='studio_bundle', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) +coll = COLLECT( + exe, + a.binaries, + a.zipfiles, + a.datas, + strip=False, + upx=True, + upx_exclude=[], + name='studio_bundle', +)