Skip to content

Commit

Permalink
Added support for building ZIP distributions (#1639)
Browse files Browse the repository at this point in the history
* added support for zip files

* making linter happy

* Added temporary fix for NoneType padding

* Removed zip script

* Added shared imports file

* making linter happy
  • Loading branch information
AyaanShah2204 committed Jul 9, 2023
1 parent 9fcae4f commit a517e21
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 51 deletions.
52 changes: 1 addition & 51 deletions apps/stable_diffusion/shark_sd.spec
Original file line number Diff line number Diff line change
@@ -1,60 +1,10 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules

import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)

datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('Pillow')
datas += copy_metadata('sentencepiece')
datas += collect_data_files('tokenizers')
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('opencv_python')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google_cloud_storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += collect_data_files('jsonschema')
datas += collect_data_files('jsonschema_specifications')
datas += collect_data_files('cpuinfo')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/ui/css/*', 'ui/css' ),
( 'web/ui/logos/*', 'logos' )
]
from apps.stable_diffusion.shark_studio_imports import datas, hiddenimports

binaries = []

block_cipher = None

hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("transformers") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]

a = Analysis(
['web/index.py'],
pathex=['.'],
Expand Down
58 changes: 58 additions & 0 deletions apps/stable_diffusion/shark_studio_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import copy_metadata
from PyInstaller.utils.hooks import collect_submodules

import sys

sys.setrecursionlimit(sys.getrecursionlimit() * 5)

# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
datas += copy_metadata("packaging")
datas += copy_metadata("filelock")
datas += copy_metadata("numpy")
datas += copy_metadata("importlib_metadata")
datas += copy_metadata("torch-mlir")
datas += copy_metadata("omegaconf")
datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += collect_data_files("tokenizers")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
("src/utils/resources/opt_flags.json", "resources"),
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/logos/*", "logos"),
]


# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("transformers") if "tests" not in x
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,12 @@ def get_unweighted_text_embeddings(
return text_embeddings


# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
return [[49407 if token is None else token for token in tokens[0]]]


def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -847,6 +853,10 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)

# FIXME: This is a hacky fix caused by tokenizer padding with None values
prompt_tokens = filter_nonetype_tokens(prompt_tokens)

# prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu")
if uncond_prompt is not None:
Expand All @@ -859,6 +869,10 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)

# FIXME: This is a hacky fix caused by tokenizer padding with None values
uncond_tokens = filter_nonetype_tokens(uncond_tokens)

# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
uncond_tokens = torch.tensor(
uncond_tokens, dtype=torch.long, device="cpu"
Expand Down
51 changes: 51 additions & 0 deletions apps/stable_diffusion/studio_bundle.spec
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- mode: python ; coding: utf-8 -*-
from apps.stable_diffusion.shark_studio_imports import datas, hiddenimports

binaries = []

block_cipher = None

a = Analysis(
['web\\index.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name='studio_bundle',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name='studio_bundle',
)

0 comments on commit a517e21

Please sign in to comment.