From 92cde1b78264a8a609e93cae9e2faeea0a8d0d07 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 11 Dec 2023 23:19:01 -0600 Subject: [PATCH] Enable UI / bugfixes / tweaks --- apps/shark_studio/api/controlnet.py | 134 +++++ apps/shark_studio/api/initializers.py | 13 +- apps/shark_studio/api/sd.py | 191 ++++++- apps/shark_studio/api/utils.py | 143 ++++- apps/shark_studio/modules/checkpoint_proc.py | 66 +++ apps/shark_studio/modules/embeddings.py | 60 +++ apps/shark_studio/modules/img_processing.py | 115 ++-- apps/shark_studio/modules/pipeline.py | 71 +++ .../{api => modules}/schedulers.py | 0 apps/shark_studio/modules/shared_cmd_opts.py | 42 +- apps/shark_studio/web/api/compat.py | 209 +++++--- apps/shark_studio/web/configs/foo.json | 1 + apps/shark_studio/web/index.py | 33 +- apps/shark_studio/web/ui/common_events.py | 55 ++ .../shark_studio/web/ui/css/sd_dark_theme.css | 324 ++++++++++++ apps/shark_studio/web/ui/logos/nod-icon.png | Bin 0 -> 16058 bytes apps/shark_studio/web/ui/logos/nod-logo.png | Bin 0 -> 10641 bytes apps/shark_studio/web/ui/outputgallery.py | 416 +++++++++++++++ apps/shark_studio/web/ui/sd.py | 493 +++++++++--------- apps/shark_studio/web/ui/utils.py | 35 +- apps/shark_studio/web/utils/globals.py | 74 +++ .../web/utils/metadata/__init__.py | 6 + .../web/utils/metadata/csv_metadata.py | 45 ++ .../web/utils/metadata/display.py | 53 ++ .../web/utils/metadata/exif_metadata.py | 52 ++ .../shark_studio/web/utils/metadata/format.py | 143 +++++ .../web/utils/metadata/png_metadata.py | 222 ++++++++ apps/shark_studio/web/utils/state.py | 41 ++ apps/shark_studio/web/utils/tmp_configs.py | 77 +++ 29 files changed, 2648 insertions(+), 466 deletions(-) create mode 100644 apps/shark_studio/api/controlnet.py create mode 100644 apps/shark_studio/modules/checkpoint_proc.py create mode 100644 apps/shark_studio/modules/pipeline.py rename apps/shark_studio/{api => modules}/schedulers.py (100%) create mode 100644 apps/shark_studio/web/configs/foo.json create mode 100644 apps/shark_studio/web/ui/common_events.py create mode 100644 apps/shark_studio/web/ui/css/sd_dark_theme.css create mode 100644 apps/shark_studio/web/ui/logos/nod-icon.png create mode 100644 apps/shark_studio/web/ui/logos/nod-logo.png create mode 100644 apps/shark_studio/web/ui/outputgallery.py create mode 100644 apps/shark_studio/web/utils/globals.py create mode 100644 apps/shark_studio/web/utils/metadata/__init__.py create mode 100644 apps/shark_studio/web/utils/metadata/csv_metadata.py create mode 100644 apps/shark_studio/web/utils/metadata/display.py create mode 100644 apps/shark_studio/web/utils/metadata/exif_metadata.py create mode 100644 apps/shark_studio/web/utils/metadata/format.py create mode 100644 apps/shark_studio/web/utils/metadata/png_metadata.py create mode 100644 apps/shark_studio/web/utils/state.py create mode 100644 apps/shark_studio/web/utils/tmp_configs.py diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py new file mode 100644 index 0000000000..ea8cdf0cc9 --- /dev/null +++ b/apps/shark_studio/api/controlnet.py @@ -0,0 +1,134 @@ +# from turbine_models.custom_models.controlnet import control_adapter, preprocessors + + +class control_adapter: + def __init__( + self, + model: str, + ): + self.model = None + + def export_control_adapter_model(model_keyword): + return None + + def export_xl_control_adapter_model(model_keyword): + return None + + +class preprocessors: + def __init__( + self, + model: str, + ): + self.model = None + + def export_controlnet_model(model_keyword): + return None + + +control_adapter_map = { + "sd15": { + "canny": {"initializer": control_adapter.export_control_adapter_model}, + "openpose": { + "initializer": control_adapter.export_control_adapter_model + }, + "scribble": { + "initializer": control_adapter.export_control_adapter_model + }, + "zoedepth": { + "initializer": control_adapter.export_control_adapter_model + }, + }, + "sdxl": { + "canny": { + "initializer": control_adapter.export_xl_control_adapter_model + }, + }, +} +preprocessor_model_map = { + "canny": {"initializer": preprocessors.export_controlnet_model}, + "openpose": {"initializer": preprocessors.export_controlnet_model}, + "scribble": {"initializer": preprocessors.export_controlnet_model}, + "zoedepth": {"initializer": preprocessors.export_controlnet_model}, +} + + +class PreprocessorModel: + def __init__( + self, + hf_model_id, + device, + ): + self.model = None + + def compile(self, device): + print("compile not implemented for preprocessor.") + return + + def run(self, inputs): + print("run not implemented for preprocessor.") + return + + +def cnet_preview(model, input_img, stencils, images, preprocessed_hints): + if isinstance(input_image, PIL.Image.Image): + img_dict = { + "background": None, + "layers": [None], + "composite": input_image, + } + input_image = EditorValue(img_dict) + images[index] = input_image + if model: + stencils[index] = model + match model: + case "canny": + canny = CannyDetector() + result = canny( + np.array(input_image["composite"]), + 100, + 200, + ) + preprocessed_hints[index] = Image.fromarray(result) + return ( + Image.fromarray(result), + stencils, + images, + preprocessed_hints, + ) + case "openpose": + openpose = OpenposeDetector() + result = openpose(np.array(input_image["composite"])) + preprocessed_hints[index] = Image.fromarray(result[0]) + return ( + Image.fromarray(result[0]), + stencils, + images, + preprocessed_hints, + ) + case "zoedepth": + zoedepth = ZoeDetector() + result = zoedepth(np.array(input_image["composite"])) + preprocessed_hints[index] = Image.fromarray(result) + return ( + Image.fromarray(result), + stencils, + images, + preprocessed_hints, + ) + case "scribble": + preprocessed_hints[index] = input_image["composite"] + return ( + input_image["composite"], + stencils, + images, + preprocessed_hints, + ) + case _: + preprocessed_hints[index] = None + return ( + None, + stencils, + images, + preprocessed_hints, + ) diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index 432eaf5331..bbb273354c 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -26,22 +26,21 @@ def imports(): startup_timer.record("import gradio") - # from apps.shark_studio.modules import shared_init - # shared_init.initialize() - # startup_timer.record("initialize shared") + import apps.shark_studio.web.utils.globals as global_obj + + global_obj._init() + startup_timer.record("initialize globals") from apps.shark_studio.modules import ( - processing, - gradio_extensons, - ui, + img_processing, ) # noqa: F401 + from apps.shark_studio.modules.schedulers import scheduler_model_map startup_timer.record("other imports") def initialize(): configure_sigint_handler() - configure_opts_onchange() # from apps.shark_studio.modules import modelloader # modelloader.cleanup_models() diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index f4b979d1fc..a601a068f7 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -1,9 +1,13 @@ from turbine_models.custom_models.sd_inference import clip, unet, vae from shark.iree_utils.compile_utils import get_iree_compiled_module from apps.shark_studio.api.utils import get_resource_path +from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.web.utils.state import status_label +from apps.shark_studio.modules.pipeline import SharkPipelineBase import iree.runtime as ireert import gc import torch +import gradio as gr sd_model_map = { "CompVis/stable-diffusion-v1-4": { @@ -86,16 +90,15 @@ class StableDiffusion(SharkPipelineBase): - # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile # a list of necessary modules or a combined "pipeline module" for a # specified job based on the inference task. - # + # # custom_model_ids: a dict of submodel + HF ID pairs for custom submodels. # e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"} - # + # # embeddings: a dict of embedding checkpoints or model IDs to use when # initializing the compiled modules. @@ -107,7 +110,6 @@ def __init__( precision: str = "fp16", device: str = None, custom_model_map: dict = {}, - custom_weights_map: dict = {}, embeddings: dict = {}, import_ir: bool = True, ): @@ -118,12 +120,185 @@ def __init__( self.iree_module_dict = None self.get_compiled_map() + def prepare_pipeline(self, scheduler, custom_model_map): + return None def generate_images( - self, - prompt, - ): - return result_output, + self, + prompt, + negative_prompt, + steps, + strength, + guidance_scale, + seed, + ondemand, + repeatable_seeds, + resample_type, + control_mode, + preprocessed_hints, + ): + return None, None, None, None, None + + +# NOTE: Each `hf_model_id` should have its own starting configuration. + +# model_vmfb_key = "" + + +def shark_sd_fn( + prompt, + negative_prompt, + image_dict, + height: int, + width: int, + steps: int, + strength: float, + guidance_scale: float, + seed: str | int, + batch_count: int, + batch_size: int, + scheduler: str, + base_model_id: str, + custom_weights: str, + custom_vae: str, + precision: str, + device: str, + lora_weights: str | list, + ondemand: bool, + repeatable_seeds: bool, + resample_type: str, + control_mode: str, + stencils: list, + images: list, + preprocessed_hints: list, + progress=gr.Progress(), +): + # Handling gradio ImageEditor datatypes so we have unified inputs to the SD API + for i, stencil in enumerate(stencils): + if images[i] is None and stencil is not None: + continue + elif stencil is None and any( + img is not None for img in [images[i], preprocessed_hints[i]] + ): + images[i] = None + preprocessed_hints[i] = None + elif images[i] is not None: + if isinstance(images[i], dict): + images[i] = images[i]["composite"] + images[i] = images[i].convert("RGB") + + if isinstance(image_dict, PIL.Image.Image): + image = image_dict.convert("RGB") + elif image_dict: + image = image_dict["image"].convert("RGB") + else: + image = None + is_img2img = False + if image: + ( + image, + _, + _, + ) = resize_stencil(image, width, height) + is_img2img = True + print("Performing Stable Diffusion Pipeline setup...") + + device_id = None + + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import apps.shark_studio.web.utils.globals as global_obj + + custom_model_map = {} + if custom_weights != "None": + custom_model_map["unet"] = {"custom_weights": custom_weights} + if custom_vae != "None": + custom_model_map["vae"] = {"custom_weights": custom_vae} + if stencils: + for i, stencil in enumerate(stencils): + if "xl" not in base_model_id.lower(): + custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ + "runwayml/stable-diffusion-v1-5" + ][stencil] + else: + custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ + "stabilityai/stable-diffusion-xl-1.0" + ][stencil] + + submit_pipe_kwargs = { + "base_model_id": base_model_id, + "height": height, + "width": width, + "precision": precision, + "device": device, + "custom_model_map": custom_model_map, + "import_ir": cmd_opts.import_mlir, + "is_img2img": is_img2img, + } + submit_prep_kwargs = { + "scheduler": scheduler, + "custom_model_map": custom_model_map, + "embeddings": lora_weights, + } + submit_run_kwargs = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "steps": steps, + "strength": strength, + "guidance_scale": guidance_scale, + "seed": seed, + "ondemand": ondemand, + "repeatable_seeds": repeatable_seeds, + "resample_type": resample_type, + "control_mode": control_mode, + "preprocessed_hints": preprocessed_hints, + } + + global sd_pipe + global sd_pipe_kwargs + + if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs: + sd_pipe = None + sd_pipe_kwargs = submit_pipe_kwargs + gc.collect() + + if sd_pipe is None: + history[-1][-1] = "Getting the pipeline ready..." + yield history, "" + + # Initializes the pipeline and retrieves IR based on all + # parameters that are static in the turbine output format, + # which is currently MLIR in the torch dialect. + + sd_pipe = SharkStableDiffusionPipeline( + **submit_pipe_kwargs, + ) + + sd_pipe.prepare_pipe(**submit_prep_kwargs) + + for prompt, msg, exec_time in progress.tqdm( + out_imgs=sd_pipe.generate_images(**submit_run_kwargs), + desc="Generating Image...", + ): + text_output = get_generation_text_info( + seeds[: current_batch + 1], device + ) + save_output_img( + out_imgs[0], + seeds[current_batch], + extra_info, + ) + generated_imgs.extend(out_imgs) + yield generated_imgs, text_output, status_label( + "Stable Diffusion", current_batch + 1, batch_count, batch_size + ), stencils, images + + return generated_imgs, text_output, "", stencils, images + + +def cancel_sd(): + print("Inject call to cancel longer API calls.") + return + if __name__ == "__main__": sd = StableDiffusion( diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 8139ed9cb5..120ec3adfa 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -2,6 +2,7 @@ import sys import os import numpy as np +import glob from random import ( randint, seed as seed_random, @@ -12,6 +13,19 @@ from pathlib import Path from safetensors.torch import load_file from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from cpuinfo import get_cpu_info + +# TODO: migrate these utils to studio +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, + get_iree_vulkan_runtime_flags, +) + +checkpoints_filetypes = ( + "*.ckpt", + "*.safetensors", +) def get_available_devices(): @@ -75,6 +89,67 @@ def get_devices_by_name(driver_name): return available_devices +def set_init_device_flags(): + if "vulkan" in cmd_opts.device: + # set runtime flags for vulkan. + set_iree_runtime_flags() + + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_vulkan_target_triple: + triple = get_vulkan_target_triple(device_name) + if triple is not None: + cmd_opts.iree_vulkan_target_triple = triple + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_vulkan_target_triple}." + ) + elif "cuda" in cmd_opts.device: + cmd_opts.device = "cuda" + elif "metal" in cmd_opts.device: + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_metal_target_platform: + triple = get_metal_target_triple(device_name) + if triple is not None: + cmd_opts.iree_metal_target_platform = triple.split("-")[-1] + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_metal_target_platform}." + ) + elif "cpu" in cmd_opts.device: + cmd_opts.device = "cpu" + + +def set_iree_runtime_flags(): + # TODO: This function should be device-agnostic and piped properly + # to general runtime driver init. + vulkan_runtime_flags = get_iree_vulkan_runtime_flags() + if cmd_opts.enable_rgp: + vulkan_runtime_flags += [ + f"--enable_rgp=true", + f"--vulkan_debug_utils=true", + ] + if cmd_opts.device_allocator_heap_key: + vulkan_runtime_flags += [ + f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}", + ] + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + + +def get_all_devices(driver_name): + """ + Inputs: driver_name + Returns a list of all the available devices for a given driver sorted by + the iree path names of the device as in --list_devices option in iree. + """ + from iree.runtime import get_driver + + driver = get_driver(driver_name) + device_list_src = driver.query_available_devices() + device_list_src.sort(key=lambda d: d["path"]) + return device_list_src + + def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" base_path = getattr( @@ -83,26 +158,52 @@ def get_resource_path(relative_path): return os.path.join(base_path, relative_path) - def get_generated_imgs_path() -> Path: return Path( - cmd_opts.output_dir - if cmd_opts.output_dir + cmd_opts.output_dir + if cmd_opts.output_dir else get_resource_path("..\web\generated_imgs") -) + ) def get_generated_imgs_todays_subdir() -> str: return dt.now().strftime("%Y%m%d") -def get_checkpoints_path(model = ""): +def create_checkpoint_folders(): + dir = ["vae", "lora"] + if not cmd_opts.ckpt_dir: + dir.insert(0, "models") + else: + if not os.path.isdir(cmd_opts.ckpt_dir): + sys.exit( + f"Invalid --ckpt_dir argument, " + f"{args.ckpt_dir} folder does not exists." + ) + for root in dir: + Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True) + + +def get_checkpoints_path(model=""): return get_resource_path(f"..\web\models\{model}") -def get_checkpoints(path): - files = [] - for file in +def get_checkpoints(model="models"): + ckpt_files = [] + file_types = checkpoints_filetypes + if model == "lora": + file_types = file_types + ("*.pt", "*.bin") + for extn in file_types: + files = [ + os.path.basename(x) + for x in glob.glob(os.path.join(get_checkpoints_path(model), extn)) + ] + ckpt_files.extend(files) + return sorted(ckpt_files, key=str.casefold) + + +def get_checkpoint_pathfile(checkpoint_name, model="models"): + return os.path.join(get_checkpoints_path(model), checkpoint_name) def get_device_mapping(driver, key_combination=3): @@ -144,6 +245,30 @@ def get_output_value(dev_dict): return device_map +def get_opt_flags(model, precision="fp16"): + iree_flags = [] + if len(cmd_opts.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" + ) + if "rocm" in cmd_opts.device: + rocm_args = get_iree_rocm_args() + iree_flags.extend(rocm_args) + if cmd_opts.iree_constant_folding == False: + iree_flags.append("--iree-opt-const-expr-hoisting=False") + iree_flags.append( + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + ) + if cmd_opts.data_tiling == False: + iree_flags.append("--iree-opt-data-tiling=False") + + if "vae" not in model: + # Due to lack of support for multi-reduce, we always collapse reduction + # dims before dispatch formation right now. + iree_flags += ["--iree-flow-collapse-reduction-dims"] + return iree_flags + + def map_device_to_name_path(device, key_combination=3): """Gives the appropriate device data (supported name/path) for user selected execution device @@ -250,6 +375,7 @@ def parse_seed_input(seed_input: str | list | int): "Seed input must be an integer or an array of integers in JSON format" ) + # Generate and return a new seed if the provided one is not in the # supported range (including -1) def sanitize_seed(seed: int | str): @@ -260,6 +386,7 @@ def sanitize_seed(seed: int | str): seed = randint(uint32_min, uint32_max) return seed + # take a seed expression in an input format and convert it to # a list of integers, where possible def parse_seed_input(seed_input: str | list | int): diff --git a/apps/shark_studio/modules/checkpoint_proc.py b/apps/shark_studio/modules/checkpoint_proc.py new file mode 100644 index 0000000000..e924de4640 --- /dev/null +++ b/apps/shark_studio/modules/checkpoint_proc.py @@ -0,0 +1,66 @@ +import os +import json +import re +from pathlib import Path +from omegaconf import OmegaConf + + +def get_path_to_diffusers_checkpoint(custom_weights): + path = Path(custom_weights) + diffusers_path = path.parent.absolute() + diffusers_directory_name = os.path.join("diffusers", path.stem) + complete_path_to_diffusers = diffusers_path / diffusers_directory_name + complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) + path_to_diffusers = complete_path_to_diffusers.as_posix() + return path_to_diffusers + + +def preprocessCKPT(custom_weights, is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) + if next(Path(path_to_diffusers).iterdir(), None): + print("Checkpoint already loaded at : ", path_to_diffusers) + return + else: + print( + "Diffusers' checkpoint will be identified here : ", + path_to_diffusers, + ) + from_safetensors = ( + True if custom_weights.lower().endswith(".safetensors") else False + ) + # EMA weights usually yield higher quality images for inference but + # non-EMA weights have been yielding better results in our case. + # TODO: Add an option `--ema` (`--no-ema`) for users to specify if + # they want to go for EMA weight extraction or not. + extract_ema = False + print( + "Loading diffusers' pipeline from original stable diffusion checkpoint" + ) + num_in_channels = 9 if is_inpaint else 4 + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=custom_weights, + extract_ema=extract_ema, + from_safetensors=from_safetensors, + num_in_channels=num_in_channels, + ) + pipe.save_pretrained(path_to_diffusers) + print("Loading complete") + + +def convert_original_vae(vae_checkpoint): + vae_state_dict = {} + for key in list(vae_checkpoint.keys()): + vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key) + + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/" + "main/configs/stable-diffusion/v1-inference.yaml" + ) + original_config_file = BytesIO(requests.get(config_url).content) + original_config = OmegaConf.load(original_config_file) + vae_config = create_vae_diffusers_config(original_config, image_size=512) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + vae_state_dict, vae_config + ) + return converted_vae_checkpoint diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py index 5fc64c0ccc..d8cf544f81 100644 --- a/apps/shark_studio/modules/embeddings.py +++ b/apps/shark_studio/modules/embeddings.py @@ -1,5 +1,10 @@ +import os +import sys import torch +import json +import safetensors from safetensors.torch import load_file +from apps.shark_studio.api.utils import get_checkpoint_pathfile def processLoRA(model, use_lora, splitting_prefix): @@ -109,3 +114,58 @@ def update_lora_weight(model, use_lora, model_name): return processLoRA(model, use_lora, "lora_te_") except: return None + + +def get_lora_metadata(lora_filename): + # get the metadata from the file + filename = get_checkpoint_pathfile(lora_filename, "lora") + with safetensors.safe_open(filename, framework="pt", device="cpu") as f: + metadata = f.metadata() + + # guard clause for if there isn't any metadata + if not metadata: + return None + + # metadata is a dictionary of strings, the values of the keys we're + # interested in are actually json, and need to be loaded as such + tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}"))) + dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}"))) + tag_dirs = [dir for dir in tag_frequencies.keys()] + + # gather the tag frequency information for all the datasets trained + all_frequencies = {} + for dataset in tag_dirs: + frequencies = sorted( + [entry for entry in tag_frequencies[dataset].items()], + reverse=True, + key=lambda x: x[1], + ) + + # get a figure for the total number of images processed for this dataset + # either then number actually listed or in its dataset_dir entry or + # the highest frequency's number if that doesn't exist + img_count = dataset_dirs.get(dir, {}).get( + "img_count", frequencies[0][1] + ) + + # add the dataset frequencies to the overall frequencies replacing the + # frequency counts on the tags with a percentage/ratio + all_frequencies.update( + [(entry[0], entry[1] / img_count) for entry in frequencies] + ) + + trained_model_id = " ".join( + [ + metadata.get("ss_sd_model_hash", ""), + metadata.get("ss_sd_model_name", ""), + metadata.get("ss_base_model_version", ""), + ] + ).strip() + + # return the topmost of all frequencies in all datasets + return { + "model": trained_model_id, + "frequencies": sorted( + all_frequencies.items(), reverse=True, key=lambda x: x[1] + ), + } diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py index e709facbbf..b5cf28ce47 100644 --- a/apps/shark_studio/modules/img_processing.py +++ b/apps/shark_studio/modules/img_processing.py @@ -1,4 +1,8 @@ -from +import os +import sys +from PIL import Image +from pathlib import Path + # save output images and the inputs corresponding to it. def save_output_img(output_img, img_seed, extra_info=None): @@ -10,43 +14,45 @@ def save_output_img(output_img, img_seed, extra_info=None): generated_imgs_path.mkdir(parents=True, exist_ok=True) csv_path = Path(generated_imgs_path, "imgs_details.csv") - prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15]) + prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15]) out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}" - img_model = args.hf_model_id - if args.ckpt_loc: - img_model = Path(os.path.basename(args.ckpt_loc)).stem + img_model = cmd_opts.hf_model_id + if cmd_opts.ckpt_loc: + img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem img_vae = None - if args.custom_vae: - img_vae = Path(os.path.basename(args.custom_vae)).stem + if cmd_opts.custom_vae: + img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem img_lora = None - if args.use_lora: - img_lora = Path(os.path.basename(args.use_lora)).stem + if cmd_opts.use_lora: + img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem - if args.output_img_format == "jpg": + if cmd_opts.output_img_format == "jpg": out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") output_img.save(out_img_path, quality=95, subsampling=0) else: out_img_path = Path(generated_imgs_path, f"{out_img_name}.png") pngInfo = PngImagePlugin.PngInfo() - if args.write_metadata_to_png: + if cmd_opts.write_metadata_to_png: # Using a conditional expression caused problems, so setting a new # variable for now. - if args.use_hiresfix: - png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}" + if cmd_opts.use_hiresfix: + png_size_text = ( + f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}" + ) else: - png_size_text = f"{args.width}x{args.height}" + png_size_text = f"{cmd_opts.width}x{cmd_opts.height}" pngInfo.add_text( "parameters", - f"{args.prompts[0]}" - f"\nNegative prompt: {args.negative_prompts[0]}" - f"\nSteps: {args.steps}," - f"Sampler: {args.scheduler}, " - f"CFG scale: {args.guidance_scale}, " + f"{cmd_opts.prompts[0]}" + f"\nNegative prompt: {cmd_opts.negative_prompts[0]}" + f"\nSteps: {cmd_opts.steps}," + f"Sampler: {cmd_opts.scheduler}, " + f"CFG scale: {cmd_opts.guidance_scale}, " f"Seed: {img_seed}," f"Size: {png_size_text}, " f"Model: {img_model}, " @@ -56,9 +62,9 @@ def save_output_img(output_img, img_seed, extra_info=None): output_img.save(out_img_path, "PNG", pnginfo=pngInfo) - if args.output_img_format not in ["png", "jpg"]: + if cmd_opts.output_img_format not in ["png", "jpg"]: print( - f"[ERROR] Format {args.output_img_format} is not " + f"[ERROR] Format {cmd_opts.output_img_format} is not " f"supported yet. Image saved as png instead." f"Supported formats: png / jpg" ) @@ -68,18 +74,20 @@ def save_output_img(output_img, img_seed, extra_info=None): # importance for each data point. Something to consider. new_entry = { "VARIANT": img_model, - "SCHEDULER": args.scheduler, - "PROMPT": args.prompts[0], - "NEG_PROMPT": args.negative_prompts[0], + "SCHEDULER": cmd_opts.scheduler, + "PROMPT": cmd_opts.prompts[0], + "NEG_PROMPT": cmd_opts.negative_prompts[0], "SEED": img_seed, - "CFG_SCALE": args.guidance_scale, - "PRECISION": args.precision, - "STEPS": args.steps, - "HEIGHT": args.height - if not args.use_hiresfix - else args.hiresfix_height, - "WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width, - "MAX_LENGTH": args.max_length, + "CFG_SCALE": cmd_opts.guidance_scale, + "PRECISION": cmd_opts.precision, + "STEPS": cmd_opts.steps, + "HEIGHT": cmd_opts.height + if not cmd_opts.use_hiresfix + else cmd_opts.hiresfix_height, + "WIDTH": cmd_opts.width + if not cmd_opts.use_hiresfix + else cmd_opts.hiresfix_width, + "MAX_LENGTH": cmd_opts.max_length, "OUTPUT": out_img_path, "VAE": img_vae, "LORA": img_lora, @@ -95,37 +103,23 @@ def save_output_img(output_img, img_seed, extra_info=None): dictwriter_obj.writerow(new_entry) csv_obj.close() - if args.save_metadata_to_json: + if cmd_opts.save_metadata_to_json: del new_entry["OUTPUT"] json_path = Path(generated_imgs_path, f"{out_img_name}.json") with open(json_path, "w") as f: json.dump(new_entry, f, indent=4) -def get_generation_text_info(seeds, device): - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, " f"device={device}" - text_output += ( - f"\nsteps={args.steps}, " - f"guidance_scale={args.guidance_scale}, " - f"seed={seeds}" - ) - text_output += ( - f"\nsize={args.height}x{args.width}, " - if not args.use_hiresfix - else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, " - ) - text_output += ( - f"batch_count={args.batch_count}, " - f"batch_size={args.batch_size}, " - f"max_length={args.max_length}" - ) +resamplers = { + "Lanczos": Image.Resampling.LANCZOS, + "Nearest Neighbor": Image.Resampling.NEAREST, + "Bilinear": Image.Resampling.BILINEAR, + "Bicubic": Image.Resampling.BICUBIC, + "Hamming": Image.Resampling.HAMMING, + "Box": Image.Resampling.BOX, +} - return text_output +resampler_list = resamplers.keys() # For stencil, the input image can be of any size, but we need to ensure that @@ -133,7 +127,7 @@ def get_generation_text_info(seeds, device): # Both width and height should be in the range of [128, 768] and multiple of 8. # This utility function performs the transformation on the input image while # also maintaining the aspect ratio before sending it to the stencil pipeline. -def resize_stencil(image: Image.Image, width, height): +def resize_stencil(image: Image.Image, width, height, resampler_type=None): aspect_ratio = width / height min_size = min(width, height) if min_size < 128: @@ -166,6 +160,9 @@ def resize_stencil(image: Image.Image, width, height): n_height = height // 8 n_width *= 8 n_height *= 8 - new_image = image.resize((n_width, n_height)) + if resampler_type in resamplers: + resampler = resamplers[resampler_type] + else: + resampler = resamplers["Nearest Neighbor"] + new_image = image.resize((n_width, n_height), resampler=resampler) return new_image, n_width, n_height - diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py new file mode 100644 index 0000000000..c087175de4 --- /dev/null +++ b/apps/shark_studio/modules/pipeline.py @@ -0,0 +1,71 @@ +from shark.iree_utils.compile_utils import get_iree_compiled_module + + +class SharkPipelineBase: + # This class is a lightweight base for managing an + # inference API class. It should provide methods for: + # - compiling a set (model map) of torch IR modules + # - preparing weights for an inference job + # - loading weights for an inference job + # - utilites like benchmarks, tests + + def __init__( + self, + model_map: dict, + device: str, + import_mlir: bool = True, + ): + self.model_map = model_map + self.device = device + self.import_mlir = import_mlir + + def import_torch_ir(self, base_model_id): + for submodel in self.model_map: + hf_id = ( + submodel["custom_hf_id"] + if submodel["custom_hf_id"] + else base_model_id + ) + torch_ir = submodel["initializer"]( + hf_id, **submodel["init_kwargs"], compile_to="torch" + ) + submodel["tempfile_name"] = get_resource_path( + f"{submodel}.torch.tempfile" + ) + with open(submodel["tempfile_name"], "w+") as f: + f.write(torch_ir) + del torch_ir + gc.collect() + + def load_vmfb(self, submodel): + if self.iree_module_dict[submodel]: + print( + f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}" + ) + elif self.model_map[submodel]["tempfile_name"]: + submodel["tempfile_name"] + + return submodel["vmfb"] + + def merge_custom_map(self, custom_model_map): + for submodel in custom_model_map: + for key in submodel: + self.model_map[submodel][key] = key + print(self.model_map) + + def get_compiled_map(self, device) -> None: + # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". + for submodel in self.model_map: + if not self.iree_module_dict[submodel][vmfb]: + self.iree_module_dict[submodel] = get_iree_compiled_module( + submodel.tempfile_name, + device=self.device, + frontend="torch", + ) + # TODO: delete the temp file + + def run(self, submodel, inputs): + return + + def safe_name(name): + return name.replace("/", "_").replace("-", "_") diff --git a/apps/shark_studio/api/schedulers.py b/apps/shark_studio/modules/schedulers.py similarity index 100% rename from apps/shark_studio/api/schedulers.py rename to apps/shark_studio/modules/schedulers.py diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index 88434ff580..dfb166a52e 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -2,7 +2,7 @@ import os from pathlib import Path -from apps.stable_diffusion.src.utils.resamplers import resampler_list +from apps.shark_studio.modules.img_processing import resampler_list def path_expand(s): @@ -36,7 +36,7 @@ def is_valid_file(arg): nargs="+", default=[ "a photo taken of the front of a super-car drifting on a road near " - "mountains at high speeds with smokes coming off the tires, front " + "mountains at high speeds with smoke coming off the tires, front " "angle, front point of view, trees in the mountains of the " "background, ((sharp focus))" ], @@ -306,21 +306,6 @@ def is_valid_file(arg): "downloads the model from shark_tank.", ) -p.add_argument( - "--load_vmfb", - default=True, - action=argparse.BooleanOptionalAction, - help="Attempts to load the model from a precompiled flat-buffer " - "and compiles + saves it if not found.", -) - -p.add_argument( - "--save_vmfb", - default=False, - action=argparse.BooleanOptionalAction, - help="Saves the compiled flat-buffer to the local directory.", -) - p.add_argument( "--use_tuned", default=False, @@ -446,7 +431,7 @@ def is_valid_file(arg): ) p.add_argument( - "--ondemand", + "--lowvram", default=False, action=argparse.BooleanOptionalAction, help="Load and unload models for low VRAM.", @@ -469,10 +454,10 @@ def is_valid_file(arg): ) p.add_argument( - "--autogen", - type=bool, - default="False", - help="Only used for a gradio workaround.", + "--custom_model_map", + type=str, + default="", + help="path to custom model map to import. This should be a .json file", ) ############################################################################## # IREE - Vulkan supported flags @@ -612,6 +597,13 @@ def is_valid_file(arg): # Web UI flags ############################################################################## +p.add_argument( + "--webui", + default=True, + action=argparse.BooleanOptionalAction, + help="controls whether the webui is launched.", +) + p.add_argument( "--progress_bar", default=True, @@ -764,8 +756,8 @@ def is_valid_file(arg): "or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name", ) -args, unknown = p.parse_known_args() -if args.import_debug: +cmd_opts, unknown = p.parse_known_args() +if cmd_opts.import_debug: os.environ["IREE_SAVE_TEMPS"] = os.path.join( - os.getcwd(), args.hf_model_id.replace("/", "_") + os.getcwd(), cmd_opts.hf_model_id.replace("/", "_") ) diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py index c5fafd7ad2..80399505c4 100644 --- a/apps/shark_studio/web/api/compat.py +++ b/apps/shark_studio/web/api/compat.py @@ -15,7 +15,7 @@ from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder -from apps.shark_studio. import sd_samplers, postprocessing, errors, restart +from apps.shark_studio.modules.img_processing import sampler_list from sdapi_v1 import shark_sd_api from api.llm import chat_api @@ -26,15 +26,21 @@ def decode_base64_to_image(encoding): raise HTTPException(status_code=500, detail="Requests not allowed") if opts.api_forbid_local_requests and not verify_url(encoding): - raise HTTPException(status_code=500, detail="Request to local resource not allowed") + raise HTTPException( + status_code=500, detail="Request to local resource not allowed" + ) - headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} + headers = ( + {"user-agent": opts.api_useragent} if opts.api_useragent else {} + ) response = requests.get(encoding, timeout=30, headers=headers) try: image = Image.open(BytesIO(response.content)) return image except Exception as e: - raise HTTPException(status_code=500, detail="Invalid image url") from e + raise HTTPException( + status_code=500, detail="Invalid image url" + ) from e if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] @@ -42,32 +48,54 @@ def decode_base64_to_image(encoding): image = Image.open(BytesIO(base64.b64decode(encoding))) return image except Exception as e: - raise HTTPException(status_code=500, detail="Invalid encoded image") from e + raise HTTPException( + status_code=500, detail="Invalid encoded image" + ) from e def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - - if opts.samples_format.lower() == 'png': + if opts.samples_format.lower() == "png": use_metadata = False metadata = PngImagePlugin.PngInfo() for key, value in image.info.items(): if isinstance(key, str) and isinstance(value, str): metadata.add_text(key, value) use_metadata = True - image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + image.save( + output_bytes, + format="PNG", + pnginfo=(metadata if use_metadata else None), + quality=opts.jpeg_quality, + ) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): if image.mode == "RGBA": image = image.convert("RGB") - parameters = image.info.get('parameters', None) - exif_bytes = piexif.dump({ - "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } - }) + parameters = image.info.get("parameters", None) + exif_bytes = piexif.dump( + { + "Exif": { + piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump( + parameters or "", encoding="unicode" + ) + } + } + ) if opts.samples_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + image.save( + output_bytes, + format="JPEG", + exif=exif_bytes, + quality=opts.jpeg_quality, + ) else: - image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + image.save( + output_bytes, + format="WEBP", + exif=exif_bytes, + quality=opts.jpeg_quality, + ) else: raise HTTPException(status_code=500, detail="Invalid image format") @@ -80,10 +108,11 @@ def encode_pil_to_base64(image): def api_middleware(app: FastAPI): rich_available = False try: - if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: + if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: import anyio # importing just so it can be placed on silent list import starlette # importing just so it can be placed on silent list from rich.console import Console + console = Console() rich_available = True except Exception: @@ -95,35 +124,49 @@ async def log_and_time(req: Request, call_next): res: Response = await call_next(req) duration = str(round(time.time() - ts, 4)) res.headers["X-Process-Time"] = duration - endpoint = req.scope.get('path', 'err') - if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): - print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( - t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), - code=res.status_code, - ver=req.scope.get('http_version', '0.0'), - cli=req.scope.get('client', ('0:0.0.0', 0))[0], - prot=req.scope.get('scheme', 'err'), - method=req.scope.get('method', 'err'), - endpoint=endpoint, - duration=duration, - )) + endpoint = req.scope.get("path", "err") + if shared.cmd_opts.api_log and endpoint.startswith("/sdapi"): + print( + "API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format( + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get("http_version", "0.0"), + cli=req.scope.get("client", ("0:0.0.0", 0))[0], + prot=req.scope.get("scheme", "err"), + method=req.scope.get("method", "err"), + endpoint=endpoint, + duration=duration, + ) + ) return res def handle_exception(request: Request, e: Exception): err = { "error": type(e).__name__, - "detail": vars(e).get('detail', ''), - "body": vars(e).get('body', ''), + "detail": vars(e).get("detail", ""), + "body": vars(e).get("body", ""), "errors": str(e), } - if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if not isinstance( + e, HTTPException + ): # do not print backtrace on known httpexceptions message = f"API error: {request.method}: {request.url} {err}" if rich_available: print(message) - console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) + console.print_exception( + show_locals=True, + max_frames=2, + extra_lines=1, + suppress=[anyio, starlette], + word_wrap=False, + width=min([console.width, 200]), + ) else: errors.report(message, exc_info=True) - return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) + return JSONResponse( + status_code=vars(e).get("status_code", 500), + content=jsonable_encoder(err), + ) @app.middleware("http") async def exception_handling(request: Request, call_next): @@ -143,52 +186,48 @@ async def http_exception_handler(request: Request, e: HTTPException): class ApiCompat: def __init__(self, queue_lock: Lock): - self.router = APIRouter() self.app = FastAPI() self.queue_lock = queue_lock api_middleware(self.app) self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["post"]) self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["post"]) - #self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"]) - #self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) - #self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) - #self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) - #self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) - #self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) - #self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) - #self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - #self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) - #self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) - #self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) - #self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) - #self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) - #self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) - #self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) - #self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) - #self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) - #self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) - #self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) - #self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) - #self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) - #self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) - #self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) - #self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) - #self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) - #self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) - #self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) - #self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) - #self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - + # self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"]) + # self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) + # self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) + # self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) + # self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) + # self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) + # self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) + # self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + # self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) + # self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) + # self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + # self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) + # self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + # self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) + # self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) + # self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) + # self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) + # self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + # self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) + # self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + # self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) + # self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) + # self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) + # self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) + # self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) # chat APIs needed for compatibility with multiple extensions using OpenAI API - self.add_api_route( - "/v1/chat/completions", chat_api, methods=["post"] - ) + self.add_api_route("/v1/chat/completions", chat_api, methods=["post"]) self.add_api_route("/v1/completions", chat_api, methods=["post"]) self.add_api_route("/chat/completions", chat_api, methods=["post"]) self.add_api_route("/completions", chat_api, methods=["post"]) @@ -196,16 +235,26 @@ def __init__(self, queue_lock: Lock): "/v1/engines/codegen/completions", chat_api, methods=["post"] ) if studio.cmd_opts.api_server_stop: - self.add_api_route("/sdapi/v1/server-kill", self.kill_studio, methods=["POST"]) - self.add_api_route("/sdapi/v1/server-restart", self.restart_studio, methods=["POST"]) - self.add_api_route("/sdapi/v1/server-stop", self.stop_studio, methods=["POST"]) + self.add_api_route( + "/sdapi/v1/server-kill", self.kill_studio, methods=["POST"] + ) + self.add_api_route( + "/sdapi/v1/server-restart", + self.restart_studio, + methods=["POST"], + ) + self.add_api_route( + "/sdapi/v1/server-stop", self.stop_studio, methods=["POST"] + ) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - def add_api_route(self, path:str, endpoint, **kwargs): + def add_api_route(self, path: str, endpoint, **kwargs): if studio.cmd_opts.api_auth: - return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs + return self.app.add_api_route( + path, endpoint, dependencies=[Depends(self.auth)], **kwargs + ) return self.app.add_api_route(path, endpoint, **kwargs) def refresh_checkpoints(self): @@ -231,7 +280,13 @@ def skip(self): def launch(self, server_name, port, root_path): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, root_path=root_path) + uvicorn.run( + self.app, + host=server_name, + port=port, + timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, + root_path=root_path, + ) def kill_studio(self): restart.stop_program() @@ -246,7 +301,7 @@ def preprocess(self, args: dict): studio.state.begin(job="preprocess") preprocess(**args) studio.state.end() - return models.PreprocessResponse(info='preprocess complete') + return models.PreprocessResponse(info="preprocess complete") except: studio.state.end() diff --git a/apps/shark_studio/web/configs/foo.json b/apps/shark_studio/web/configs/foo.json new file mode 100644 index 0000000000..0967ef424b --- /dev/null +++ b/apps/shark_studio/web/configs/foo.json @@ -0,0 +1 @@ +{} diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index d678d0b647..6ff90b4dbc 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -3,12 +3,13 @@ import time import sys import logging +import apps.shark_studio.api.initializers as initialize from ui.chat import chat_element from ui.sd import sd_element from ui.outputgallery import outputgallery_element -from modules import timer, initialize +from apps.shark_studio.modules import timer startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -72,15 +73,13 @@ def launch_webui(address): def webui(): - from apps.shark_studio.shared_cmd_options import cmd_opts + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts logging.basicConfig(level=logging.DEBUG) launch_api = cmd_opts.api initialize.initialize() - from modules import shared, ui_tempdir, script_callbacks, ui, progress - # required to do multiprocessing in a pyinstaller freeze freeze_support() @@ -131,16 +130,23 @@ def webui(): # Setup to use shark_tmp for gradio's temporary image files and clear any # existing temporary images there if they exist. Then we can import gradio. # It has to be in this order or gradio ignores what we've set up. - from apps.shark_studio.web.initializers import ( - config_gradio_tmp_imgs_folder, - create_custom_models_folders, + from apps.shark_studio.web.utils.tmp_configs import ( + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, + ) + from apps.shark_studio.api.utils import ( + create_checkpoint_folders, ) - config_gradio_tmp_imgs_folder() import gradio as gr + config_tmp() + clear_tmp_mlir() + clear_tmp_imgs() + # Create custom models folders if they don't exist - create_custom_models_folders() + create_checkpoint_folders() def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" @@ -151,10 +157,7 @@ def resource_path(relative_path): dark_theme = resource_path("ui/css/sd_dark_theme.css") - from apps.shark_studio.web.ui import load_ui_from_script - - # init global sd pipeline and config - studio.state._init() + # from apps.shark_studio.web.ui import load_ui_from_script def register_button_click(button, selectedid, inputs, outputs): button.click( @@ -211,9 +214,9 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): if __name__ == "__main__": - from apps.shark_studio.shared_cmd_options import cmd_opts + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - if cmd_opts.nowebui: + if cmd_opts.webui == False: api_only() else: webui() diff --git a/apps/shark_studio/web/ui/common_events.py b/apps/shark_studio/web/ui/common_events.py new file mode 100644 index 0000000000..37555ed7ee --- /dev/null +++ b/apps/shark_studio/web/ui/common_events.py @@ -0,0 +1,55 @@ +from apps.shark_studio.web.ui.utils import ( + HSLHue, + hsl_color, +) +from apps.shark_studio.modules.embeddings import get_lora_metadata + + +# Answers HTML to show the most frequent tags used when a LoRA was trained, +# taken from the metadata of its .safetensors file. +def lora_changed(lora_file): + # tag frequency percentage, that gets maximum amount of the staring hue + TAG_COLOR_THRESHOLD = 0.55 + # tag frequency percentage, above which a tag is displayed + TAG_DISPLAY_THRESHOLD = 0.65 + # template for the html used to display a tag + TAG_HTML_TEMPLATE = '{tag}' + + if lora_file == "None": + return ["
No LoRA selected
"] + elif not lora_file.lower().endswith(".safetensors"): + return [ + "
Only metadata queries for .safetensors files are currently supported
" + ] + else: + metadata = get_lora_metadata(lora_file) + if metadata: + frequencies = metadata["frequencies"] + return [ + "".join( + [ + f'
Trained against weights in: {metadata["model"]}
' + ] + + [ + TAG_HTML_TEMPLATE.format( + color=hsl_color( + (tag[1] - TAG_COLOR_THRESHOLD) + / (1 - TAG_COLOR_THRESHOLD), + start=HSLHue.RED, + end=HSLHue.GREEN, + ), + tag=tag[0], + ) + for tag in frequencies + if tag[1] > TAG_DISPLAY_THRESHOLD + ], + ) + ] + elif metadata is None: + return [ + "
This LoRA does not publish tag frequency metadata
" + ] + else: + return [ + "
This LoRA has empty tag frequency metadata, or we could not parse it
" + ] diff --git a/apps/shark_studio/web/ui/css/sd_dark_theme.css b/apps/shark_studio/web/ui/css/sd_dark_theme.css new file mode 100644 index 0000000000..5686f0868c --- /dev/null +++ b/apps/shark_studio/web/ui/css/sd_dark_theme.css @@ -0,0 +1,324 @@ +/* +Apply Gradio dark theme to the default Gradio theme. +Procedure to upgrade the dark theme: +- Using your browser, visit http://localhost:8080/?__theme=dark +- Open your browser inspector, search for the .dark css class +- Copy .dark class declarations, apply them here into :root +*/ + +:root { + --body-background-fill: var(--background-fill-primary); + --body-text-color: var(--neutral-100); + --color-accent-soft: var(--neutral-700); + --background-fill-primary: var(--neutral-950); + --background-fill-secondary: var(--neutral-900); + --border-color-accent: var(--neutral-600); + --border-color-primary: var(--neutral-700); + --link-text-color-active: var(--secondary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --body-text-color-subdued: var(--neutral-400); + --shadow-spread: 1px; + --block-background-fill: var(--neutral-800); + --block-border-color: var(--border-color-primary); + --block_border_width: None; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: None; + --block-label-text-color: var(--neutral-200); + --block_shadow: None; + --block_title_background_fill: None; + --block_title_border_color: None; + --block_title_border_width: None; + --block-title-text-color: var(--neutral-200); + --panel-background-fill: var(--background-fill-secondary); + --panel-border-color: var(--border-color-primary); + --panel_border_width: None; + --checkbox-background-color: var(--neutral-800); + --checkbox-background-color-focus: var(--checkbox-background-color); + --checkbox-background-color-hover: var(--checkbox-background-color); + --checkbox-background-color-selected: var(--secondary-600); + --checkbox-border-color: var(--neutral-700); + --checkbox-border-color-focus: var(--secondary-500); + --checkbox-border-color-hover: var(--neutral-600); + --checkbox-border-color-selected: var(--secondary-600); + --checkbox-border-width: var(--input-border-width); + --checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); + --checkbox-label-border-color: var(--border-color-primary); + --checkbox-label-border-color-hover: var(--checkbox-label-border-color); + --checkbox-label-border-width: var(--input-border-width); + --checkbox-label-text-color: var(--body-text-color); + --checkbox-label-text-color-selected: var(--checkbox-label-text-color); + --error-background-fill: var(--background-fill-primary); + --error-border-color: var(--border-color-primary); + --error_border_width: None; + --error-text-color: #ef4444; + --input-background-fill: var(--neutral-800); + --input-background-fill-focus: var(--secondary-600); + --input-background-fill-hover: var(--input-background-fill); + --input-border-color: var(--border-color-primary); + --input-border-color-focus: var(--neutral-700); + --input-border-color-hover: var(--input-border-color); + --input_border_width: None; + --input-placeholder-color: var(--neutral-500); + --input_shadow: None; + --input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset); + --loader_color: None; + --slider_color: None; + --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600)); + --table-border-color: var(--neutral-700); + --table-even-background-fill: var(--neutral-950); + --table-odd-background-fill: var(--neutral-900); + --table-row-focus: var(--color-accent-soft); + --button-border-width: var(--input-border-width); + --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); + --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); + --button-cancel-border-color: #dc2626; + --button-cancel-border-color-hover: var(--button-cancel-border-color); + --button-cancel-text-color: white; + --button-cancel-text-color-hover: var(--button-cancel-text-color); + --button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600)); + --button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500)); + --button-primary-border-color: var(--primary-500); + --button-primary-border-color-hover: var(--button-primary-border-color); + --button-primary-text-color: white; + --button-primary-text-color-hover: var(--button-primary-text-color); + --button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700)); + --button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600)); + --button-secondary-border-color: var(--neutral-600); + --button-secondary-border-color-hover: var(--button-secondary-border-color); + --button-secondary-text-color: white; + --button-secondary-text-color-hover: var(--button-secondary-text-color); + --block-border-width: 1px; + --block-label-border-width: 1px; + --form-gap-width: 1px; + --error-border-width: 1px; + --input-border-width: 1px; +} + +/* SHARK theme */ +body { + background-color: var(--background-fill-primary); +} + +.generating.svelte-zlszon.svelte-zlszon { + border: none; +} + +.generating { + border: none !important; +} + +#chatbot { + height: 100% !important; +} + +/* display in full width for desktop devices */ +@media (min-width: 1536px) +{ + .gradio-container { + max-width: var(--size-full) !important; + } +} + +.gradio-container .contain { + padding: 0 var(--size-4) !important; +} + +#top_logo { + color: transparent; + background-color: transparent; + border-radius: 0 !important; + border: 0; +} + +#ui_title { + padding: var(--size-2) 0 0 var(--size-1); +} + +#demo_title_outer { + border-radius: 0; +} + +#prompt_box_outer div:first-child { + border-radius: 0 !important +} + +#prompt_box textarea, #negative_prompt_box textarea { + background-color: var(--background-fill-primary) !important; +} + +#prompt_examples { + margin: 0 !important; +} + +#prompt_examples svg { + display: none !important; +} + +#ui_body { + padding: var(--size-2) !important; + border-radius: 0.5em !important; +} + +#img_result+div { + display: none !important; +} + +footer { + display: none !important; +} + +#gallery + div { + border-radius: 0 !important; +} + +/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */ +#gallery .thumbnail-item.thumbnail-lg { + aspect-ratio: unset; + max-height: calc(55vh - (2 * var(--spacing-lg))); +} +@media (min-width: 1921px) { + /* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */ + #gallery .grid-wrap, #gallery .preview{ + min-height: calc(768px + 4px + var(--size-14)); + max-height: calc(768px + 4px + var(--size-14)); + } + /* Limit height to 768px_height + 2px_margin_height for the thumbnails */ + #gallery .thumbnail-item.thumbnail-lg { + max-height: 770px !important; + } +} +/* Don't upscale when viewing in solo image mode */ +#gallery .preview img { + object-fit: scale-down; +} +/* Navbar images in cover mode*/ +#gallery .preview .thumbnail-item img { + object-fit: cover; +} + +/* Limit the stable diffusion text output height */ +#std_output textarea { + max-height: 215px; +} + +/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */ +#gallery .wrap.default { + pointer-events: none; +} + +/* Import Png info box */ +#txt2img_prompt_image { + height: var(--size-32) !important; +} + +/* Hide "remove buttons" from ui dropdowns */ +#custom_model .token-remove.remove-all, +#lora_weights .token-remove.remove-all, +#scheduler .token-remove.remove-all, +#device .token-remove.remove-all, +#stencil_model .token-remove.remove-all { + display: none; +} + +/* Hide selected items from ui dropdowns */ +#custom_model .options .item .inner-item, +#scheduler .options .item .inner-item, +#device .options .item .inner-item, +#stencil_model .options .item .inner-item { + display:none; +} + +/* workarounds for container=false not currently working for dropdowns */ +.dropdown_no_container { + padding: 0 !important; +} + +#output_subdir_container :first-child { + border: none; +} + +/* reduced animation load when generating */ +.generating { + animation-play-state: paused !important; +} + +/* better clarity when progress bars are minimal */ +.meta-text { + background-color: var(--block-label-background-fill); +} + +/* lora tag pills */ +.lora-tags { + border: 1px solid var(--border-color-primary); + color: var(--block-info-text-color) !important; + padding: var(--block-padding); +} + +.lora-tag { + display: inline-block; + height: 2em; + color: rgb(212 212 212) !important; + margin-right: 5pt; + margin-bottom: 5pt; + padding: 2pt 5pt; + border-radius: 5pt; + white-space: nowrap; +} + +.lora-model { + margin-bottom: var(--spacing-lg); + color: var(--block-info-text-color) !important; + line-height: var(--line-sm); +} + +/* output gallery tab */ +.output_parameters_dataframe table.table { + /* works around a gradio bug that always shows scrollbars */ + overflow: clip auto; +} + +.output_parameters_dataframe tbody td { + font-size: small; + line-height: var(--line-xs); +} + +.output_icon_button { + max-width: 30px; + align-self: end; + padding-bottom: 8px; +} + +.outputgallery_sendto { + min-width: 7em !important; +} + +/* output gallery should take up most of the viewport height regardless of image size/number */ +#outputgallery_gallery .fixed-height { + min-height: 89vh !important; +} + +/* don't stretch non-square images to be square, breaking their aspect ratio */ +#outputgallery_gallery .thumbnail-item.thumbnail-lg > img { + object-fit: contain !important; +} + +/* centered logo for when there are no images */ +#top_logo.logo_centered { + height: 100%; + width: 100%; +} + +#top_logo.logo_centered img{ + object-fit: scale-down; + position: absolute; + width: 80%; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); +} diff --git a/apps/shark_studio/web/ui/logos/nod-icon.png b/apps/shark_studio/web/ui/logos/nod-icon.png new file mode 100644 index 0000000000000000000000000000000000000000..29f7e32220bd6df9cb4cbd168c0456818c5bc994 GIT binary patch literal 16058 zcmV;rK1IQaP)Px#1ZP1_K>z@;j|==^1poj532;bRa{vGi!vFvd!vV){sAK>DK2J$RK~#8N?R^QD z996mYS*ohLr+bzxB!MJ^C6JH=LjpG<+pi$tf?fobAc*=4-V0X&LE*Y9iu@=7dKIsV zBFGX@)Q3$#R0LdzS!J?^NFZb*%Pc*;RngdKV*Q}L#V~&4tdQ%WDV)xBoJcDU|HdoLa_CJSly4lTW(oqx=cosZA0@XnE�SP46vkguM=wz) zZHx!zbvS%PURx5~w{Y=6B2X7<%|7G^F*PRzE7ONk7&5JRg>fA+&N*%BGNz^E0=pT6 znn=tzZK!EGvrRff`=@BOEeO^U*S|6IxG6uxws$^aMFhtDzTUo>;tG5@L)w{IAmva?w!AQ3Vp|=MBdA?^gwil%BM#`JCf5^FjfiJ zxikN5(y!~p?HvXcX`x%KWt}yD!J%XguOlsbVOA1Z(Z&k~!U$^qU8_Gemv?He@2w>& zjup=`Xkjz$nrEguJDtOs!@S5oquxK^Eb}CQydu;;Y;omDZ51-*g6uTYu1gO_ow&Wn zpdc-z0Eh(L3bHBv%9m*IHPmFMqvoEM=|6|aDqW&lhaWB^_WciYya(Xv;~3UL=f(@w zSC1dMIloc~OU7AqTW@J{gE>L8Q)H1CB0duW}3q}LOY!&veIVXAdSkKsRNek?~1h0gK(iOA_S0e{SnFjgGvrBk+lwQsu}w%lqm%vbHB< zJ8ZwJzkh!~cW`k(0tg4P8p6(L{W$m4B9a2|YzH}FcB*3=>{pgHX~pha zuwn!cD?cT&o|On+aX4Anwit$YYe!s*rWl`hmu1}^aIW^dX3c4RKIh6AAT7CZuF0;+ zVA9O#QylzcoiX#+DIW%5 zceiET9`Ic_XY)Y@Gf7VuCFk@&@p2A1VC8%lt87G~j&AH1qy_O34$duEmp|wT(b7>- z$YQCCD?FrGw}$!#IdbUH2!`ahpZTVyc+tNAo8SfvyM0u4<95> zmE3h0?DO&fgaz?t&0ji}g~6{{TxW)jfyUZZkfXIwOYw>Gr%jrTeLw3 zKx-3^jY3CZ8~wZw%{*@M@33s3S=UHWjQTyFMS9pGSn%Va}09h3mI` z2n)I`^5DikMOy1{-Lqg-!>l=ro~7pVVnhe%mYW=Y&T(ZV2PU6<{xL^vw0d`6uugao z(Sn?R%#`^`^ZT0|b8GxCI`dEBKg4&?}KU17LHJN!+!(U#o%va@2XEz|tRW=eGJM9&As--B{W zJdC!kVz6RB|9C42Oh}7|fA*xvm|_x3RE9 zSWSdgnj<9(4QoPLWMTgc+rsMXiw@Tt0rduW7)?!rU{B`gLNF~V#_CG_!FaPlivdMi zYhlqdmI+0KCI7<`7BvJxcpXAoXc|FS+}vo-{SVe53pu~rEa2S9N4*^0d*A$(L(sYm zR*5iRNNX)DddEs#c?wlokt7INlZP4t5O(rcV1e?vC;9$NSO6U&%~Uv#<%4h|y5j_) zu8`Fop>*X8tQ)KnVL*_!s$u6XoWPa-gC`srYgmS={!QAw)e;tcwf1+y0=6VE?pXN{ zy0T+b8{9Cd)d#T6V2#=Cb<*O-7z?{##R#V5A2U+q;wXW#PS-`pjQ=#`01Bc9Wx+_u zAHYaABoVwWWXd!<4cBY1#%%ZM54N5^EG-Ypa5xS^NgmckVXfe9>vZMvXXwFTVUw^i z+aRT$DF)$cq^r>jBj?H%x?X2u9Y)E)bTDDB{orJ(HPhJs>lW9U8ugKi3G0VXpi2xs z!1Or#y5@G(zA$afyTNGih6oEqym{{zb)=SC4-qsdW$l^+d5nhU6xK@kw zl=G*J-wG?Uy~F`-6vlV0=IRru9j3^_4p-FG(ZRD#P1SYS9@`GYhr@>UP>v~s!eUQ~ zag~51f+xIuMVq0Nc66QgN9=LZTEZe0n_}uuK|GeL+5n2=ihJs|V07+Gd#VM&RicT+>`2Ss|=Gtm5yP&nuPG0Cz$>`~I%iyMte zP=3J^Vnjum(3ZrAg{6SKmn_G_cQ1S)<%Q$SP1v3VF_uSnjAWsE5Yv1lHJkLuN0B{k zqSZP2OVeKl+a|D8GYp6<+5;}y+|XYa3g;Eh)h5ooS7F`b?Y4m>@~#>lsjsM#_8*Ji zlXdw+UB0SdD5r{2Yqj(yA}qGQCtH%(W2A+}v^N{Q4sxzIDR#U}JpoybFOt=`9)yMC zf$c0na;3saQU_D%;}jz4L~=+DqthQs9(5Yy#!3Dd&uc1v=YA%=aP|sn~Uld63QtkPGM3^Q=MBJUQ*A5z%tI|@3ixGj=p^;M`&zb$|Pj3xH{H#4D5C0?CCi;g5M*J8Y90%DdV zCF}^?Nb`ooNaAVKmDL~`HvaY4cAJRHjJYIF7>cdwAu!~QM)3ASrp3SLBL=2>R!jDb zR0qG1BnPZ8bn@<$EP!pLsvRHw&`N|`k=nM`H{N`|#vUPUx8t?wof$_w77Yg<>Wp+{ zF^C3SLpIFk>3?@ra3xY|Mc^7(a#1dhN zlY40jEPk>z+n=%9NxOX=OLi;5?L?5V)ZCv*PKF^;Atd9exliQ5I}E_IG%O(g>|15b zHj1^;x=r*m*3s2h#j7T0SBx zg?Rt~^R5N-ne2Wk2yD4aL=mv~c$m6EwV5`3ine(`yUy=x$Fkj(V}n`E8V^dcj7~G7 z6<2t01J2%zo*VXB$Ghye*zKgPT8O!e9`26EYBLBNF@kiN8x3S^U#4LeGtZI(7SpWL zj4T{qN^9qZoG%taEoUtCs+(E1X=imMv?6hOv|X6id;63iN}}cgK#Z91i$5)-AgS8-p62o-@mK0uO;q%a~J)g(GjPTPf((3ih-ex zKQLHy?QSD2E?{S{PM{D5qofRRo!hr1tOZ$AG|5C%j0l5a;}gyB--$d=Dklmi#T|9b%kMZE{2k>k#1YevwG``x(sJ zZbCJizWm3@k;HL55JKA-OK4>l#wxX{cZsmeNn3k!&>@k? z_c%E^Il9TDWi=7YTw4rmOlB-1su`JEv)M$9$F<;(7JRv&mZY#WHrH+w0;oyY*_&iZ zZxa>UBbRwkL+G4D=ZI|G+jEW>hLfow@HX!Z!#f*D;rwvsw280a7&}AWZmHEeO@wS7 zS$0~ApU%rn_(!2=d@zK*b{T1LK{2f%>wn`FN2d2d7{Aew$clkA$WmC{xP(sQmrC$o z_Vh9iSVnX|*%jpM3qWKfHSmz&9BTC`Bn9ePqcTBGdc5YHWAr~9C0JirL?YMSjE6U2 z%vtP@-_}>zfOE3vKn~f0j3cKOLsh2lPHpvri_br9%1?3J{aqsLG->fPJJMb1xLp)r3k*w82Bnp1oB9C~=-U!?j)Fi+X4TMZo5=m=ZYbnj}_`X5` z%SSyaW`L;X)uK9jrvn$-!$V5srG4}rN%Eq}?~TJ|AEHueoQt?I>DnI*apYyVmQAkA zknVqwbnW{Q=k~V_QC9x!>4$w!V5dl%j13}W#k9QM6XM(iS#gsjMz#iFNvF0Ep>JZs zxyv?}0FKk2x5n&EBpF-S#UqfdIh&r^K}WDuhg(%!Mi^mjH(iS-18;{ZR{k1&VhJ(V>xrI0_oMhhzuXNz#k{szyT51`=Ty;&gXLuRGqY zVA-EGJ$yy_;R}vAtbn~>h~Vt*EBB>Wq_v?I%tzK$-%e5>M;8I{)eslQGTGi_&9lSzd?VM$)|ntz8Rgxf&-oBjN7H7&bSOYa9578T2IfA-`X ze)@|yByEytsOQW%^5PP?5wUZ@sY!C9+5mN{7U}DFSbiwuh?5YCK~q)?!i7ZGW3WFQ zi~Y@$?Je-zlR7neH?MmGJ=gDNOhim#W}IDV-j!=QcZOslCAYC`6MfLS6^21`v5L6< zBGLC#X(^^+0A?brL}G)uyPPvz0{N9wxpy&F_9rM%{!gLM6q8Hu-{qO;V`}H}Tp9t<$7p^N_sZM0z6!9Dic`nx`sE(Y#+e zjS2M;I{w*^BGs1y67U-T9LO64OIXOSNJe=Enz2cE`Q50Fr!{BMJBhH*6G?~V<3Wr` z%ettTuY~#}5Ed67%cajRY;q+C3u31rA++RQ#yXoG9jt)W!*v2#@rND=nZ~w-M4T_< zeP@va#lx(OWCC<41}vdjpJbRq<}1tj3#diLdcTF5cAjSGk$)W()C!ViX)MNtVkX^#-8 z-!s^)ti7I?j}kD>O{DeBMW<_3`(kPHzB3-MF+`*v6RC%F%Q1#GPDXM}1~XwvZ7+|g zN3(M!BtkMfmb~^hQISlBn1Ub`TtzbEa=azd6uslS$ncu6)BIp{D4dU8Fn!|F*dH0e zfp9OkA#IhgvlpFC?b4H2W(aFvA{WP6?Y2l2;z>kAp;_`-Bp>jDOq1phvKr4?z`EH0 zJ`ylwOh3o9x`rNT8NYR$n%vNVA&8O;8s&xN)uKSuwWj<-l3C{!y9;tSeRk`zyj3@bH6%~XS$2S>b1lekxU)&Q^#MM4bbnNaBCIO*1Op`J&`4H|cROJCK zGEN5XH6Tg#h*4EBVeu5wg~f=4hOqe2*cL(^pM}<+z?9(A4@Yl+CqP&qVA%vO|M$qfWZ8voHcjREPATWGow!nx+vsfZq&}L7yvHW++jz1 zF{;%@LHN43>|g-Ml}sL>r4mZtu#Ij-^@XiB@BN;tkD{({i z=99Q5fNQXAvcNbsU#1InaU(fQap4fH3mv?aFS1(fb}dt1=xY$)8Wqe#ztRJ6F9xAY zh9D8y4K>U#A`s~4$ytB;j^QD+>3XAemAs!U_f#@U*k93cUBr4zp)y%c9D7ffk(!f( z7b=yy5&8w*BQ64--cUANgpJ@eCQr)hB&C9PVUS-mee`39uwh<1;+I%geQ&^i2CD;c zJwRA=*t~1O%3ILckSyuNxJY=fEfE&u5~Q*umk_TG1)GlM+P+B=r0)%30T2$`V?QKn z$klxlY2Dc{EVNp{x}sA^L|6kK5`-&tIzA-3A-xeG{cv8e@X)$))X-z0yV2@Gl0Y+K zOfh~!3g;5U_HdzUpHehun4)Um96&IY zMIUa)9z}g9p1y&YN!b}#(%^SRWL@4!a_Vb5C@&kqgDa&F86&F{uh2tJQb4+pN?p+R z$4!0*ULOj}3esc$zFH>xkQm}loA?S@xi>?%8Zw!Lgo)f(2z10`n!f`d`eQ&KNpQZn z_izZehWZ9L< z)9*C{hG5%0(NjVw^$K#*L|D^LJ7u&O55Iwx)}kg+?N}I8PRdAWhUbCKA$_$OgI8c3 z#5tz>GH?TGEi3Z;(D2ChL&w9i((f;$I}IEn^vs|e_(hDbWeE#{;TXG$umA{)^TV<6 zJ)}k6M(2?yIpV_7(|rV~LNX*d?kn>OCv?*>JNFSdcMuWR0)5ab0XtW4^3v&^(_?CS zVIZ*qi;Hddgj5ikv$IU8n-+Kp1)BQDpk>}rybE<|w4bg-(m+3r*dq4Z-A~%?1kT?Q zEh6iR5k>~okQJ?xAuBnbNT0&mnAi0mJLP(G8^=w?&D&iq4$cq9hW13piUh%%NRF&T z8aI{-vMUOY{RrVVWX^C-SkPB&BTc9!4Fi1!oIZ5(=_aoHvXu%w#fEn!1lRgR)Lx=! zI1;$}0W;Z7y+HLQX!bU0mucS+6NWC5wTIEN1<79Lm-1JS&^%c z$?D^s5T^S!HILTGf)&I~-V+Zmw5M^JrXBOI&}o>8LUL2LfY+~~N{qBNvYUtwS>pBA z5@LVzT{9;CKIGop10pV2To6~$wNWH3iMXv-;QICV0qG6+){_tIB&uO(Oampv(Fh** zWNC3T@Pyl;c}O_LLuqo_k8gLsfeQSdq+SK0sy6j0`c3B%G2LnL**4xmB{XqM!OJWGaz9?H=_&(3zn7Cei0r!lffi}uRvI9Sz%rF7GVL9 z2)_%~o^IUiOcc!_A_<9}5$LM3tHM;)xJE*^(=^m3yAZx5+Ks6tyiEz4TBZuTyZ9>6Ud6X z)HrmNqGfq#RwC_`ITmawLwDVJfV_bupXt5g;zE zCo*G!QWv2c5CO;XnU`S&ui+19k#T)+t*~-efxd+2si{NfL;Tnr(`qDEAVg!fNWhMd zMgJXP77(K0Gx;LS_a%tECx9Db$z`o7mct4VS$`7gRp(ZeRFZ`V^MS$ucp*KADMSNM zJVFG%E-q;KG<2gq9!NjXWKebdGR>bx_ZlPv%x54P-6!ERwcdgU7?FYg64zQ%k{rM2 z(BfK|ZI)${`UTC)jh>s@oNE5lCp#O=8`b;!9X)}RnqC|~C1ngG!9#WmKGYHrO^L1Q zB6=T&?s#N!xtWqYFdGd3;5+_!E&kC&y&7Xpz{0}zQmesGw%G0MJ6ED^ zWLFRu*ZRZLCVZ<{D&5+UPJ0#M@y!9RD26#I0>9g z?GFJ7u#9Ux47qNCi_|aPTYwK3CB`h>!$=YpHvz=LA1d#SR-L(w+lF07M0ZzM=CNHH818@j@8#`PGeWVOI=$Ae}1pK&(FNH@j;5EdOC>Gy7? zWifHX5QR>bxc1(f1N9Nsm`yzA!K=f3`xBW=#;Z6kYxjdnD&6qDZ?A0sG>#9eYj)U4 z2n_s2dW~jvif*Rg;yUA%e)-#23D4U85TxlaMW;NPv_MW!^TUygkgU88>w0qDb_YOG zoL5goMI8maTHFj77FJ|Ps}d>CqbDubZB>KgLl^r#<6YmqP85)blR&FPbB~KY-$D0NJ}w3E#k3aO{Pym%f4<5 zS-Pk9YMHvvR$bZe*c;_`y9yu=S_D%f8R&)ZSE4mM)Y2WobZ`{<{h1Dv)t6>V^A2LX z(iPXBzo}U81P?jB`<2GgbG4K5$Zx1%O3%u=!`s$QMgr=9HEH?Um{Px6EkoLP#F%CD zLvl8Moh6;y6^16Kj16JYV9DZcAS+a1GQeBRmvC(r7r_$C?E?1~PIlx^%^fmvp zYvmmyxcVrFTcGy4*%f>(OMm)*hqvDfnE<&E@6sn2OI!Vsb6P)4P3v8_gK?Q6H$@nP zVi1}k*|`XxV_N>vBg+Pd2Ktbc@ev3Ifdp{IwwKDa1vqc?N>B4(sLrt=bB_sycf1#r zk0a6{q0s`T5oLEp6*3ie5BY;+K!T_~5%qam_bVoXhmvCe>`zv8d$7$k?k7Ficjbd^ zNBiksZ^jb-Gg7hSqtFFq-1N%zqs*DzEC;bj!m2J{X3yK$W;r-=+uX$wWA)T78ecUmy$-0)v=2M4~jpDP{;yjDe77 zZhHeXLBVTZXz$2;YK8lV#k)~sD0YO5LG=41;h!$VO&_#|N8WmwT zy3j?<`QM1V50M3V+@@Cnd>0<1y_sg+ldvsbJoM5Mx}IiWK)T&NM-}FRyx+TAQdE;vA84+$%T0+bknRHs*3i~ zo(lYOIxL#^6!+I6S_$x+c>R=BZYCPF~JSE`!qLtIini;m<3!~CFiSPug=`j* z18H(#yO?~jrBZwou5o{zCo2{HV{^;1B9l9*R4yCxce!c)BGX6C&8=D0*J#4)NDJTu zla=-uFN%lkBI4Xx=-~j8M%Fh5vMvI}8RieH?3pYPAuSLqG_UuBHODb6;o~7L@p28) zqWqt}aOMfEXAABuOgd7)U~Bq&iz;dVvG|;<%OC6tm5-1gwN5LAIQS^6vyldCsD3ej z*){1*_C{p0kQ``s9gjTz-&{Fi_&NP0ttI5K`7wvR$uAqp2Et}sS9S*K(O+j@I?EE< z?9dDJNLma0@FW@V7Zl!>+vTYuCt4W{euknaa~G>S;1u~@YlZ(mX3HS_iEmrui0q_=oM*` zu!EkV+cbKiwgq7cDN%&PNZ`LY-mK5S38R}vGG05_b?MHc>R8WP^36ucKDSxB6%@3D zUG3(r@7B`Bn^z?TGFmSWIqU^2>*7v;#J~nEm)~?yFagVYORgl2&rNs!o}Xu;2_20i z7J#%^?lv??(b@LB-!%N|ir!k=K9kn|&{*friP+TII9bARQ*f}F{Y{f$d};}3iGv~+mTkG&X0+UF*71wQ>%Tr$e6cT0 zc@mbm?<*YoUDeofdAlhQjL_Nj+W6i&Wql?sV2L|x?xKgO8957sQ;^6{!MoSe8;|O! z+NZdVJ&&rcU9e(AyVLLrA_67TTbHrsVC(T1i=iiy*!@2LiheYmNx#b`27tCy<21;G z{?rdxhwI4v*de1DhLO(Z!?8TfHu7@0{M*^#^Y637Halzwz4x4YfEN7gg{Vf+RERY> z68)3zFe?doisRWj_5vDPCu-aS6a+Pw+n(d828O5rN(a~E#=aHXd~4;(*Ky#|?EB4q zJ9^8KOOlx^mTl&aR^P|xiBqOjiQHR&u&~n0bFXaOoNl@aWG%R!Dp$&3n97_zE`QK; z5Ec(ayi1Sh7inQZ(ez-PC|~Q_NE(91N}nEOUa^)k31>4t5fRwZMGTx;Mt9OwY z09m($JU}+9(+^&rE!|^aG5Z>01#Cgh!MQU&bbzYdQwc-E4V`IpR?Ey=IyXP~_Z`WE zev=lkWQ?3LJ`ah45T$N0aXSN5ttvw1WwbnWkBkMdw^2(*kkEDt1^KK4+7nTBC%*Bo z!di!m0}&C(>cAPdb(I-a5zR5I+tSuneqM7FY`?!DQQecthJG@7+A%S~+aljXi4Sogyt@!z#$Uq#gEP zS(&;7`m(sx%@3brBDmNVuweWBp&kIlg#4+sihYeGM${a9z>ifwUsY;AX%w16cp3V$=4 z&B8KU+L)EXMNd~+biJ1DO=ym=9#$IDftoPbp0$Lf>`l)GGj^|5Et&_zdB8_$rVpZq zXcVm?kC`7xJ0F0)cJw(*q->%mok+?u7=Hq3ZJXXs)^ga|jbf@XT981&ZHVV+o-}JM z@?MtN@D+giR!r#}R$$th%55ZauS|#jt7{rZuI_SDFTKdviBQyD8ubPpmm zb{KK%U|M_#UD8oZ>#_9t0Ge+i@{v;7mVlOg(Ll&FZ>FY&!cw|PQkjhe{*UW8I1kG* z<%Jk~gMgwIGj+OsAvjsJ$h+BX9AQh=Dq<(gB1w1@xA;P?&3g?SiQIpszr&@Mes?l4 zkZOPJj0bXSeh-~`$iZXvi%S>l-dbI2Nuh0dL(@H-L3csRx>U>L)ViN9>xsC#jI=;i zD$HLxR;%#mLmYXk8CFPZ^L9YIEycdtlJ3tVMCW_8<34AtuLT@~>S#8EWbwFJK(Gmk zo&f6@&sN88sBOl9H?s0*x*ZiB6zcRip@PGikcX4?JdB8RD4qWpI-{XP$aJ#U2k~-{ zF&bhXhoe}r4N%k)&oUIZ6r%T(FiTjVKzFmH)IN7{=$`?nEM9C3!AB1r?SzF+wIc#s zU5j~GTb9mF?+Su|H#9o!VR*Cq>YZ0xr_JmTW0#T^u*AhY#u=b{_*E2?SIJ8t>ytvM zV^qx*lEvlE(hdEymSQgB&XQUZ?@r)+(=EjzjFm@oUyNabNXgg~vU-QXve5hpPl}8! zkpytwu#l$a*XVa})i4x{1Ml!xtzbjhrhTQoMYV1i%;t1mI7i_Sn&CD&2!4MTEyw4r zbS@RtZkYYJXj|Vns#lIrN2(F>gO#K^+AJW@u}&zsnep+*Dp}!NnFWM`j0kEATk^FJ z2$40OH9)MD4*YPJO#LmtW#vPwb8DBzs~9ZhyI=i1-(DoZl2VMg;N$6Z>ULe|G^8z7 z%7q0ZCyt!D@JspH2*oZZEr90UHuG^oIV!>~ZgFHTs*9zkRg>oB+(EgVtqzNbt!zR0 zQo7A6=q9Zu4s1X-b@ETP(mG#pQLeOk=Em6v1VDsB6tdcDsklO`6dl9mx+|-hcL4Ww zbLjA)QmauFnLfk|%Lmehrkklp(Rq;8)H!s3Oj^upF*+*#$>14F2{ULqOfnCo#QO)F z8cQ3{ex>Nml^wJXSs=ci&SeAbyn!>lmX5NXYq2quqLXvAi4N9I8mcYi<;w72LM16n zW*>H=xj5*ajd@J*+3G>e*CM%HEf8ZJD)f(&;j2_wd2grJFd`e24{7sqQ}D1VAx4Ib zj|qiy0AxVCJ4>cs0Yg{<;;oprg*xDxEunqyPH$QDE2L(z9Oy~=T5G#->eqAswyt#% zJw#(16_hI3mWEVk+r?_*JwIJI{YGc%ylbnZ-z}sCa8Z*l(uhjr)W_&2e-trYbls^T zFx~Z#tdk>Q)vRU0H8O8K$4fA@pzT;EE$X0YL{?3oBS6r>sX_m+2AX1kJeY_S5yi!u zy_41wKsO%qhlod16HNOiEoMl36w;$vOH7y)6t5TQ0V8aN@Sq3jWg;bdO}s>_U&i1i z<{{(RjZ%;3>}!k~*RH3?#LnQEBSvvV2$FTJ7XNfp6O!bzI3zzhcUq-*_^Y`Cj#yUd zeA5BVV~1-``hUzpF=nJbZUCX z!P0^Z=6zeRY_^*X5a~ZW#pH2lYHlO#X8~(xfhOAmIxy-M1R#2W-WD}Q>Y7WuwZA7Dj)dK2@cnb^G?cKgD!dyCUmSj8R4&_L$=#VYXS1s zjs+XVDCN|E$+x=7pFxC}N+cbJ9>8{}CXfddgGGc35qa4Ft%r|oUNawqm}I8$02N8+ z2IR&eZ8Hu)n?eHMjfnK|RXQe`@t0_RX~@M&Aruib^{O4rV|%c_4fB?Wm>*~rY}Q*x z%(Y4K%b{75JqVKVR#^M%PdkAl;c)oGN&L8VOF0q_cKsopV@5O%KlD{kN(ULOE+UoI5$2_S-OshqZS=Xy&N#-IyC$E4C0RI+M?%t^Ad#zsxapNKI5h`Cl~14 zDw#&vlHX(;vV~K=6K zBP@Av9KY$nEg&qcYBDhZ%Xs=L^WIMC?$Qn?^G=YpGYA!&tUx$vXjUA{?P~wU!jTg$ zLzDoYpFZy91CHFhYApy0;#P7)ob@by!tpPB;cr$>>{Zg%!jdNGfb1%Z21{u(Tk#?T z+msKM{$@#ufI@7)P%xymy$5F(V5d|1)9vO_vWy((VW# zbUa&49waNHEI{yhcAEsmU?5~~U`hFo%7pi)G7Z19AuOf!rVZL3hbFo=m0Pg3n8M=P z#DJ<4Ly#7__0V;#6TiD)e(AtVZjm-z1@dmlG=B@46`!G7;cUUG;&WIjT{Y@oKe*+Y z>EAkkxGLSt3Y{TqXq3Frx%}lD{$nD(3)}58q^)Duz-m@FQU_;;w~9bVR0)-8=AxnSzum|tUtaH(U^0U<0*Uu$3U8?%5=V1i}*%JFcn zKt|jAzY$2q8p0p4;_1&I^$I5opP^ktDSm=^gtKrw92c276TTwH>xdQKd0#re83xqrlI9nC#dLxaLw^cZu(eEVYTa|9fYt%)0#hy0Fr}zDBLbe_p@-Yuza%;MJJt|vFymt zd=H(=9(kR~KZokRU*!0I^pz`yjU2xEH|ohr3%$)BbAn=uaS{ zg;*;h1BePVlsoQudfK-p_C#FctX==0vFy#ouFsJ3g#xGUZ9f2N1|Mz#83hyg6kCKQAIOV)zUgm`L4a1(c1u#jAyZ#I38V9i@ zto(%)@|&6Kxv;ioC1IC zqsFuBjM-Oy2+oy_3W!oZWZC8u);p;`S1oQ#Tyl-IFEtJ~!&;%>1=GF&4}>NAsB-6-q)S&K-fE%(g}iTwn;o&(I%hUZW-Yts zVmU%xwWyP}4rASn#UtzH3;Au370Cf)w1#s1slOk6{adV)6z|e&7{?4UfR3FC{S%0+ zD_KJ>k^>-aLvDQ8%ls4JFq&T_3!~ z9kub9ZG!&20rtW1k6-YG&aUAn+|d>4X{>3;1n!q43qzlzgJRjn2|+r0VV=9^zc~8L znT*TRc~F5=P%NF1q>GoHI6ux8*HH&xvr@=t1*NXp=<0&yCNWUSe|vk0f!@M~xF8kI zn0@)Xy!O|=rX6y$l5!!kyyKyFw@EILwk7ks8nboM)=@<=E43y!>LVblt#h@B0mMly z`LMjNfK-6PWmC4z?VCQ^+jTU3wZ(@CVaL8$9!J_s}75ekQl5kvS*MC zp@j8G|AQQUnk1V%;&@HyTVu_1h1UlRJ&L?Kv z`j@IyAma8l)oX!}7^e6>91jtOA*#Gab3TKJYPLZIny37GtaY!GwvHZ=jIVN$Is;ZV zNJVLX(YYncXr)``n;0$;7fn9M+tQ^*55SUFzoW6mMU(+fn?090Y3q1HB#B%0G7l4x#>4ty9W0^FwJ`d;nLgd! zTxh@Fjug`uAR|_EGuM%Y9bY9Zq6uB?e}$ulkI=p#cb&9#^a@qt9(VHr72({en})m- z9=!=q0kFF0{SYl%Nd$g0;}@RyeE+4A<8Ecr**wJW!eD@=Qj3MbFwXTjCEd4FgnNM@ zD>SS<9tc~=Z-!$A3l7;~^>3ZDb@UoW4fPcstuxrYgnvAoP~}UZ9@Ft)tf%weBf1gnsMo|I3il zU31{+cZ7c6F%=#`M9VS9k?1C42v{j8nkrhSj5FnW4NS_R@m~Yip{#Caw`s`1iIcq0 zao?zN$63cO`25;h!UAK>2nyD|5ml2v84r(tKDW2 zF!9&8LuC{p&`+Vq0%EpU)DvNWI%(_ZH(+&BElebd1!oQ86BONsH*>9y)SMqdi^si= zuu*S>)nsx1f%L3m;!#Ketb=|tP&|Su9)!tlt!ymK zEIt7LkpRfT-24M@DM(^xZTq3IG`IXeU#N=hJ7MyVAOyur(BHRjh5q9|=voN$45s)m zmzg!xkHZhJ^#>+z?BEO4g$VhF9;)H=BkTurF8rVPAJW|N;(zsgF#o0g&-!0^mtV*( zglR7Y{+CpM%)#8|Qi%8ic>$)m5Tf`1{Ew{8!1Ncw{~!FHi;cPUNB@6W938wcQ7v+#FrAEFCcN=PV-&Edw)X^2dlX`eh<30^Gx@3=PX|9Jj^UK2`!! z(IR}=ntE6PV)2r3P$qHV^r|$sqpkhNYJo1_Kq~&N7o0PsCd-o23Xb`aZDAka7FKJZ zFBo47Rk35?pE{9Hx=vqD zI=TYBVCGVc{uoS~Y2}!2Zq1q0^wr8K$=0pJ#C$`NV_Ak*D*C%ZWC?kCuZL$0D5LL> zBY4Y68-??^CkHe|)`%d>t8Pd+2421@2{_KCF!rPw74#i7YO9x?AJ#XFE3!o>l?(grR{%Jeds@0 z3g1WfYWB(kd3tvX=9PMcJVD;eb*92yXqxH<2a16B0*% z)OR;(i+AhhQz1uVIn7I}P%KXO*-%g*C>e1Pb?>#aymh7ZQIZk2>z(Ad1hXqN+X`+& zty&yW)iQFGff^>du_ACsL=~KBvmA<4Xl$SDDib{-rpCwKg+%tKTN}OnnaXC9>0F-& z(iHuHo)IG;N>b=8c-kEP)?vI(D77HR0cG!^&X0CzI}*BO?mk#yb(sVUf~A&W5Z@A^P;)~fzuXXPeZAV5TgND31W zrHm6iu1SQIi5gl0N?Cw*7{yW(@{d2?H=@IeI8fq<6`++D8Nhs-2ZOR(slmT?6dJlh zgDOVd7^xZPeq_st(8P690*MgRs;><*@cY6eLIijsfmc~`+6_=L^C@raV2~T{B-E6$ zh^yPN6v4D~JZuJBu3}ajDo}cBRq$)W&s6u+sxasRWD^;48;>;?#dCxZeTYb8iqGeG z`QIbP`CTNKsMk2Ii5Cv=A!$Y4IN4#!r9Tc6o*fWtp*7=3d9&4WMF6{=Na+alv@?T| zLsTYoj?xqvxvsjM;KDrRterEZlBTn;Cb!Jf$A zlK44(Q9kN9ZlLrtpil=5$C8jUio^FMES9mBSEbkh6zW_HW!_F+elBQPxIRse z1|UaN!$zIR`G?%3g$r~Rsfptf0j9)n()>2bW)_@ zxlk7MdDwN;Q(}8})>_!2j!0oYQHM?30>6cw3e?c7yAA!H>tkYY^_4$PtC-VfK>4%p2!2g5}5w1;_b&kle0b`8*x$&;EL38u(=` z19Z$>lWo78?(XbUq~T{0r3Kexzu&K3K6DBRd#pA^{Vk2z8|0U6X3a;+g2Jr$+6TXT zC-BDqdv0#`!g*EA*;7T-=ae^4aT7UOMh1_)QIT%<+S~D;d%N;!3v6g;|IcXNyrxHi zknm2YWG7e6_i1n&8zMhm$k6tn5=60~K9K8=S9 zjvR40m|8CGDMyNoF7$hkW{lAf)V)Bn@_~N7v)nwZ<*8sWocPxaI5UDf=|p5K;jF3S z+&B(;1}D*mNLTO*oJjvFg_@>s>khL@{)pSD8iHNB`<^*Y^N{!T!uZIpe}ZHH0bnEc z_GoXv`Ww$GLWnt2^bb_-UFLU>doyko?ZqixSSVw0l_B8}S~~v2xxM}-kHkByM+}xs z;n?x%eX6>h7CR-O|5Cd1&hkEe#U z%ja-;;5?KU1<0AtTNc7QD(U_UM}>w%g@AY_KzG8j(QZ}$+1&C#I)*~6=qb7=2h`K( zYwVB?nwiLLp}^J1>fPI2V@8D0a^4t_Q5Gb zH9xs!NUJt}Yp<9{|68mmq1G#L=&;C6{(_p{Kw(uqE-2w3l8IZSc1wVoZvYwPsZ7Og zg3Ad;knJrYOqAv;6)cgV!}cOAB6BobR&nunnCKshsorL?VoK!6HpVE@HRpV7 zGaREq?y+-C{CCe3ISOK+hX@FDp#(8ZzFO4MjgYh6>2uB6de3^VhxNTBo&gm*)=n-4h3AgIU?PtVX)oRii@Sug0uXAisni^BK@9xb)q zU{s|P&Mbf-H`&JG>fdHX&EMx=MiWznS027IN0AtJ_Ql7l-HEML9v|0HdLaJrzjerS)B&s&Iz%CIWvc|=9t)Surr zZcrZ?Sb=uX06ixwI_G{QOXwc(XZkgs`E`T*=kroIQO_`nsf}y&?^j)8lc;}^q~xoQ7jytR9L`keSqEG5M8(_eu{n}kYrUTW1iYX2 zk}5DVFP#LsuAlGgdi9=tKjY0kOr(W8@NmywB@kXB1imGCOs*Gld6=qoyr1#7LlRyl z2)&#Fy#p_ydw$L=01-J~GE;w({E87&OX=6OixVY~==mNNOAWh^C>>k0fCUAL?fPc( zmLz&sNZT_uCAxme88ox|tNP(`ySGW&8yG32WVrbkCBpal<>mRsPZVW$Wl;EX*C(Js zN~}V__uh9E*AP)F3epU`Sb5yfu4BV_yuZo}>0I`mzSv&g#x_K}+iZV&_4SC=2f+Nz z0G0?4;?F;YnoA#M@O1C?)%R4*HR-eIpy5)w`ZRlMU!+gX%jn>T(CzQBsO3e}95zBf z?z$=Q5}gFabV!ft!Fmmt1bHzYU7SePFyNiuJVd>E6?hx>`4z?XVr*O@-L-F6|I6I1 zeOI2jaVRalEbfKd+M^V*5r8)o-YP}z^s=a za zOV7@?-S>)7#(=wzFXn6F=+E-(FlVoUA;JaK>Hw1ibR*b;nm1{jc9ZG=G(4T>q#Z-yLK%UvH=f)p8(4!46!$2^D^sVW;P;ig%3n0;D9lo~}9tX^R&ef?0#fffBsepXVVi(x7b z&Lab_@R+#8m@OLRV20AVwF!ubFVl1XHFhovr>d!U87?y%WwTL?#@)D3gb1*zBg>G~ za#|TCG2{MyF@uRp6+c4pW(LcEPK;4^fg%yGi@P?vK=kMZ=n(qmr@zS@Inh7QH`@@J zaj#xV!7s>t3+6a-W_yr305fHFKse5+B5e4v&c-XDhxI`hb0+VXwzf82-tjh|fwHg# zs&E~y!~imnnW!#TG4$bS&!**}C&u(=efY8LJ?R@~r+>jc-^_*zH`Ayuowm-!gbCGI zB_8JO_KT}&ftJVDWu?-_06FcSmd9M0=kOYHRIlsa1-(fpB6t%NG*Hwe%DuNtbU z4;n{2ITkaK&`Tcph zT=xT2*l_$K(y{2wp5ei*JJvrKQjMNfNbv5LD^+P~vA`k=`waio+uc$jiRto(E{#?OnhuReM>&8-$w5s;~~BtVBe?;r}UTGrQ#H zf%a*=<(Wn%`HWM%J3=m94B&81E#KY2oZeEdtzWe+G2b&O*aSD=n8C-!h^UP^GzbZt zB@haGv!*TrE5_#pdHE=b?6tf+6^dbvbrofVJlljn$)YbGs?+frV$tS~&I_q;2aH39 zKL4<)K+7Jsr68s*8MFN5*}E!5%uTHKr@xeKKN7r8MvP6;vOAE-qvPsCYO9Gxhpew3 z7moC3SIY@Eza$?kb@zscoy5@J5+xchDNmhGcoSno1x#NNPV?YE*&0qZ?zCgBa>0Bk zyqoeXdfE9@i3*bq-XuKaKUqI$EO10v{cVNcI1WrZp9c+(8phCY7(toGf#$tF7?)Sn$l z6OcMk7en&8Htk#%^7!p8B26mR;ObZ<&8?byEWE@duoPw4TE^39e z05isFGvlk5jX*$p^ea3vN16ECzeG72A7fvp84Y>gIt%j5z#ZYnDTK432T=I8(OLQR z0Fe|=JbxLbE;UM;(_J%0oEfN(`994C=#_KV2$a^|dA3$&D+xg8(=5N~`B{6+&emoZ z!T6+|;Dk&0Ysb2Me$!G!x>OGOT0PM+af4&9hiE+dRo7U0&!fGT-Enm#`spV&ey8bc~s`-2kTrDc_v?HRY1v$Z$#BoS^P!yH?&|(rW6cBsB z!o%>-%E8%+lNJxbROz0tk+ACCd;_s6He0&Fw6hVVPL*G~GN^MvR|#R#P(hNM4k8P! zPwb2%LDkbpzmx$Niaq}@4r5Z1=~_%9b|dhk61qXTVQR|Cp{Gy?bQ?%5eD^ijl_Kk0 z@DO?~QCvO=3)Y)@ei!jU9UC(%&G;fAT-E*KG)G5JL>D`x&R_Kz<509FHW_?|! z6T=A1a(P6afa9k$uxuZ*4)4c4D~Yf8B%i)5vB18sqh8lCz;q)&-S^9bQJcVv-G-bL zQHqj6Yl_xH`l(qn!niNbd{y|qF)Aj!z=3j~BR$cLhW?7432P!sL4hMa(pzx$OU3)*3cF2 z5DLb{inOn^X)jojt(CzwuAVRw*4V=kOWE8*BxvW~rP zaqZW5QMoc~r+}*p4eEG%CgxCx$`1{cl*jkC-XK57h?Q^LI3w`BId(DFmPggsWgyv< z3!AoPwka{!tx;SO|H@hI31bFyoUK?tAUSD(8m*z@V59wY`t*DpOuJx;->H@Ms8ksa z8?x1Tt07;kbJa5}%^TBf8w8DI>HYK`McVVG)C_ipFZwB`yGw)wHC*O?k!(p$qWE2M zCJSC^Y>JquTMV_%m-qIGiFh(xiulo^%zh4vZGbv;(DGGJudkCPN;$=#4mM)-t;SOm zoku3{chFr#Yeyx4e&0)KbrRv%v?$JW;@88+^@70X3&O`A^q`IYkW*X-n+UkfENm!l zg5oA_#15j*)dr%D@F^fGz&=;JyvTN`d`J1X9IQCT#jCq1e=1rV=7{+xttM|Ih*22ziL4lCK)18s;Xr(WxjSm}~9ZdPxgrR4(r*!$#!l#wR2pI5nj}`s%%&OO$eOD8NlxK-9VJ^Gb_ltoK?} zUT&R=#%Xx|EqvkJ>1x0eCMH}aD}pUZ8yV6A2|B73{9ehBkVcrA94-ZdP9;D~R`!}t zR0x5NP-y@~gH^Vyq4V)bL2bt2>GZ|%sL^_TpS<8qBJoW6=HPtXX|(})9$!5)TZD#f z15UpPRke5+@Hd{Xx=nZUP=-)yVOjdV0|{8XMjtx2UtR*20Z3A;s=w!4Zji$$C-6Xs z^a;U8-&oABm0+*=Qo960g4w0@sio?gzG~8^r|Y;(z*JXhQffrNF~^~7mn}Na<{Ebg zylS8n8mwiQ~37ge|4EgOqGuspJu>05YiwC_Nq%( zR-oj!16<_Q2rx(dHbT8MPza6}Bk|JdBcju_dcZFU@Kl&`*_5(rYcaXq!)0#U*GZHx zzTM9?a|)E_xSq6kE^90;tzD_|KZ6d3ty1D;yTc?YO`^?jNQzd;f`;*0pdTpaV&~!E z5oSB`1o@B*$yUPekw0L}2>!-H3GU{uCF<{iVw0uxX)9;r7NSbS3pI@%Jf^C~8yx-~ zGZ|O;W5=PGnmsMDZKs3#%>+3|j8bmcR8y%z+$UPTY;$yoZIcVOwn`y81a04*FKf zlzulB%v?i_|3*_vj5fJ-5_ZEp%a6rozfMzIwrv!fB7?N{wYivkSvLHXxSBB4a12$0L zIpm1Gn4mHoTjLjAWm>`uif*QA=bSXKbcCzfR1drSKtp|Af8 zcY;`h7zo)i;0P%vzeUM`<9$a#(ZSz^6Nz$(BuIj zs_{AQjS8sXtG<)S>&QWKavDGpYf|JB;rfu#_Q{uk=Z?dZFb6QmDM6tJM4cXzV70(O zTFgN=O#B7mcLWZ#@@LK(CN_*hGV@CxB0cNAPwVAoXf!>KG4q0Xy&id8PWE};;Kz>Y z>t~BNte=6Y;H-RIAv~2@a{_$yk7{FiEOeGYym57{78#{#hYR zyhU=K2|cToe2_ipSRFU(Y|-NE`%1SFenzLSo_j)4t~y2L?%Go3F()Ude+YWKc2Y+A zW>3K3^w67?^~~QxYlCen<1XDkSbT&TG+k9B!jRb{da_On#Oo<9>*4K?=@@;Lh2k;Ht0!4#Fl zQVNS3Mf-||glCnNpDI}SSH7f&wt#l3l$&Ld)s_8)r#?b_)QS;d7+lChRd*ykH@mVI z*2z(iQSsF0b09LbWBw4WnRQd+@RH`bZ1Gg`xk$bz_#o=8KzIh@WRR#`o;po!EqG~{ z!MjJSQcOw;;U5~+EN#NQ2$5}d0X1m%*voULqGM%i;bpcf*d7jMi!Z{E4f;iXVil=6pO7_nzDZi3yz|_$O5%D-_IT7yv)Cr^x93})k_%3uZae>Cz1b7 zv3fvY77Q?4-`q@a;oV$XdVJPo)Aad)fc3-l8>hlCSE7Jj02n(3UPUj*NMY%)J zi$%jyNq=aF!)L=w34oO?CqgD}vU1a8!K_dABt6i%7#i`2^oOc(M*1UTvhBZb^#FiQV3LSsq+RkWV!|7eV>2UB3Zj&xOS&^{}n8dY-yo#&JWK`%3 z5FbIob`R60ynPBmWGtI=AKvtK^q;weqqO1{knVa?{j{C8DHMf|!)F9YURW--SNkcZ zNH4Rvk)D&2durk;=PzyQr?2d{-yW7cd$Vwmpl@=bl;2yfw!*$`JxkpwCy!jJGO8U zW7n)OMZw9FXK0sDG_JMB$=UhmWjaFM&|<$kebV?~++Fb3otk>b_eClTy4M!nzc-g1 z8toMP>E(kELQlemOw3LO7v3vv7woz0&d{-C!*pvBEO6Cp#||7MWZc(7=c4Ux3+~gI zx>R(AA-Wv>qGi&Df`Y^U?=HZ>PFJFN#jY!H-{rv-7--x7iQDgKhw&deYnovjyeQkN zn4H}yPsMxB!wW0|`m4orTMW39TDQj2kvfc+n=E>yDb5YDJ^d8k;_%DGKb`EF6m^hM zkSa<^M35nd)KB5a za=$UVR#IlH?YU;r{%_GE;^1oilBVL9>d8(5uXPT|%IMExf!>*y?(NH$tp%r3j+cIA zETbty##d>&r{RrHPtKD;K5j_WUiJ?q0@!hX|El|6#J|8Rs`d?Ryt(~@RS`BU7(gCj zd4vbFOynk55hA>y5rlauy=xnkjg@JWG(D`}`3C$M^y?lagxxC#n75{m9sJ&4eTk+R z*sQ_rdc3iE&^HzfZ1lHa=L`9^i+VG_C85X9C}fV0A zsePwC0m)eR117f*@3qu}YVPyL1m7Wl zB&=~6e#O?FmUW9jpN%YtIWXLmDhTOPRnY%oVX`2ZF8{O8=*7`3FuX5DQ|5W_H$OY# z>}}QV_O=m&l#ylY#r@+Xmh^%ExWDGf6}Bh&_u$cw>3n)x=ckg2+8lckOq-4{9Whco zr|mv^whd!QB*&4MY_={jad4G4D`>MP2^>GDq3Wz;Ko%dL zX!6#2juADj6IrfK?^4cLTVJxwFU{a#^gM@V6wmy#hj>SxWJIv$U z4WXt|v&^#{b;+W!nN++b8h?C4<5=0_9+pL33zsi&wnr9;Uz4$RA{ErVop${L=D-=o zDT%+)jja6~O5h=~S|15B-FZb3`Zw8__>1E`I9$lI)KJu~v!$o|GN8>s5{up}Rn=i^ zcDQIPSqxahjWq8hRo-qfTSYHbqKujw@nsLED;(;D|4!+-CdfUsJbVya7Y-x+uF^V& z{lq#z>9)n)FgT}%7RwA!j5*Df5a?|CRk}AMGdx&fyGqv@m?zh~|4j=?Bu^gEuVtL! zOWi<Qox+9r?Pei=2%cwx4mj=`%hxC6D!jM)7$VMf-7a^|skjUQWd+O;LdO&G(A^2VU tb68rh4voRxzIe})cyar7%5_>kwNwHFQJk2{r6G7^g7wW7wM{}081Mmzuj literal 0 HcmV?d00001 diff --git a/apps/shark_studio/web/ui/outputgallery.py b/apps/shark_studio/web/ui/outputgallery.py new file mode 100644 index 0000000000..dd58541aae --- /dev/null +++ b/apps/shark_studio/web/ui/outputgallery.py @@ -0,0 +1,416 @@ +import glob +import gradio as gr +import os +import subprocess +import sys +from PIL import Image + +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.api.utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, +) +from apps.shark_studio.web.ui.utils import nodlogo_loc +from apps.shark_studio.web.utils.metadata import displayable_metadata + +# -- Functions for file, directory and image info querying + +output_dir = get_generated_imgs_path() + + +def outputgallery_filenames(subdir) -> list[str]: + new_dir_path = os.path.join(output_dir, subdir) + if os.path.exists(new_dir_path): + filenames = [ + glob.glob(new_dir_path + "/" + ext) + for ext in ("*.png", "*.jpg", "*.jpeg") + ] + + return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True) + else: + return [] + + +def output_subdirs() -> list[str]: + # Gets a list of subdirectories of output_dir and below, as relative paths. + relative_paths = [ + os.path.relpath(entry[0], output_dir) + for entry in os.walk( + output_dir, followlinks=cmd_opts.output_gallery_followlinks + ) + ] + + # It is less confusing to always including the subdir that will take any + # images generated today even if it doesn't exist yet + if get_generated_imgs_todays_subdir() not in relative_paths: + relative_paths.append(get_generated_imgs_todays_subdir()) + + # sort subdirectories so that the date named ones we probably + # created in this or previous sessions come first, sorted with the most + # recent first. Other subdirs are listed after. + generated_paths = sorted( + [path for path in relative_paths if path.isnumeric()], reverse=True + ) + result_paths = generated_paths + sorted( + [ + path + for path in relative_paths + if (not path.isnumeric()) and path != "." + ] + ) + + return result_paths + + +# --- Define UI layout for Gradio + +with gr.Blocks() as outputgallery_element: + nod_logo = Image.open(nodlogo_loc) + + with gr.Row(elem_id="outputgallery_gallery"): + # needed to workaround gradio issue: + # https://github.com/gradio-app/gradio/issues/2907 + dev_null = gr.Textbox("", visible=False) + + gallery_files = gr.State(value=[]) + subdirectory_paths = gr.State(value=[]) + + with gr.Column(scale=6): + logo = gr.Image( + label="Getting subdirectories...", + value=nod_logo, + interactive=False, + visible=True, + show_label=True, + elem_id="top_logo", + elem_classes="logo_centered", + show_download_button=False, + ) + + gallery = gr.Gallery( + label="", + value=gallery_files.value, + visible=False, + show_label=True, + columns=4, + ) + + with gr.Column(scale=4): + with gr.Group(): + with gr.Row(): + with gr.Column( + scale=15, + min_width=160, + elem_id="output_subdir_container", + ): + subdirectories = gr.Dropdown( + label=f"Subdirectories of {output_dir}", + type="value", + choices=subdirectory_paths.value, + value="", + interactive=True, + elem_classes="dropdown_no_container", + allow_custom_value=True, + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + open_subdir = gr.Button( + variant="secondary", + value="\U0001F5C1", # unicode open folder + interactive=False, + size="sm", + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + refresh = gr.Button( + variant="secondary", + value="\u21BB", # unicode clockwise arrow circle + size="sm", + ) + + image_columns = gr.Slider( + label="Columns shown", value=4, minimum=1, maximum=16, step=1 + ) + outputgallery_filename = gr.Textbox( + label="Filename", + value="None", + interactive=False, + show_copy_button=True, + ) + + with gr.Accordion( + label="Parameter Information", open=False + ) as parameters_accordian: + image_parameters = gr.DataFrame( + headers=["Parameter", "Value"], + col_count=2, + wrap=True, + elem_classes="output_parameters_dataframe", + value=[["Status", "No image selected"]], + interactive=True, + ) + + with gr.Accordion(label="Send To", open=True): + with gr.Row(): + outputgallery_sendto_sd = gr.Button( + value="Stable Diffusion", + interactive=False, + elem_classes="outputgallery_sendto", + size="sm", + ) + + # --- Event handlers + + def on_clear_gallery(): + return [ + gr.Gallery( + value=[], + visible=False, + ), + gr.Image( + visible=True, + ), + ] + + def on_image_columns_change(columns): + return gr.Gallery(columns=columns) + + def on_select_subdir(subdir) -> list: + # evt.value is the subdirectory name + new_images = outputgallery_filenames(subdir) + new_label = ( + f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" + ) + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_open_subdir(subdir): + subdir_path = os.path.normpath(os.path.join(output_dir, subdir)) + + if os.path.isdir(subdir_path): + if sys.platform == "linux": + subprocess.run(["xdg-open", subdir_path]) + elif sys.platform == "darwin": + subprocess.run(["open", subdir_path]) + elif sys.platform == "win32": + os.startfile(subdir_path) + + def on_refresh(current_subdir: str) -> list: + # get an up-to-date subdirectory list + refreshed_subdirs = output_subdirs() + # get the images using either the current subdirectory or the most + # recent valid one + new_subdir = ( + current_subdir + if current_subdir in refreshed_subdirs + else refreshed_subdirs[0] + ) + new_images = outputgallery_filenames(new_subdir) + new_label = ( + f"{len(new_images)} images in " + f"{os.path.join(output_dir, new_subdir)}" + ) + + return [ + gr.Dropdown( + choices=refreshed_subdirs, + value=new_subdir, + ), + refreshed_subdirs, + new_images, + gr.Gallery( + value=new_images, label=new_label, visible=len(new_images) > 0 + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_new_image(subdir, subdir_paths, status) -> list: + # prevent error triggered when an image generates before the tab + # has even been selected + subdir_paths = ( + subdir_paths + if len(subdir_paths) > 0 + else [get_generated_imgs_todays_subdir()] + ) + + # only update if the current subdir is the most recent one as + # new images only go there + if subdir_paths[0] == subdir: + new_images = outputgallery_filenames(subdir) + new_label = ( + f"{len(new_images)} images in " + f"{os.path.join(output_dir, subdir)} - {status}" + ) + + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + else: + # otherwise change nothing, + # (only untyped gradio gr.update() does this) + return [gr.update(), gr.update(), gr.update()] + + def on_select_image(images: list[str], evt: gr.SelectData) -> list: + # evt.index is an index into the full list of filenames for + # the current subdirectory + filename = images[evt.index] + params = displayable_metadata(filename) + + if params: + if params["source"] == "missing": + return [ + "Could not find this image file, refresh the gallery and update the images", + [["Status", "File missing"]], + ] + else: + return [ + filename, + list(map(list, params["parameters"].items())), + ] + + return [ + filename, + [["Status", "No parameters found"]], + ] + + def on_outputgallery_filename_change(filename: str) -> list: + exists = filename != "None" and os.path.exists(filename) + return [ + # disable or enable each of the sendto button based on whether + # an image is selected + gr.Button(interactive=exists), + ] + + # The time first our tab is selected we need to do an initial refresh + # to populate the subdirectory select box and the images from the most + # recent subdirectory. + # + # We do it at this point rather than setting this up in the controls' + # definitions as when you refresh the browser you always get what was + # *initially* set, which won't include any new subdirectories or images + # that might have created since the application was started. Doing it + # this way means a browser refresh/reload always gets the most + # up-to-date data. + def on_select_tab(subdir_paths, request: gr.Request): + local_client = request.headers["host"].startswith( + "127.0.0.1:" + ) or request.headers["host"].startswith("localhost:") + + if len(subdir_paths) == 0: + return on_refresh("") + [gr.update(interactive=local_client)] + else: + return ( + # Change nothing, (only untyped gr.update() does this) + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + # clearing images when we need to completely change what's in the + # gallery avoids current images being shown replacing piecemeal and + # prevents weirdness and errors if the user selects an image during the + # replacement phase. + clear_gallery = dict( + fn=on_clear_gallery, + inputs=None, + outputs=[gallery, logo], + queue=False, + ) + + subdirectories.select(**clear_gallery).then( + on_select_subdir, + [subdirectories], + [gallery_files, gallery, logo], + queue=False, + ) + + open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False) + + refresh.click(**clear_gallery).then( + on_refresh, + [subdirectories], + [subdirectories, subdirectory_paths, gallery_files, gallery, logo], + queue=False, + ) + + image_columns.change( + fn=on_image_columns_change, + inputs=[image_columns], + outputs=[gallery], + queue=False, + ) + + gallery.select( + on_select_image, + [gallery_files], + [outputgallery_filename, image_parameters], + queue=False, + ) + + outputgallery_filename.change( + on_outputgallery_filename_change, + [outputgallery_filename], + [ + outputgallery_sendto_sd, + ], + queue=False, + ) + + # We should have been given the .select function for our tab, so set it up + def outputgallery_tab_select(select): + select( + fn=on_select_tab, + inputs=[subdirectory_paths], + outputs=[ + subdirectories, + subdirectory_paths, + gallery_files, + gallery, + logo, + open_subdir, + ], + queue=False, + ) + + # We should have been passed a list of components on other tabs that update + # when a new image has generated on that tab, so set things up so the user + # will see that new image if they are looking at today's subdirectory + def outputgallery_watch(components: gr.Textbox): + for component in components: + component.change( + on_new_image, + inputs=[subdirectories, subdirectory_paths, component], + outputs=[gallery_files, gallery, logo], + queue=False, + ) diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 1be5dc89fe..f26c7967e3 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -24,130 +24,31 @@ ) from apps.shark_studio.api.sd import ( sd_model_map, - StableDiffusion, -) -from apps.shark_studio.api.schedulers import ( - scheduler_model_map, + shark_sd_fn, + cancel_sd, ) from apps.shark_studio.api.controlnet import ( preprocessor_model_map, - control_adapter_model_map, PreprocessorModel, + cnet_preview, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, ) from apps.shark_studio.modules.img_processing import ( resampler_list, resize_stencil, ) +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from apps.shark_studio.web.ui.utils import ( - get_generation_text_info, nodlogo_loc, ) +from apps.shark_studio.web.utils.state import ( + get_generation_text_info, + status_label, +) from apps.shark_studio.web.ui.common_events import lora_changed -sd_pipe = None - - -# NOTE: Each `hf_model_id` should have its own starting configuration. - -# model_vmfb_key = "" - -def shark_sd_fn( - prompt: str, - negative_prompt: str, - image_dict, - height: int, - width: int, - steps: int, - strength: float, - guidance_scale: float, - seed: str | int, - batch_count: int, - batch_size: int, - scheduler: str, - base_model_id: str, - custom_checkpoints: str, - custom_vae: str, - precision: str, - device: str, - lora_weights: str | list, - lora_hf_ids: str | list, - ondemand: bool, - repeatable_seeds: bool, - resample_type: str, - control_mode: str, - stencils: list, - images: list, - preprocessed_hints: list, - progress=gr.Progress(), -): - - # Handling gradio ImageEditor datatypes so we have unified inputs to the SD API - for i, stencil in enumerate(stencils): - if images[i] is None and stencil is not None: - continue - elif stencil is None and any(img is not None for img in [images[i], preprocessed_hints[i]]): - images[i] = None - preprocessed_hints[i] = None - elif images[i] is not None: - if isinstance(images[i], dict): - images[i] = images[i]["composite"] - images[i] = images[i].convert("RGB") - - if isinstance(image_dict, PIL.Image.Image): - image = image_dict.convert("RGB") - elif image_dict: - image = image_dict["image"].convert("RGB") - else: - image = None - if image: - image, _, _, = resize_stencil(image, width, height) - - device_id = None - - from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - - submit_pipe_kwargs = { - base_model_id: base_model_id, - height: height, - width: width, - precision: precision, - device: device, - extra_model_ids: extra_model_ids, - embeddings: lora_hf_ids, - import_ir: cmd_opts.import_ir, - } - submit_prep_kwargs = { - - - - global sd_pipe - global sd_pipe_kwargs - - for key in - - if sd_pipe is None: - history[-1][-1] = "Getting the pipeline ready..." - yield history, "" - - # Initializes the pipeline and retrieves IR based on all - # parameters that are static in the turbine output format, - # which is currently MLIR in the torch dialect. - - sd_pipe = SharkStableDiffusionPipeline( - **submit_pipe_kwargs - ) - sd_pipe.queue_compile() - - for prompt, msg, exec_time in progress.tqdm( - sd_pipe.generate_images( - prompt, - negative_prompt, - ), - desc="Generating Image...", - ): - - return history, "" - def view_json_file(file_obj): content = "" @@ -155,17 +56,33 @@ def view_json_file(file_obj): content = fopen.read() return content -sd_fn_sig = signature(shark_sd_fn) -max_controlnets = 5 + +max_controlnets = 3 max_loras = 5 + def show_loras(k): k = int(k) - return [gr.Dropdown(visible=True)]*k + [gr.Dropdown(visible=False, value="None")]*(max_textboxes-k) + return gr.State( + [gr.Dropdown(visible=True)] * k + + [gr.Dropdown(visible=False, value="None")] * (max_loras - k) + ) + def show_controlnets(k): k = int(k) - return [gr.Row(visible=True)]*k + [gr.Row(visible=False)]*(max_textboxes-k) + return [ + gr.State( + [ + [gr.Row(visible=True, render=True)] * k + + [gr.Row(visible=False)] * (max_controlnets - k) + ] + ), + gr.State([None] * k), + gr.State([None] * k), + gr.State([None] * k), + ] + def create_canvas(width, height): data = Image.fromarray( @@ -182,10 +99,9 @@ def create_canvas(width, height): } return EditorValue(img_dict) + def import_original(original_img, width, height): - resized_img, _, _ = resize_stencil( - original_img, width, height - ) + resized_img, _, _ = resize_stencil(original_img, width, height) img_dict = { "background": resized_img, "layers": [resized_img], @@ -196,6 +112,7 @@ def import_original(original_img, width, height): crop_size=(width, height), ) + def update_cn_input( model, width, @@ -203,7 +120,6 @@ def update_cn_input( stencils, images, preprocessed_hints, - index, ): if model == None: stencils[index] = None @@ -271,80 +187,99 @@ def update_cn_input( images, preprocessed_hints, ] + + +sd_fn_inputs = [] +sd_fn_sig = signature(shark_sd_fn).replace() +for i in sd_fn_sig.parameters: + sd_fn_inputs.append(i) + with gr.Blocks(title="Stable Diffusion") as sd_element: # Get a list of arguments needed for the API call, then # initialize an empty list that will manage the corresponding # gradio values. - inputs_list = gr.State(signature(shark_sd_fn)) - inputs_args = gr.State([None] * len(inputs_list)) with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - elem_id="top_logo", - width=150, - height=50, - show_download_button=False, - ) - save_sd_config = gr.Button(label="Save Config", scale=1) - load_sd_config = gr.FileExplorer("Load Config", scale=1) - clear_sd_config = gr.ClearButton("Clear Config", scale=1) - with gr.Column(elem_if="ui_body"): + nod_logo = Image.open(nodlogo_loc) + with gr.Row(variant="compact", equal_height=True): + with gr.Column( + scale=1, + elem_id="demo_title_outer", + ): + gr.Image( + value=nod_logo, + show_label=False, + interactive=False, + elem_id="top_logo", + width=150, + height=50, + show_download_button=False, + ) + with gr.Column(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): - with gr.Group() - sd_model_info = ( - f"Checkpoint Path: {str(get_checkpoint_path())}" - ) - sd_base = gr.Dropdown( - label="Base Model", - info="Select or enter HF model ID", - elem_id="custom_model", - value="stabilityai/stable-diffusion-2.1-base", - choices=get_base_models(), - ) # base_model_id - sd_checkpoint = gr.Dropdown( - label="Checkpoints (optional)", - info="Select or enter HF model ID", - elem_id="custom_model", - value="None", - choices=get_checkpoints(sd_base), - ) # - sd_vae_info = (str(get_checkpoints_path("vae"))).replace( - "\\", "\n\\" - ) - sd_vae_info = f"VAE Path: {sd_vae_info}" - sd_custom_vae = gr.Dropdown( - label=f"Custom VAE Models", - info=sd_vae_info, - elem_id="custom_model", - value=os.path.basename(cmd_opts.custom_vae) - if cmd_opts.custom_vae - else "None", - choices=["None"] + get_checkpoints("vae"), - allow_custom_value=True, - scale=1, - ) - + with gr.Row(equal_height=True): + with gr.Column(scale=3): + sd_model_info = ( + f"Checkpoint Path: {str(get_checkpoints_path())}" + ) + sd_base = gr.Dropdown( + label="Base Model", + info="Select or enter HF model ID", + elem_id="custom_model", + value="stabilityai/stable-diffusion-2-1-base", + choices=sd_model_map.keys(), + ) # base_model_id + sd_custom_weights = gr.Dropdown( + label="Weights (Optional)", + info="Select or enter HF model ID", + elem_id="custom_model", + value="None", + allow_custom_value=True, + choices=get_checkpoints(sd_base), + ) # + with gr.Column(scale=2): + sd_vae_info = ( + str(get_checkpoints_path("vae")) + ).replace("\\", "\n\\") + sd_vae_info = f"VAE Path: {sd_vae_info}" + sd_custom_vae = gr.Dropdown( + label=f"Custom VAE Models", + info=sd_vae_info, + elem_id="custom_model", + value=os.path.basename(cmd_opts.custom_vae) + if cmd_opts.custom_vae + else "None", + choices=["None"] + get_checkpoints("vae"), + allow_custom_value=True, + scale=1, + ) + with gr.Column(scale=1): + save_sd_config = gr.Button( + value="Save Config", size="sm" + ) + clear_sd_config = gr.ClearButton( + value="Clear Config", size="sm" + ) + load_sd_config = gr.FileExplorer( + label="Load Config", + root=os.path.basename("./configs"), + ) + with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( label="Prompt", - value=args.prompts[0], + value=cmd_opts.prompts[0], lines=2, elem_id="prompt_box", ) negative_prompt = gr.Textbox( label="Negative Prompt", - value=args.negative_prompts[0], + value=cmd_opts.negative_prompts[0], lines=2, elem_id="negative_prompt_box", ) - - with gr.Accordion(label = "Input Image", open=False): + + with gr.Accordion(label="Input Image", open=False): # TODO: make this import image prompt info if it exists sd_init_image = gr.Image( label="Input Image", @@ -352,41 +287,94 @@ def update_cn_input( height=300, interactive=True, ) - with gr.Accordion(label="Embeddings options", open=False): + with gr.Accordion( + label="Embeddings options", open=False, render=True + ): sd_lora_info = ( str(get_checkpoints_path("loras")) ).replace("\\", "\n\\") - num_loras = gr.Slider(1, max_loras, value=1, step=1, label="LoRA Count") - loras = [] + num_loras = gr.Slider( + 1, max_loras, value=1, step=1, label="LoRA Count" + ) + loras = gr.State([]) for i in range(max_loras): - lora_opt = gr.Dropdown( - allow_custom_value=False, - label=f"Standalone LoRA Weights", - info=sd_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), + with gr.Row(): + lora_opt = gr.Dropdown( + allow_custom_value=True, + label=f"Standalone LoRA Weights", + info=sd_lora_info, + elem_id="lora_weights", + value="None", + choices=["None"] + get_checkpoints("lora"), + ) + with gr.Row(): + lora_tags = gr.HTML( + value="
No LoRA selected
", + elem_classes="lora-tags", + ) + gr.on( + triggers=[lora_opt.change], + fn=lora_changed, + inputs=[lora_opt], + outputs=[lora_tags], + queue=True, ) + loras.value.append(lora_opt) + + num_loras.change(show_loras, [num_loras], [loras]) with gr.Accordion(label="Advanced Options", open=True): with gr.Row(): scheduler = gr.Dropdown( elem_id="scheduler", label="Scheduler", value="EulerDiscrete", - choices=scheduler_list, + choices=scheduler_model_map.keys(), allow_custom_value=False, ) with gr.Row(): height = gr.Slider( - 384, 768, value=cmd_opts.height, step=8, label="Height" + 384, + 768, + value=cmd_opts.height, + step=8, + label="Height", ) width = gr.Slider( - 384, 768, value=cmd_opts.width, step=8, label="Width" + 384, + 768, + value=cmd_opts.width, + step=8, + label="Width", ) with gr.Row(): with gr.Column(scale=3): steps = gr.Slider( - 1, 100, value=args.steps, step=1, label="Steps" + 1, + 100, + value=cmd_opts.steps, + step=1, + label="Steps", + ) + batch_count = gr.Slider( + 1, + 100, + value=cmd_opts.batch_count, + step=1, + label="Batch Count", + interactive=True, + ) + batch_size = gr.Slider( + 1, + 4, + value=cmd_opts.batch_size, + step=1, + label="Batch Size", + interactive=True, + visible=True, + ) + repeatable_seeds = gr.Checkbox( + cmd_opts.repeatable_seeds, + label="Repeatable Seeds", ) with gr.Column(scale=3): strength = gr.Slider( @@ -402,6 +390,13 @@ def update_cn_input( label="Resample Type", allow_custom_value=True, ) + guidance_scale = gr.Slider( + 0, + 50, + value=cmd_opts.guidance_scale, + step=0.1, + label="CFG Scale", + ) ondemand = gr.Checkbox( value=cmd_opts.lowvram, label="Low VRAM", @@ -416,38 +411,6 @@ def update_cn_input( ], visible=True, ) - with gr.Row(): - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=cmd_opts.guidance_scale, - step=0.1, - label="CFG Scale", - ) - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=cmd_opts.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - repeatable_seeds = gr.Checkbox( - cmd_opts.repeatable_seeds, - label="Repeatable Seeds", - ) - with gr.Row(): - batch_size = gr.Slider( - 1, - 4, - value=cmd_opts.batch_size, - step=1, - label="Batch Size", - interactive=True, - visible=True, - ) with gr.Row(): seed = gr.Textbox( value=cmd_opts.seed, @@ -457,40 +420,53 @@ def update_cn_input( device = gr.Dropdown( elem_id="device", label="Device", - value=get_available_devices[0], - choices=get_available_devices, + value=get_available_devices()[0], + choices=get_available_devices(), allow_custom_value=False, ) - with gr.Accordion(label="Controlnet Options", open=False): + with gr.Accordion( + label="Controlnet Options", open=False, render=False + ): sd_cnet_info = ( str(get_checkpoints_path("controlnet")) ).replace("\\", "\n\\") - num_cnets = gr.Slider(1, max_controlnets, value=1, step=1, label="Controlnet Count") + num_cnets = gr.Slider( + 0, + max_controlnets, + value=0, + step=1, + label="Controlnet Count", + ) cnet_rows = [] - stencils = [] - images = [] - preprocessed_hints = [] + stencils = gr.State([]) + images = gr.State([]) + preprocessed_hints = gr.State([]) + control_mode = gr.Radio( + choices=["Prompt", "Balanced", "Controlnet"], + value="Balanced", + label="Control Mode", + ) + for i in range(max_controlnets): - with gr.Row as cnet_row: + with gr.Row(visible=False) as cnet_row: with gr.Column(): cnet_gen = gr.Button( value="Preprocess controlnet input", ) - cnet_processor = gr.Dropdown( - allow_custom_value=True, - label=f"Controlnet Preprocessor", - info=sd_cnet_info, - elem_id="lora_weights", - value="None", - choices=["None"] + controlnet_list + get_custom_model_files("controlnet"), - ) - cnet_adapter = gr.Dropdown( + cnet_model = gr.Dropdown( allow_custom_value=True, - label=f"Controlnet Adapter", + label=f"Controlnet Model", info=sd_cnet_info, elem_id="lora_weights", value="None", - choices=["None"] + controlnet_list + get_custom_model_files("controlnet"), + choices=[ + "None", + "canny", + "openpose", + "scribble", + "zoedepth", + ] + + get_checkpoints("controlnet"), ) canvas_width = gr.Slider( label="Canvas Width", @@ -529,14 +505,13 @@ def update_cn_input( visible=True, label="Preprocessed Hint", interactive=True, - show_label=True + show_label=True, ) use_input_img.click( import_original, [sd_init_image, canvas_width, canvas_height], - [cnet_image], + [cnet_input], ) - cnet_model.change( fn=update_cn_input, inputs=[ @@ -563,7 +538,7 @@ def update_cn_input( create_canvas, [canvas_width, canvas_height], [ - cnet_image, + cnet_input, ], ) gr.on( @@ -583,12 +558,16 @@ def update_cn_input( preprocessed_hints, ], ) - cnet_rows.append(cnet_row) + cnet_rows.value.append(cnet_row) - num_cnets.change(show_controlnets, num_cnets, cnet_rows) + num_cnets.change( + show_controlnets, + [num_cnets], + [cnet_rows, stencils, images, preprocessed_hints], + ) with gr.Column(scale=1, min_width=600): with gr.Group(): - img2img_gallery = gr.Gallery( + sd_gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery", @@ -596,14 +575,14 @@ def update_cn_input( object_fit="contain", ) std_output = gr.Textbox( - value=f"{i2i_model_info}\n" + value=f"{sd_model_info}\n" f"Images will be saved at " f"{get_generated_imgs_path()}", lines=2, elem_id="std_output", show_label=False, ) - img2img_status = gr.Textbox(visible=False) + sd_status = gr.Textbox(visible=False) with gr.Row(): stable_diffusion = gr.Button("Generate Image(s)") random_seed = gr.Button("Randomize Seed") @@ -631,12 +610,11 @@ def update_cn_input( batch_size, scheduler, sd_base, - sd_checkpoint, + sd_custom_weights, sd_custom_vae, precision, device, - lora_weights, - lora_hf_id, + loras, ondemand, repeatable_seeds, resample_type, @@ -652,13 +630,13 @@ def update_cn_input( stencils, images, ], - show_progress="minimal" if cmd_opts.progress_bar else "none", + show_progress="minimal", ) status_kwargs = dict( - fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs), + fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs), inputs=[batch_count, batch_size], - outputs=img2img_status, + outputs=sd_status, ) prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) @@ -670,10 +648,3 @@ def update_cn_input( fn=cancel_sd, cancels=[prompt_submit, neg_prompt_submit, generate_click], ) - - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/shark_studio/web/ui/utils.py b/apps/shark_studio/web/ui/utils.py index 9b588f858a..ba62e5adc0 100644 --- a/apps/shark_studio/web/ui/utils.py +++ b/apps/shark_studio/web/ui/utils.py @@ -1,10 +1,33 @@ -def nodlogo_loc(): - return "foo" +from enum import IntEnum +import math +import sys +import os -def get_checkpoints_path(model_type: str = None): - return "foo" +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) -def get_checkpoints(): - return "foo" +nodlogo_loc = resource_path("logos/nod-logo.png") +nodicon_loc = resource_path("logos/nod-icon.png") + + +class HSLHue(IntEnum): + RED = 0 + YELLOW = 60 + GREEN = 120 + CYAN = 180 + BLUE = 240 + MAGENTA = 300 + + +def hsl_color(alpha: float, start, end): + b = (end - start) * (alpha if alpha > 0 else 0) + result = b + start + + # Return a CSS HSL string + return f"hsl({math.floor(result)}, 80%, 35%)" diff --git a/apps/shark_studio/web/utils/globals.py b/apps/shark_studio/web/utils/globals.py new file mode 100644 index 0000000000..0b5f54636a --- /dev/null +++ b/apps/shark_studio/web/utils/globals.py @@ -0,0 +1,74 @@ +import gc + +""" +The global objects include SD pipeline and config. +Maintaining the global objects would avoid creating extra pipeline objects when switching modes. +Also we could avoid memory leak when switching models by clearing the cache. +""" + + +def _init(): + global _sd_obj + global _config_obj + global _schedulers + _sd_obj = None + _config_obj = None + _schedulers = None + + +def set_sd_obj(value): + global _sd_obj + _sd_obj = value + + +def set_sd_scheduler(key): + global _sd_obj + _sd_obj.scheduler = _schedulers[key] + + +def set_sd_status(value): + global _sd_obj + _sd_obj.status = value + + +def set_cfg_obj(value): + global _config_obj + _config_obj = value + + +def set_schedulers(value): + global _schedulers + _schedulers = value + + +def get_sd_obj(): + global _sd_obj + return _sd_obj + + +def get_sd_status(): + global _sd_obj + return _sd_obj.status + + +def get_cfg_obj(): + global _config_obj + return _config_obj + + +def get_scheduler(key): + global _schedulers + return _schedulers[key] + + +def clear_cache(): + global _sd_obj + global _config_obj + global _schedulers + del _sd_obj + del _config_obj + del _schedulers + gc.collect() + _sd_obj = None + _config_obj = None + _schedulers = None diff --git a/apps/shark_studio/web/utils/metadata/__init__.py b/apps/shark_studio/web/utils/metadata/__init__.py new file mode 100644 index 0000000000..bcbcf746ca --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/__init__.py @@ -0,0 +1,6 @@ +from .png_metadata import ( + import_png_metadata, +) +from .display import ( + displayable_metadata, +) diff --git a/apps/shark_studio/web/utils/metadata/csv_metadata.py b/apps/shark_studio/web/utils/metadata/csv_metadata.py new file mode 100644 index 0000000000..d617e802bf --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/csv_metadata.py @@ -0,0 +1,45 @@ +import csv +import os +from .format import humanize, humanizable + + +def csv_path(image_filename: str): + return os.path.join(os.path.dirname(image_filename), "imgs_details.csv") + + +def has_csv(image_filename: str) -> bool: + return os.path.exists(csv_path(image_filename)) + + +def matching_filename(image_filename: str, row): + # we assume the final column of the csv has the original filename with full path and match that + # against the image_filename if we are given a list. Otherwise we assume a dict and and take + # the value of the OUTPUT key + return os.path.basename(image_filename) in ( + row[-1] if isinstance(row, list) else row["OUTPUT"] + ) + + +def parse_csv(image_filename: str): + csv_filename = csv_path(image_filename) + + with open(csv_filename, "r", newline="") as csv_file: + # We use a reader or DictReader here for images_details.csv depending on whether we think it + # has headers or not. Having headers means less guessing of the format. + has_header = csv.Sniffer().has_header(csv_file.read(2048)) + csv_file.seek(0) + + reader = ( + csv.DictReader(csv_file) if has_header else csv.reader(csv_file) + ) + + matches = [ + # we rely on humanize and humanizable to work out the parsing of the individual .csv rows + humanize(row) + for row in reader + if row + and (has_header or humanizable(row)) + and matching_filename(image_filename, row) + ] + + return matches[0] if matches else {} diff --git a/apps/shark_studio/web/utils/metadata/display.py b/apps/shark_studio/web/utils/metadata/display.py new file mode 100644 index 0000000000..26234aab5c --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/display.py @@ -0,0 +1,53 @@ +import json +import os +from PIL import Image +from .png_metadata import parse_generation_parameters +from .exif_metadata import has_exif, parse_exif +from .csv_metadata import has_csv, parse_csv +from .format import compact, humanize + + +def displayable_metadata(image_filename: str) -> dict: + if not os.path.isfile(image_filename): + return {"source": "missing", "parameters": {}} + + pil_image = Image.open(image_filename) + + # we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads, + # and we go via that for SendTo, and is directly tied to the image) + if "parameters" in pil_image.info: + return { + "source": "png", + "parameters": compact( + parse_generation_parameters(pil_image.info["parameters"]) + ), + } + + # we have a matching json file (next most likely to be accurate when it's there) + json_path = os.path.splitext(image_filename)[0] + ".json" + if os.path.isfile(json_path): + with open(json_path) as params_file: + return { + "source": "json", + "parameters": compact( + humanize(json.load(params_file), includes_filename=False) + ), + } + + # we have a CSV file so try that (can be different shapes, and it usually has no + # headers/param names so of the things we we *know* have parameters, it's the + # last resort) + if has_csv(image_filename): + params = parse_csv(image_filename) + if params: # we might not have found the filename in the csv + return { + "source": "csv", + "parameters": compact(params), # already humanized + } + + # EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something* + if has_exif(image_filename): + return {"source": "exif", "parameters": parse_exif(pil_image)} + + # we've got nothing + return None diff --git a/apps/shark_studio/web/utils/metadata/exif_metadata.py b/apps/shark_studio/web/utils/metadata/exif_metadata.py new file mode 100644 index 0000000000..c72da8a935 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/exif_metadata.py @@ -0,0 +1,52 @@ +from PIL import Image +from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS + + +def has_exif(image_filename: str) -> bool: + return True if Image.open(image_filename).getexif() else False + + +def parse_exif(pil_image: Image) -> dict: + img_exif = pil_image.getexif() + + # See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594 + # I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I + # I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a + # dependency + exif_tags = { + TAGS.get(key, key): str(val) + for (key, val) in img_exif.items() + if key in TAGS + and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo) + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + def try_get_ifd(ifd_id): + try: + return img_exif.get_ifd(ifd_id).items() + except KeyError: + return {} + + ifd_tags = { + TAGS.get(key, key): str(val) + for ifd_id in IFD + for (key, val) in try_get_ifd(ifd_id) + if ifd_id != IFD.GPSInfo + and key in TAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + gps_tags = { + GPSTAGS.get(key, key): str(val) + for (key, val) in try_get_ifd(IFD.GPSInfo) + if key in GPSTAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + return {**exif_tags, **ifd_tags, **gps_tags} diff --git a/apps/shark_studio/web/utils/metadata/format.py b/apps/shark_studio/web/utils/metadata/format.py new file mode 100644 index 0000000000..f097dab54f --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/format.py @@ -0,0 +1,143 @@ +# As SHARK has evolved more columns have been added to images_details.csv. However, since +# no version of the CSV has any headers (yet) we don't actually have anything within the +# file that tells us which parameter each column is for. So this is a list of known patterns +# indexed by length which is what we're going to have to use to guess which columns are the +# right ones for the file we're looking at. + +# The same ordering is used for JSON, but these do have key names, however they are not very +# human friendly, nor do they match up with the what is written to the .png headers + +# So these are functions to try and get something consistent out the raw input from all +# these sources + +PARAMS_FORMATS = { + 9: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 10: { + "MODEL": "Model", + "VARIANT": "Variant", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 12: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", + }, +} + +PARAMS_FORMAT_CURRENT = { + "VARIANT": "Model", + "VAE": "VAE", + "LORA": "LoRA", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", +} + + +def compact(metadata: dict) -> dict: + # we don't want to alter the original dictionary + result = dict(metadata) + + # discard the filename because we should already have it + if result.keys() & {"Filename"}: + result.pop("Filename") + + # make showing the sizes more compact by using only one line each + if result.keys() & {"Size-1", "Size-2"}: + result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}" + elif result.keys() & {"Height", "Width"}: + result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}" + + if result.keys() & {"Hires resize-1", "Hires resize-1"}: + hires_y = result.pop("Hires resize-1") + hires_x = result.pop("Hires resize-2") + + if hires_x == 0 and hires_y == 0: + result["Hires resize"] = "None" + else: + result["Hires resize"] = f"{hires_y}x{hires_x}" + + # remove VAE if it exists and is empty + if (result.keys() & {"VAE"}) and ( + not result["VAE"] or result["VAE"] == "None" + ): + result.pop("VAE") + + # remove LoRA if it exists and is empty + if (result.keys() & {"LoRA"}) and ( + not result["LoRA"] or result["LoRA"] == "None" + ): + result.pop("LoRA") + + return result + + +def humanizable(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + return lookup_key in PARAMS_FORMATS.keys() + + +def humanize(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + + # For lists we can only work based on the length, we have no other information + if isinstance(metadata, list): + if humanizable(metadata, includes_filename): + return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata)) + else: + raise KeyError( + f"Humanize could not find the format for a parameter list of length {len(metadata)}" + ) + + # For dictionaries we try to use the matching length parameter format if + # available, otherwise we just use the current format which is assumed to + # have everything currently known about. Then we swap keys in the metadata + # that match keys in the format for the friendlier name that we have set + # in the format value + if isinstance(metadata, dict): + if humanizable(metadata, includes_filename): + format = PARAMS_FORMATS[lookup_key] + else: + format = PARAMS_FORMAT_CURRENT + + return { + format[key]: metadata[key] + for key in format.keys() + if key in metadata.keys() and metadata[key] + } + + raise TypeError("Can only humanize parameter lists or dictionaries") diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py new file mode 100644 index 0000000000..cffc385ab7 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -0,0 +1,222 @@ +import re +from pathlib import Path +from apps.shark_studio.api.utils import ( + get_checkpoint_pathfile, +) +from apps.shark_studio.api.sd import ( + sd_model_map, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, +) + +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' +re_param = re.compile(re_param_code) +re_imagesize = re.compile(r"^(\d+)x(\d+)$") + + +def parse_generation_parameters(x: str): + res = {} + prompt = "" + negative_prompt = "" + done_with_prompt = False + + *lines, lastline = x.strip().split("\n") + if len(re_param.findall(lastline)) < 3: + lines.append(lastline) + lastline = "" + + for i, line in enumerate(lines): + line = line.strip() + if line.startswith("Negative prompt:"): + done_with_prompt = True + line = line[16:].strip() + + if done_with_prompt: + negative_prompt += ("" if negative_prompt == "" else "\n") + line + else: + prompt += ("" if prompt == "" else "\n") + line + + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt + + for k, v in re_param.findall(lastline): + v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v + m = re_imagesize.match(v) + if m is not None: + res[k + "-1"] = m.group(1) + res[k + "-2"] = m.group(2) + else: + res[k] = v + + # Missing CLIP skip means it was set to 1 (the default) + if "Clip skip" not in res: + res["Clip skip"] = "1" + + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res[ + "Prompt" + ] += f"""""" + + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + + return res + + +def try_find_model_base_from_png_metadata( + file: str, folder: str = "models" +) -> str: + custom = "" + + # Remove extension from file info + if file.endswith(".safetensors") or file.endswith(".ckpt"): + file = Path(file).stem + # Check for the file name match with one of the local ckpt or safetensors files + if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file(): + custom = file + ".ckpt" + if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file(): + custom = file + ".safetensors" + + return custom + + +def find_model_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + png_hf_id = "" + png_custom = "" + + if key in metadata: + model_file = metadata[key] + png_custom = try_find_model_base_from_png_metadata(model_file) + # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") + if model_file in sd_model_map: + png_custom = model_file + # If nothing had matched, check vendor/hf_model_id + if not png_custom and model_file.count("/"): + png_hf_id = model_file + # No matching model was found + if not png_custom and not png_hf_id: + print( + "Import PNG info: Unable to find a matching model for %s" + % model_file + ) + + return png_custom, png_hf_id + + +def find_vae_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> str: + vae_custom = "" + + if key in metadata: + vae_file = metadata[key] + vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae") + + # VAE input is optional, should not print or throw an error if missing + + return vae_custom + + +def find_lora_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + lora_hf_id = "" + lora_custom = "" + + if key in metadata: + lora_file = metadata[key] + lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora") + # If nothing had matched, check vendor/hf_model_id + if not lora_custom and lora_file.count("/"): + lora_hf_id = lora_file + + # LoRA input is optional, should not print or throw an error if missing + + return lora_custom, lora_hf_id + + +def import_png_metadata( + pil_data, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, +): + try: + png_info = pil_data.info["parameters"] + metadata = parse_generation_parameters(png_info) + + (png_custom_model, png_hf_model_id) = find_model_from_png_metadata( + "Model", metadata + ) + (lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata( + "LoRA", metadata + ) + vae_custom_model = find_vae_from_png_metadata("VAE", metadata) + + negative_prompt = metadata["Negative prompt"] + steps = int(metadata["Steps"]) + cfg_scale = float(metadata["CFG scale"]) + seed = int(metadata["Seed"]) + width = float(metadata["Size-1"]) + height = float(metadata["Size-2"]) + + if "Model" in metadata and png_custom_model: + custom_model = png_custom_model + elif "Model" in metadata and png_hf_model_id: + custom_model = png_hf_model_id + + if "LoRA" in metadata and lora_custom_model: + custom_lora = lora_custom_model + hf_lora_id = "" + if "LoRA" in metadata and lora_hf_model_id: + custom_lora = "None" + hf_lora_id = lora_hf_model_id + + if "VAE" in metadata and vae_custom_model: + custom_vae = vae_custom_model + + if "Prompt" in metadata: + prompt = metadata["Prompt"] + if "Sampler" in metadata: + if metadata["Sampler"] in scheduler_model_map: + sampler = metadata["Sampler"] + else: + print( + "Import PNG info: Unable to find a scheduler for %s" + % metadata["Sampler"] + ) + + except Exception as ex: + if pil_data and pil_data.info.get("parameters"): + print("import_png_metadata failed with %s" % ex) + pass + + return ( + None, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, + ) diff --git a/apps/shark_studio/web/utils/state.py b/apps/shark_studio/web/utils/state.py new file mode 100644 index 0000000000..626d4ce53f --- /dev/null +++ b/apps/shark_studio/web/utils/state.py @@ -0,0 +1,41 @@ +import apps.shark_studio.web.utils.globals as global_obj +import gc + + +def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1): + print(f"Getting status label for {tab_name}") + if batch_index < batch_count: + bs = f"x{batch_size}" if batch_size > 1 else "" + return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}" + else: + return f"{tab_name} complete" + + +def get_generation_text_info(seeds, device): + cfg_dump = {} + for cfg in global_obj.get_config_dict(): + cfg_dump[cfg] = cfg + text_output = f"prompt={cfg_dump['prompts']}" + text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}" + text_output += ( + f"\nmodel_id={cfg_dump['hf_model_id']}, " + f"ckpt_loc={cfg_dump['ckpt_loc']}" + ) + text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}" + text_output += ( + f"\nsteps={cfg_dump['steps']}, " + f"guidance_scale={cfg_dump['guidance_scale']}, " + f"seed={seeds}" + ) + text_output += ( + f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, " + if not cfg_dump.use_hiresfix + else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, " + ) + text_output += ( + f"batch_count={cfg_dump['batch_count']}, " + f"batch_size={cfg_dump['batch_size']}, " + f"max_length={cfg_dump['max_length']}" + ) + + return text_output diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py new file mode 100644 index 0000000000..3e6ba46bfe --- /dev/null +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -0,0 +1,77 @@ +import os +import shutil +from time import time + +shark_tmp = os.path.join(os.getcwd(), "shark_tmp/") + + +def clear_tmp_mlir(): + cleanup_start = time() + print( + "Clearing .mlir temporary files from a prior run. This may take some time..." + ) + mlir_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.endswith(".mlir") + ] + for filename in mlir_files: + os.remove(shark_tmp + filename) + print( + f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds." + ) + + +def clear_tmp_imgs(): + # tell gradio to use a directory under shark_tmp for its temporary + # image files unless somewhere else has been set + if "GRADIO_TEMP_DIR" not in os.environ: + os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio") + + print( + f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. " + + "You may change this by setting the GRADIO_TEMP_DIR environment variable." + ) + + # Clear all gradio tmp images from the last session + if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): + cleanup_start = time() + print( + "Clearing gradio UI temporary image files from a prior run. This may take some time..." + ) + shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True) + print( + f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds." + ) + + # older SHARK versions had to workaround gradio bugs and stored things differently + else: + image_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.startswith("tmp") + and filename.endswith(".png") + ] + if len(image_files) > 0: + print( + "Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..." + ) + cleanup_start = time() + for filename in image_files: + os.remove(shark_tmp + filename) + print( + f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds." + ) + else: + print("No temporary images files to clear.") + + +def config_tmp(): + # create shark_tmp if it does not exist + if not os.path.exists(shark_tmp): + os.mkdir(shark_tmp) + + clear_tmp_mlir() + clear_tmp_imgs()