diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index bbb273354c..9e99c43ec1 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -1,14 +1,17 @@ import importlib -import logging import os import signal import sys -import re import warnings import json from threading import Thread from apps.shark_studio.modules.timer import startup_timer +from apps.shark_studio.web.utils.tmp_configs import ( + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, + ) def imports(): @@ -21,6 +24,9 @@ def imports(): warnings.filterwarnings( action="ignore", category=UserWarning, module="torchvision" ) + warnings.filterwarnings( + action="ignore", category=UserWarning, message='.*is deprecated, please use.*', module="*torch*" + ) import gradio # noqa: F401 @@ -34,20 +40,27 @@ def imports(): from apps.shark_studio.modules import ( img_processing, ) # noqa: F401 - from apps.shark_studio.modules.schedulers import scheduler_model_map startup_timer.record("other imports") def initialize(): configure_sigint_handler() + # Setup to use shark_tmp for gradio's temporary image files and clear any + # existing temporary images there if they exist. Then we can import gradio. + # It has to be in this order or gradio ignores what we've set up. - # from apps.shark_studio.modules import modelloader - # modelloader.cleanup_models() + config_tmp() + clear_tmp_mlir() + clear_tmp_imgs() + + from apps.shark_studio.web.utils.file_utils import ( + create_checkpoint_folders, + ) + # Create custom models folders if they don't exist + create_checkpoint_folders() - # from apps.shark_studio.modules import sd_models - # sd_models.setup_model() - # startup_timer.record("setup SD model") + import gradio as gr # initialize_rest(reload_script_modules=False) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 7cd39da0b6..4be20b24d9 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -1,125 +1,61 @@ import gc -from unittest import registerResult import torch import time import os import json +import numpy as np from pathlib import Path from turbine_models.custom_models.sd_inference import clip, unet, vae from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.web.utils.state import status_label -from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path +from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path, get_checkpoints_path from apps.shark_studio.modules.pipeline import SharkPipelineBase +from apps.shark_studio.modules.schedulers import get_schedulers +from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddings from apps.shark_studio.modules.img_processing import ( resize_stencil, save_output_img, ) + from apps.shark_studio.modules.ckpt_processing import ( + preprocessCKPT, process_custom_pipe_weights, ) +from transformers import CLIPTokenizer from math import ceil from PIL import Image sd_model_map = { - "CompVis/stable-diffusion-v1-4": { - "clip": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "clip": { + "initializer": clip.export_clip_model, + "external_weight_file": None, + "ireec_flags": ["--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + ], }, - "runwayml/stable-diffusion-v1-5": { - "clip": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "vae_encode": { + "initializer": vae.export_vae_model, + "external_weight_file": None, }, - "stabilityai/stable-diffusion-2-1-base": { - "clip": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "unet": { + "initializer": unet.export_unet_model, + "ireec_flags": ["--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + ], + "external_weight_file": None, }, - "stabilityai/stable_diffusion-xl-1.0": { - "clip_1": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "clip_2": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "vae_decode": { + "initializer": vae.export_vae_model, + "external_weight_file": None, }, } -def get_spec(custom_sd_map: dict, sd_embeds: dict): - spec = [] - for key in custom_sd_map: - if "control" in key.split("_"): - spec.append("controlled") - elif key == "custom_vae": - spec.append(custom_sd_map[key]["custom_weights"].split(".")[0]) - num_embeds = 0 - embeddings_spec = None - for embed in sd_embeds: - if embed is not None: - num_embeds += 1 - embeddings_spec = str(num_embeds) + "embeds" - if embeddings_spec: - spec.append(embeddings_spec) - return "_".join(spec) - - class StableDiffusion(SharkPipelineBase): + # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile @@ -132,39 +68,143 @@ class StableDiffusion(SharkPipelineBase): # embeddings: a dict of embedding checkpoints or model IDs to use when # initializing the compiled modules. + def __init__( self, - base_model_id: str = "runwayml/stable-diffusion-v1-5", - height: int = 512, - width: int = 512, - precision: str = "fp16", - device: str = None, - custom_model_map: dict = {}, - embeddings: dict = {}, + base_model_id, + height: int, + width: int, + batch_size: int, + precision: str, + device: str, + custom_vae: str = None, + num_loras: int = 0, import_ir: bool = True, is_img2img: bool = False, + is_controlled: bool = False, ): - super().__init__( - sd_model_map[base_model_id], base_model_id, device, import_ir - ) + self.model_max_length = 77 + self.batch_size = batch_size self.precision = precision self.is_img2img = is_img2img - self.pipe_id = ( - safe_name(base_model_id) - + str(height) - + str(width) - + precision - + device - + get_spec(custom_model_map, embeddings) + self.scheduler_obj = {} + self.precision = precision + static_kwargs = { + "pipe": {}, + "clip": {"hf_model_name": base_model_id}, + "unet": { + "hf_model_name": base_model_id, + "unet_model": unet.UnetModel(hf_model_name=base_model_id, hf_auth_token=None), + "batch_size": batch_size, + #"is_controlled": is_controlled, + #"num_loras": num_loras, + "height": height, + "width": width, + }, + "vae_encode": { + "hf_model_name": custom_vae if custom_vae else base_model_id, + "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None), + "batch_size": batch_size, + "height": height, + "width": width, + }, + "vae_decode": { + "hf_model_name": custom_vae, + "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None), + "batch_size": batch_size, + "height": height, + "width": width, + }, + } + super().__init__( + sd_model_map, base_model_id, static_kwargs, device, import_ir ) - print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}") + pipe_id_list = [ + safe_name(base_model_id), + str(batch_size), + f"{str(height)}x{str(width)}", + precision, + ] + if num_loras > 0: + pipe_id_list.append(str(num_loras)+"lora") + if is_controlled: + pipe_id_list.append("controlled") + if custom_vae: + pipe_id_list.append(custom_vae) + self.pipe_id = "_".join(pipe_id_list) + print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") + del static_kwargs + gc.collect() + - def prepare_pipe(self, scheduler, custom_model_map, embeddings): + def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings): print( - f"\n[LOG] Preparing pipeline with scheduler {scheduler}, custom map {json.dumps(custom_model_map)}, and embeddings {json.dumps(embeddings)}." + f"\n[LOG] Preparing pipeline with scheduler {scheduler}" + f"\n[LOG] Custom embeddings currently unsupported." + ) + schedulers = get_schedulers(self.base_model_id) + self.weights_path = get_checkpoints_path(self.pipe_id) + if not os.path.exists(self.weights_path): + os.mkdir(self.weights_path) + # accepting a list of schedulers in batched cases. + for i in scheduler: + self.scheduler_obj[i] = schedulers[i] + print(f"[LOG] Loaded scheduler: {i}") + for model in adapters: + self.model_map[model] = adapters[model] + if os.path.isfile(custom_weights): + for i in self.model_map: + self.model_map[i]["external_weights_file"] = None + elif custom_weights != "": + print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?") + self.static_kwargs["pipe"] = { + # "external_weight_path": self.weights_path, +# "external_weights": "safetensors", + } + self.get_compiled_map(pipe_id=self.pipe_id) + print("\n[LOG] Pipeline successfully prepared for runtime.") + return + + + def encode_prompts_weight( + self, + prompt, + negative_prompt, + do_classifier_free_guidance=True, + ): + # Encodes the prompt into text encoder hidden states. + self.load_submodels(["clip"]) + self.tokenizer = CLIPTokenizer.from_pretrained( + self.base_model_id, + subfolder="tokenizer", + ) + clip_inf_start = time.time() + + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt + if do_classifier_free_guidance + else None, ) - self.get_compiled_map(device=self.device, pipe_id=self.pipe_id) - return None + + if do_classifier_free_guidance: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + pad = (0, 0) * (len(text_embeddings.shape) - 2) + pad = pad + (0, 512 - text_embeddings.shape[1]) + text_embeddings = torch.nn.functional.pad(text_embeddings, pad) + + # SHARK: Report clip inference time + clip_inf_time = (time.time() - clip_inf_start) * 1000 + if self.ondemand: + self.unload_clip() + gc.collect() + print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}") + + return text_embeddings.numpy().astype(np.float16) + def generate_images( self, @@ -181,11 +221,35 @@ def generate_images( hints, ): print("\n[LOG] Generating images...") + batched_args=[ + prompt, + negative_prompt, + steps, + strength, + guidance_scale, + seed, + resample_type, + control_mode, + hints, + ] + for arg in batched_args: + if not isinstance(arg, list): + arg = [arg] * self.batch_size + if len(arg) < self.batch_size: + arg = arg * self.batch_size + else: + arg = [arg[i] for i in range(self.batch_size)] + + text_embeddings = self.encode_prompts_weight( + prompt, + negative_prompt, + ) + print(text_embeddings) test_img = [ Image.open( get_resource_path("../../tests/jupiter.png"), mode="r" ).convert("RGB") - ] + ] * self.batch_size return test_img @@ -257,29 +321,31 @@ def shark_sd_fn( from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import apps.shark_studio.web.utils.globals as global_obj - custom_model_map = {} + adapters = {} + is_controlled = False control_mode = None hints = [] - if custom_weights != "None": - custom_model_map["unet"] = {"custom_weights": custom_weights} - if custom_vae != "None": - custom_model_map["vae"] = {"custom_weights": custom_vae} + num_loras = 0 + for i in embeddings: + num_loras += 1 if embeddings[i] else 0 if "model" in controlnets: for i, model in enumerate(controlnets["model"]): if "xl" not in base_model_id.lower(): - custom_model_map[f"control_adapter_{model}"] = { + adapters[f"control_adapter_{model}"] = { "hf_id": control_adapter_map[ "runwayml/stable-diffusion-v1-5" ][model], "strength": controlnets["strength"][i], } else: - custom_model_map[f"control_adapter_{model}"] = { + adapters[f"control_adapter_{model}"] = { "hf_id": control_adapter_map[ "stabilityai/stable-diffusion-xl-1.0" ][model], "strength": controlnets["strength"][i], } + if model is not None: + is_controlled=True control_mode = controlnets["control_mode"] for i in controlnets["hint"]: hints.append[i] @@ -288,16 +354,19 @@ def shark_sd_fn( "base_model_id": base_model_id, "height": height, "width": width, + "batch_size": batch_size, "precision": precision, "device": device, - "custom_model_map": custom_model_map, - "embeddings": embeddings, + "custom_vae": custom_vae, + "num_loras": num_loras, "import_ir": cmd_opts.import_mlir, "is_img2img": is_img2img, + "is_controlled": is_controlled, } submit_prep_kwargs = { "scheduler": scheduler, - "custom_model_map": custom_model_map, + "custom_weights": custom_weights, + "adapters": adapters, "embeddings": embeddings, } submit_run_kwargs = { @@ -313,6 +382,7 @@ def shark_sd_fn( "control_mode": control_mode, "hints": hints, } + print(submit_pipe_kwargs) if ( not global_obj.get_sd_obj() or global_obj.get_pipe_kwargs() != submit_pipe_kwargs diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index 8e72b0cd8a..25edd3109c 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -92,6 +92,8 @@ def process_custom_pipe_weights(custom_weights): custom_weights_tgt = get_path_to_diffusers_checkpoint( custom_weights ) + custom_weights_params = custom_weights + return custom_weights_params, custom_weights_tgt def get_civitai_checkpoint(url: str): diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index 51746d350e..b3350dea0c 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -1,8 +1,14 @@ -from shark.iree_utils.compile_utils import get_iree_compiled_module +from msvcrt import kbhit +from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap from apps.shark_studio.web.utils.file_utils import ( get_checkpoints_path, get_resource_path, ) +from apps.shark_studio.modules.shared_cmd_opts import ( + cmd_opts, +) +from iree import runtime as ireert +from pathlib import Path import gc import os @@ -19,86 +25,152 @@ def __init__( self, model_map: dict, base_model_id: str, + static_kwargs: dict, device: str, import_mlir: bool = True, ): self.model_map = model_map + self.static_kwargs = static_kwargs self.base_model_id = base_model_id self.device = device self.import_mlir = import_mlir self.iree_module_dict = {} + self.tempfiles = {} + + + def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: + # First checks whether we have .vmfbs precompiled, then populates the map + # with the precompiled executables and fetches executables for the rest of the map. + # The weights aren't static here anymore so this function should be a part of pipeline + # initialization. As soon as you have a pipeline ID unique to your static torch IR parameters, + # and your model map is populated with any IR - unique model IDs and their static params, + # call this method to get the artifacts associated with your map. + self.pipe_id = pipe_id + self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id)) + self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True) + print("\n[LOG] Checking for pre-compiled artifacts.") + if submodel == "None": + for key in self.model_map: + self.get_compiled_map(pipe_id, submodel=key) + else: + self.get_precompiled(pipe_id, submodel) + ireec_flags = [] + if submodel in self.iree_module_dict: + if "vmfb" in self.iree_module_dict[submodel]: + print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...") + return + elif submodel not in self.tempfiles: + print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...") + if submodel in self.static_kwargs: + init_kwargs = self.static_kwargs[submodel] + for key in self.static_kwargs["pipe"]: + if key not in init_kwargs: + init_kwargs[key] = self.static_kwargs["pipe"][key] + self.import_torch_ir( + submodel, init_kwargs + ) + self.get_compiled_map(pipe_id, submodel) + else: + ireec_flags = self.model_map[submodel]["ireec_flags"] if "ireec_flags" in self.model_map[submodel] else [] + + if "external_weights_file" in self.model_map[submodel]: + weights_path = self.model_map[submodel]["external_weights_file"] + else: + weights_path = None + self.iree_module_dict[submodel] = get_iree_compiled_module( + self.tempfiles[submodel], + device=self.device, + frontend="torch", + mmap=True, + external_weight_file=weights_path, + extra_args=ireec_flags, + write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb") + ) + return + + + def hijack_weights(self, weights_path, submodel="None"): + if submodel == "None": + for i in self.model_map: + self.hijack_weights(weights_path, i) + else: + if submodel in self.iree_module_dict: + self.model_map[submodel]["external_weights_file"] = weights_path + return + + + def get_precompiled(self, pipe_id, submodel="None"): + if submodel == "None": + for model in self.model_map: + self.get_precompiled(pipe_id, model) + vmfbs = [] + vmfb_matches = {} + vmfbs_path = self.pipe_vmfb_path + for dirpath, dirnames, filenames in os.walk(vmfbs_path): + vmfbs.extend(filenames) + break + for file in vmfbs: + if submodel in file: + print(f"Found existing .vmfb at {file}") + self.iree_module_dict[submodel] = {} + ( + self.iree_module_dict[submodel]["vmfb"], + self.iree_module_dict[submodel]["config"], + self.iree_module_dict[submodel]["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + os.path.join(vmfbs_path, file), + self.device, + device_idx=0, + rt_flags=[], + external_weight_file=self.model_map[submodel]['external_weight_file'], + ) + return + + + def safe_dict(self, kwargs: dict): + flat_args = {} + for i in kwargs: + if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]: + flat_args[i] = [kwargs[i][j] for j in kwargs[i]] + else: + flat_args[i] = kwargs[i] + + return flat_args + def import_torch_ir(self, submodel, kwargs): - weights = ( - submodel["custom_weights"] if submodel["custom_weights"] else None - ) torch_ir = self.model_map[submodel]["initializer"]( - self.base_model_id, **kwargs, compile_to="torch" + **self.safe_dict(kwargs), compile_to="torch" ) - self.model_map[submodel]["tempfile_name"] = get_resource_path( - f"{submodel}.torch.tempfile" - ) - with open(self.model_map[submodel]["tempfile_name"], "w+") as f: + if submodel == "clip": + # clip.export_clip_model returns (torch_ir, tokenizer) + torch_ir = torch_ir[0] + self.tempfiles[submodel] = get_resource_path(os.path.join( + "..", "shark_tmp", f"{submodel}.torch.tempfile" + )) + + with open(self.tempfiles[submodel], "w+") as f: f.write(torch_ir) del torch_ir gc.collect() + return - def load_vmfb(self, submodel): - if submodel in self.iree_module_dict: - print( - f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}" - ) - elif self.model_map[submodel]["tempfile_name"]: - submodel["tempfile_name"] - - return submodel["vmfb"] - - def merge_custom_map(self, custom_model_map): - for submodel in custom_model_map: - for key in submodel: - self.model_map[submodel][key] = key - print(self.model_map) - - def get_local_vmfbs(self, pipe_id): - for submodel in self.model_map: - vmfbs = [] - vmfb_matches = {} - vmfbs_path = get_checkpoints_path("../vmfbs") - for dirpath, dirnames, filenames in os.walk(vmfbs_path): - vmfbs.extend(filenames) - break - for file in vmfbs: - if all(keys in file for keys in [submodel, pipe_id]): - print(f"Found existing .vmfb at {file}") - self.iree_module_dict[submodel] = {"vmfb": file} - - def get_compiled_map(self, device, pipe_id) -> None: - # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". - if not self.import_mlir: - self.get_local_vmfbs(pipe_id) - for submodel in self.model_map: + + def load_submodels(self, submodels: list): + for submodel in submodels: if submodel in self.iree_module_dict: - if "vmfb" in self.iree_module_dict[submodel]: - continue - if "tempfile_name" not in self.model_map[submodel]: - sub_kwargs = ( - self.model_map[submodel]["kwargs"] - if self.model_map[submodel]["kwargs"] - else {} - ) - self.import_torch_ir( - submodel, self.base_model_id, **sub_kwargs - ) - self.iree_module_dict[submodel] = get_iree_compiled_module( - submodel["tempfile_name"], - device=self.device, - frontend="torch", - external_weight_file=submodel["custom_weights"], + print( + f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}" ) - # TODO: delete the temp file + else: + self.get_compiled_map(self.pipe_id, submodel) + return + def run(self, submodel, inputs): - return + inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)] + return self.iree_module_dict[submodel]['vmfb']['main'](*inp) + def safe_name(name): return name.replace("/", "_").replace("-", "_") diff --git a/apps/shark_studio/modules/prompt_encoding.py b/apps/shark_studio/modules/prompt_encoding.py new file mode 100644 index 0000000000..d97d334f29 --- /dev/null +++ b/apps/shark_studio/modules/prompt_encoding.py @@ -0,0 +1,431 @@ + +from typing import List, Optional, Union +from iree import runtime as ireert +import re +import torch + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: + text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights( + pipe, prompt: List[str], max_length: int +): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print( + "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" + ) + return tokens, weights + + +def pad_tokens_and_weights( + tokens, + weights, + max_length, + bos, + eos, + no_boseos_middle=True, + chunk_length=77, +): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = 8 + weights_length = ( + max_length + if no_boseos_middle + else max_embeddings_multiples * chunk_length + ) + for i in range(len(tokens)): + tokens[i] = ( + [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + ) + if no_boseos_middle: + weights[i] = ( + [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + ) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][ + j + * (chunk_length - 2) : min( + len(weights[i]), (j + 1) * (chunk_length - 2) + ) + ] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe, + text_input: torch.Tensor, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, +): + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[ + :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 + ].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + + text_embedding = pipe.run("clip", text_input_chunk)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + # SHARK: Convert the result to tensor + # text_embeddings = torch.concat(text_embeddings, axis=1) + text_embeddings_np = np.concatenate(np.array(text_embeddings)) + text_embeddings = torch.from_numpy(text_embeddings_np)[None, :] + else: + text_embeddings = pipe.run("clip", text_input)[0] + # text_embeddings = torch.from_numpy(text_embeddings)[None, :] + return torch.from_numpy(text_embeddings.to_host()) + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = 8 + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[ + :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 + ].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + # text_embedding = pipe.text_encoder(text_input_chunk)[0] + + print(text_input_chunk) + breakpoint() + text_embedding = pipe.run("clip", text_input_chunk) + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + # SHARK: Convert the result to tensor + # text_embeddings = torch.concat(text_embeddings, axis=1) + text_embeddings_np = np.concatenate(np.array(text_embeddings)) + text_embeddings = torch.from_numpy(text_embeddings_np)[None, :] + return text_embeddings + + +# This function deals with NoneType values occuring in tokens after padding +# It switches out None with 49407 as truncating None values causes matrix dimension errors, +def filter_nonetype_tokens(tokens: List[List]): + return [[49407 if token is None else token for token in tokens[0]]] + + +def get_weighted_text_embeddings( + pipe, + prompt: List[str], + uncond_prompt: List[str] = None, + max_embeddings_multiples: Optional[int] = 8, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, +): + max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights( + pipe, prompt, max_length - 2 + ) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = get_prompts_with_weights( + pipe, uncond_prompt, max_length - 2 + ) + else: + prompt_tokens = [ + token[1:-1] + for token in pipe.tokenizer( + prompt, max_length=max_length, truncation=True + ).input_ids + ] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] + for token in pipe.tokenizer( + uncond_prompt, max_length=max_length, truncation=True + ).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max( + max_length, max([len(token) for token in uncond_tokens]) + ) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.model_max_length, + ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + prompt_tokens = filter_nonetype_tokens(prompt_tokens) + + # prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu") + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.model_max_length, + ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + uncond_tokens = filter_nonetype_tokens(uncond_tokens) + + # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + uncond_tokens = torch.tensor( + uncond_tokens, dtype=torch.long, device="cpu" + ) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + # prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + prompt_weights = torch.tensor( + prompt_weights, dtype=torch.float, device="cpu" + ) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + # uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + uncond_weights = torch.tensor( + uncond_weights, dtype=torch.float, device="cpu" + ) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = ( + text_embeddings.float() + .mean(axis=[-2, -1]) + .to(text_embeddings.dtype) + ) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = ( + text_embeddings.float() + .mean(axis=[-2, -1]) + .to(text_embeddings.dtype) + ) + text_embeddings *= ( + (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + ) + if uncond_prompt is not None: + previous_mean = ( + uncond_embeddings.float() + .mean(axis=[-2, -1]) + .to(uncond_embeddings.dtype) + ) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = ( + uncond_embeddings.float() + .mean(axis=[-2, -1]) + .to(uncond_embeddings.dtype) + ) + uncond_embeddings *= ( + (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + ) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index c62646f69c..484c8384a6 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -1,4 +1,105 @@ # from shark_turbine.turbine_models.schedulers import export_scheduler_model +from diffusers import ( + LCMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDPMScheduler, + DDIMScheduler, + DPMSolverMultistepScheduler, + KDPM2DiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, + DPMSolverSinglestepScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, +) + + +def get_schedulers(model_id): + #TODO: switch over to turbine and run all on GPU + print(f"[LOG] Initializing schedulers from model id: {model_id}") + schedulers = dict() + schedulers["PNDM"] = PNDMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DDPM"] = DDPMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DDIM"] = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["LCMScheduler"] = LCMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "DPMSolverMultistep" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", algorithm_type="dpmsolver" + ) + schedulers[ + "DPMSolverMultistep++" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", algorithm_type="dpmsolver++" + ) + schedulers[ + "DPMSolverMultistepKarras" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + use_karras_sigmas=True, + ) + schedulers[ + "DPMSolverMultistepKarras++" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + algorithm_type="dpmsolver++", + use_karras_sigmas=True, + ) + schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "DPMSolverSinglestep" + ] = DPMSolverSinglestepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "KDPM2AncestralDiscrete" + ] = KDPM2AncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + return schedulers def export_scheduler_model(model): diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index decbbd693f..535e5d2c7f 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -453,12 +453,6 @@ def is_valid_file(arg): "Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)", ) -p.add_argument( - "--custom_model_map", - type=str, - default="", - help="path to custom model map to import. This should be a .json file", -) ############################################################################## # IREE - Vulkan supported flags ############################################################################## diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 7d6a56728e..85de5c4126 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -126,26 +126,6 @@ def webui(): # # uvicorn.run(api, host="0.0.0.0", port=args.server_port) # sys.exit(0) - # Setup to use shark_tmp for gradio's temporary image files and clear any - # existing temporary images there if they exist. Then we can import gradio. - # It has to be in this order or gradio ignores what we've set up. - from apps.shark_studio.web.utils.tmp_configs import ( - config_tmp, - clear_tmp_mlir, - clear_tmp_imgs, - ) - from apps.shark_studio.web.utils.file_utils import ( - create_checkpoint_folders, - ) - - import gradio as gr - - config_tmp() - clear_tmp_mlir() - clear_tmp_imgs() - - # Create custom models folders if they don't exist - create_checkpoint_folders() def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index 129c7ef88a..e7b8fd72c4 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -50,7 +50,7 @@ def get_generated_imgs_todays_subdir() -> str: def create_checkpoint_folders(): - dir = ["vae", "lora"] + dir = ["vae", "lora", "../vmfb"] if not cmd_opts.ckpt_dir: dir.insert(0, "models") else: diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index bae1908e1c..25c363f652 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -106,7 +106,6 @@ def get_iree_frontend_args(frontend): # Common args to be used given any frontend or device. def get_iree_common_args(debug=False): common_args = [ - "--iree-stream-resource-max-allocation-size=4294967295", "--iree-vm-bytecode-module-strip-source-map=true", "--iree-util-zero-fill-elided-attrs", ]