diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py new file mode 100644 index 0000000000..ea8cdf0cc9 --- /dev/null +++ b/apps/shark_studio/api/controlnet.py @@ -0,0 +1,134 @@ +# from turbine_models.custom_models.controlnet import control_adapter, preprocessors + + +class control_adapter: + def __init__( + self, + model: str, + ): + self.model = None + + def export_control_adapter_model(model_keyword): + return None + + def export_xl_control_adapter_model(model_keyword): + return None + + +class preprocessors: + def __init__( + self, + model: str, + ): + self.model = None + + def export_controlnet_model(model_keyword): + return None + + +control_adapter_map = { + "sd15": { + "canny": {"initializer": control_adapter.export_control_adapter_model}, + "openpose": { + "initializer": control_adapter.export_control_adapter_model + }, + "scribble": { + "initializer": control_adapter.export_control_adapter_model + }, + "zoedepth": { + "initializer": control_adapter.export_control_adapter_model + }, + }, + "sdxl": { + "canny": { + "initializer": control_adapter.export_xl_control_adapter_model + }, + }, +} +preprocessor_model_map = { + "canny": {"initializer": preprocessors.export_controlnet_model}, + "openpose": {"initializer": preprocessors.export_controlnet_model}, + "scribble": {"initializer": preprocessors.export_controlnet_model}, + "zoedepth": {"initializer": preprocessors.export_controlnet_model}, +} + + +class PreprocessorModel: + def __init__( + self, + hf_model_id, + device, + ): + self.model = None + + def compile(self, device): + print("compile not implemented for preprocessor.") + return + + def run(self, inputs): + print("run not implemented for preprocessor.") + return + + +def cnet_preview(model, input_img, stencils, images, preprocessed_hints): + if isinstance(input_image, PIL.Image.Image): + img_dict = { + "background": None, + "layers": [None], + "composite": input_image, + } + input_image = EditorValue(img_dict) + images[index] = input_image + if model: + stencils[index] = model + match model: + case "canny": + canny = CannyDetector() + result = canny( + np.array(input_image["composite"]), + 100, + 200, + ) + preprocessed_hints[index] = Image.fromarray(result) + return ( + Image.fromarray(result), + stencils, + images, + preprocessed_hints, + ) + case "openpose": + openpose = OpenposeDetector() + result = openpose(np.array(input_image["composite"])) + preprocessed_hints[index] = Image.fromarray(result[0]) + return ( + Image.fromarray(result[0]), + stencils, + images, + preprocessed_hints, + ) + case "zoedepth": + zoedepth = ZoeDetector() + result = zoedepth(np.array(input_image["composite"])) + preprocessed_hints[index] = Image.fromarray(result) + return ( + Image.fromarray(result), + stencils, + images, + preprocessed_hints, + ) + case "scribble": + preprocessed_hints[index] = input_image["composite"] + return ( + input_image["composite"], + stencils, + images, + preprocessed_hints, + ) + case _: + preprocessed_hints[index] = None + return ( + None, + stencils, + images, + preprocessed_hints, + ) diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py new file mode 100644 index 0000000000..bbb273354c --- /dev/null +++ b/apps/shark_studio/api/initializers.py @@ -0,0 +1,87 @@ +import importlib +import logging +import os +import signal +import sys +import re +import warnings +import json +from threading import Thread + +from apps.shark_studio.modules.timer import startup_timer + + +def imports(): + import torch # noqa: F401 + + startup_timer.record("import torch") + warnings.filterwarnings( + action="ignore", category=DeprecationWarning, module="torch" + ) + warnings.filterwarnings( + action="ignore", category=UserWarning, module="torchvision" + ) + + import gradio # noqa: F401 + + startup_timer.record("import gradio") + + import apps.shark_studio.web.utils.globals as global_obj + + global_obj._init() + startup_timer.record("initialize globals") + + from apps.shark_studio.modules import ( + img_processing, + ) # noqa: F401 + from apps.shark_studio.modules.schedulers import scheduler_model_map + + startup_timer.record("other imports") + + +def initialize(): + configure_sigint_handler() + + # from apps.shark_studio.modules import modelloader + # modelloader.cleanup_models() + + # from apps.shark_studio.modules import sd_models + # sd_models.setup_model() + # startup_timer.record("setup SD model") + + # initialize_rest(reload_script_modules=False) + + +def initialize_rest(*, reload_script_modules=False): + """ + Called both from initialize() and when reloading the webui. + """ + # Keep this for adding reload options to the webUI. + + +def dumpstacks(): + import threading + import traceback + + id2name = {th.ident: th.name for th in threading.enumerate()} + code = [] + for threadId, stack in sys._current_frames().items(): + code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})") + for filename, lineno, name, line in traceback.extract_stack(stack): + code.append(f"""File: "{filename}", line {lineno}, in {name}""") + if line: + code.append(" " + line.strip()) + + print("\n".join(code)) + + +def configure_sigint_handler(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f"Interrupted with signal {sig} in {frame}") + + dumpstacks() + + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py new file mode 100644 index 0000000000..a601a068f7 --- /dev/null +++ b/apps/shark_studio/api/sd.py @@ -0,0 +1,308 @@ +from turbine_models.custom_models.sd_inference import clip, unet, vae +from shark.iree_utils.compile_utils import get_iree_compiled_module +from apps.shark_studio.api.utils import get_resource_path +from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.web.utils.state import status_label +from apps.shark_studio.modules.pipeline import SharkPipelineBase +import iree.runtime as ireert +import gc +import torch +import gradio as gr + +sd_model_map = { + "CompVis/stable-diffusion-v1-4": { + "clip": { + "initializer": clip.export_clip_model, + "max_tokens": 64, + }, + "vae_encode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + "unet": { + "initializer": unet.export_unet_model, + "max_tokens": 512, + }, + "vae_decode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + }, + "runwayml/stable-diffusion-v1-5": { + "clip": { + "initializer": clip.export_clip_model, + "max_tokens": 64, + }, + "vae_encode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + "unet": { + "initializer": unet.export_unet_model, + "max_tokens": 512, + }, + "vae_decode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + }, + "stabilityai/stable-diffusion-2-1-base": { + "clip": { + "initializer": clip.export_clip_model, + "max_tokens": 64, + }, + "vae_encode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + "unet": { + "initializer": unet.export_unet_model, + "max_tokens": 512, + }, + "vae_decode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + }, + "stabilityai/stable_diffusion-xl-1.0": { + "clip_1": { + "initializer": clip.export_clip_model, + "max_tokens": 64, + }, + "clip_2": { + "initializer": clip.export_clip_model, + "max_tokens": 64, + }, + "vae_encode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + "unet": { + "initializer": unet.export_unet_model, + "max_tokens": 512, + }, + "vae_decode": { + "initializer": vae.export_vae_model, + "max_tokens": 64, + }, + }, +} + + +class StableDiffusion(SharkPipelineBase): + # This class is responsible for executing image generation and creating + # /managing a set of compiled modules to run Stable Diffusion. The init + # aims to be as general as possible, and the class will infer and compile + # a list of necessary modules or a combined "pipeline module" for a + # specified job based on the inference task. + # + # custom_model_ids: a dict of submodel + HF ID pairs for custom submodels. + # e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"} + # + # embeddings: a dict of embedding checkpoints or model IDs to use when + # initializing the compiled modules. + + def __init__( + self, + base_model_id: str = "runwayml/stable-diffusion-v1-5", + height: int = 512, + width: int = 512, + precision: str = "fp16", + device: str = None, + custom_model_map: dict = {}, + embeddings: dict = {}, + import_ir: bool = True, + ): + super().__init__(sd_model_map[base_model_id], device, import_ir) + self.base_model_id = base_model_id + self.device = device + self.precision = precision + self.iree_module_dict = None + self.get_compiled_map() + + def prepare_pipeline(self, scheduler, custom_model_map): + return None + + def generate_images( + self, + prompt, + negative_prompt, + steps, + strength, + guidance_scale, + seed, + ondemand, + repeatable_seeds, + resample_type, + control_mode, + preprocessed_hints, + ): + return None, None, None, None, None + + +# NOTE: Each `hf_model_id` should have its own starting configuration. + +# model_vmfb_key = "" + + +def shark_sd_fn( + prompt, + negative_prompt, + image_dict, + height: int, + width: int, + steps: int, + strength: float, + guidance_scale: float, + seed: str | int, + batch_count: int, + batch_size: int, + scheduler: str, + base_model_id: str, + custom_weights: str, + custom_vae: str, + precision: str, + device: str, + lora_weights: str | list, + ondemand: bool, + repeatable_seeds: bool, + resample_type: str, + control_mode: str, + stencils: list, + images: list, + preprocessed_hints: list, + progress=gr.Progress(), +): + # Handling gradio ImageEditor datatypes so we have unified inputs to the SD API + for i, stencil in enumerate(stencils): + if images[i] is None and stencil is not None: + continue + elif stencil is None and any( + img is not None for img in [images[i], preprocessed_hints[i]] + ): + images[i] = None + preprocessed_hints[i] = None + elif images[i] is not None: + if isinstance(images[i], dict): + images[i] = images[i]["composite"] + images[i] = images[i].convert("RGB") + + if isinstance(image_dict, PIL.Image.Image): + image = image_dict.convert("RGB") + elif image_dict: + image = image_dict["image"].convert("RGB") + else: + image = None + is_img2img = False + if image: + ( + image, + _, + _, + ) = resize_stencil(image, width, height) + is_img2img = True + print("Performing Stable Diffusion Pipeline setup...") + + device_id = None + + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import apps.shark_studio.web.utils.globals as global_obj + + custom_model_map = {} + if custom_weights != "None": + custom_model_map["unet"] = {"custom_weights": custom_weights} + if custom_vae != "None": + custom_model_map["vae"] = {"custom_weights": custom_vae} + if stencils: + for i, stencil in enumerate(stencils): + if "xl" not in base_model_id.lower(): + custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ + "runwayml/stable-diffusion-v1-5" + ][stencil] + else: + custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ + "stabilityai/stable-diffusion-xl-1.0" + ][stencil] + + submit_pipe_kwargs = { + "base_model_id": base_model_id, + "height": height, + "width": width, + "precision": precision, + "device": device, + "custom_model_map": custom_model_map, + "import_ir": cmd_opts.import_mlir, + "is_img2img": is_img2img, + } + submit_prep_kwargs = { + "scheduler": scheduler, + "custom_model_map": custom_model_map, + "embeddings": lora_weights, + } + submit_run_kwargs = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "steps": steps, + "strength": strength, + "guidance_scale": guidance_scale, + "seed": seed, + "ondemand": ondemand, + "repeatable_seeds": repeatable_seeds, + "resample_type": resample_type, + "control_mode": control_mode, + "preprocessed_hints": preprocessed_hints, + } + + global sd_pipe + global sd_pipe_kwargs + + if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs: + sd_pipe = None + sd_pipe_kwargs = submit_pipe_kwargs + gc.collect() + + if sd_pipe is None: + history[-1][-1] = "Getting the pipeline ready..." + yield history, "" + + # Initializes the pipeline and retrieves IR based on all + # parameters that are static in the turbine output format, + # which is currently MLIR in the torch dialect. + + sd_pipe = SharkStableDiffusionPipeline( + **submit_pipe_kwargs, + ) + + sd_pipe.prepare_pipe(**submit_prep_kwargs) + + for prompt, msg, exec_time in progress.tqdm( + out_imgs=sd_pipe.generate_images(**submit_run_kwargs), + desc="Generating Image...", + ): + text_output = get_generation_text_info( + seeds[: current_batch + 1], device + ) + save_output_img( + out_imgs[0], + seeds[current_batch], + extra_info, + ) + generated_imgs.extend(out_imgs) + yield generated_imgs, text_output, status_label( + "Stable Diffusion", current_batch + 1, batch_count, batch_size + ), stencils, images + + return generated_imgs, text_output, "", stencils, images + + +def cancel_sd(): + print("Inject call to cancel longer API calls.") + return + + +if __name__ == "__main__": + sd = StableDiffusion( + "runwayml/stable-diffusion-v1-5", + device="vulkan", + ) + print("model loaded") diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 4072491cbf..a4f52dca24 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -1,12 +1,407 @@ import os import sys +import os +import numpy as np +import glob +from random import ( + randint, + seed as seed_random, + getstate as random_getstate, + setstate as random_setstate, +) + +from pathlib import Path +from safetensors.torch import load_file +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from cpuinfo import get_cpu_info + +# TODO: migrate these utils to studio +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, + get_iree_vulkan_runtime_flags, +) + +checkpoints_filetypes = ( + "*.ckpt", + "*.safetensors", +) def get_available_devices(): - return ["cpu-task"] + def get_devices_by_name(driver_name): + from shark.iree_utils._common import iree_device_map + + device_list = [] + try: + driver_name = iree_device_map(driver_name) + device_list_dict = get_all_devices(driver_name) + print(f"{driver_name} devices are available.") + except: + print(f"{driver_name} devices are not available.") + else: + cpu_name = get_cpu_info()["brand_raw"] + for i, device in enumerate(device_list_dict): + device_name = ( + cpu_name if device["name"] == "default" else device["name"] + ) + if "local" in driver_name: + device_list.append( + f"{device_name} => {driver_name.replace('local', 'cpu')}" + ) + else: + # for drivers with single devices + # let the default device be selected without any indexing + if len(device_list_dict) == 1: + device_list.append(f"{device_name} => {driver_name}") + else: + device_list.append( + f"{device_name} => {driver_name}://{i}" + ) + return device_list + + set_iree_runtime_flags() + + available_devices = [] + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + ) + + vulkaninfo_list = get_all_vulkan_devices() + vulkan_devices = [] + id = 0 + for device in vulkaninfo_list: + vulkan_devices.append(f"{device.strip()} => vulkan://{id}") + id += 1 + if id != 0: + print(f"vulkan devices are available.") + available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) + cuda_devices = get_devices_by_name("cuda") + available_devices.extend(cuda_devices) + rocm_devices = get_devices_by_name("rocm") + available_devices.extend(rocm_devices) + cpu_device = get_devices_by_name("cpu-sync") + available_devices.extend(cpu_device) + cpu_device = get_devices_by_name("cpu-task") + available_devices.extend(cpu_device) + return available_devices + + +def set_init_device_flags(): + if "vulkan" in cmd_opts.device: + # set runtime flags for vulkan. + set_iree_runtime_flags() + + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_vulkan_target_triple: + triple = get_vulkan_target_triple(device_name) + if triple is not None: + cmd_opts.iree_vulkan_target_triple = triple + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_vulkan_target_triple}." + ) + elif "cuda" in cmd_opts.device: + cmd_opts.device = "cuda" + elif "metal" in cmd_opts.device: + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_metal_target_platform: + triple = get_metal_target_triple(device_name) + if triple is not None: + cmd_opts.iree_metal_target_platform = triple.split("-")[-1] + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_metal_target_platform}." + ) + elif "cpu" in cmd_opts.device: + cmd_opts.device = "cpu" + + +def set_iree_runtime_flags(): + # TODO: This function should be device-agnostic and piped properly + # to general runtime driver init. + vulkan_runtime_flags = get_iree_vulkan_runtime_flags() + if cmd_opts.enable_rgp: + vulkan_runtime_flags += [ + f"--enable_rgp=true", + f"--vulkan_debug_utils=true", + ] + if cmd_opts.device_allocator_heap_key: + vulkan_runtime_flags += [ + f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}", + ] + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + + +def get_all_devices(driver_name): + """ + Inputs: driver_name + Returns a list of all the available devices for a given driver sorted by + the iree path names of the device as in --list_devices option in iree. + """ + from iree.runtime import get_driver + + driver = get_driver(driver_name) + device_list_src = driver.query_available_devices() + device_list_src.sort(key=lambda d: d["path"]) + return device_list_src def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) + + +def get_generated_imgs_path() -> Path: + return Path( + cmd_opts.output_dir + if cmd_opts.output_dir + else get_resource_path("..\web\generated_imgs") + ) + + +def get_generated_imgs_todays_subdir() -> str: + return dt.now().strftime("%Y%m%d") + + +def create_checkpoint_folders(): + dir = ["vae", "lora"] + if not cmd_opts.ckpt_dir: + dir.insert(0, "models") + else: + if not os.path.isdir(cmd_opts.ckpt_dir): + sys.exit( + f"Invalid --ckpt_dir argument, " + f"{args.ckpt_dir} folder does not exists." + ) + for root in dir: + Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True) + + +def get_checkpoints_path(model=""): + return get_resource_path(f"..\web\models\{model}") + + +def get_checkpoints(model="models"): + ckpt_files = [] + file_types = checkpoints_filetypes + if model == "lora": + file_types = file_types + ("*.pt", "*.bin") + for extn in file_types: + files = [ + os.path.basename(x) + for x in glob.glob(os.path.join(get_checkpoints_path(model), extn)) + ] + ckpt_files.extend(files) + return sorted(ckpt_files, key=str.casefold) + + +def get_checkpoint_pathfile(checkpoint_name, model="models"): + return os.path.join(get_checkpoints_path(model), checkpoint_name) + + +def get_device_mapping(driver, key_combination=3): + """This method ensures consistent device ordering when choosing + specific devices for execution + Args: + driver (str): execution driver (vulkan, cuda, rocm, etc) + key_combination (int, optional): choice for mapping value for + device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Returns: + dict: map to possible device names user can input mapped to desired + combination of name/path. + """ + from shark.iree_utils._common import iree_device_map + + driver = iree_device_map(driver) + device_list = get_all_devices(driver) + device_map = dict() + + def get_output_value(dev_dict): + if key_combination == 1: + return f"{driver}://{dev_dict['path']}" + if key_combination == 2: + return dev_dict["name"] + if key_combination == 3: + return dev_dict["name"], f"{driver}://{dev_dict['path']}" + + # mapping driver name to default device (driver://0) + device_map[f"{driver}"] = get_output_value(device_list[0]) + for i, device in enumerate(device_list): + # mapping with index + device_map[f"{driver}://{i}"] = get_output_value(device) + # mapping with full path + device_map[f"{driver}://{device['path']}"] = get_output_value(device) + return device_map + + +def get_opt_flags(model, precision="fp16"): + iree_flags = [] + if len(cmd_opts.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" + ) + if "rocm" in cmd_opts.device: + rocm_args = get_iree_rocm_args() + iree_flags.extend(rocm_args) + if cmd_opts.iree_constant_folding == False: + iree_flags.append("--iree-opt-const-expr-hoisting=False") + iree_flags.append( + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + ) + if cmd_opts.data_tiling == False: + iree_flags.append("--iree-opt-data-tiling=False") + + if "vae" not in model: + # Due to lack of support for multi-reduce, we always collapse reduction + # dims before dispatch formation right now. + iree_flags += ["--iree-flow-collapse-reduction-dims"] + return iree_flags + + +def map_device_to_name_path(device, key_combination=3): + """Gives the appropriate device data (supported name/path) for user + selected execution device + Args: + device (str): user + key_combination (int, optional): choice for mapping value for + device name. + 1 : path + 2 : name + 3 : (name, path) + Defaults to 3. + Raises: + ValueError: + Returns: + str / tuple: returns the mapping str or tuple of mapping str for + the device depending on key_combination value + """ + driver = device.split("://")[0] + device_map = get_device_mapping(driver, key_combination) + try: + device_mapping = device_map[device] + except KeyError: + raise ValueError(f"Device '{device}' is not a valid device.") + return device_mapping + + def get_devices_by_name(driver_name): + from shark.iree_utils._common import iree_device_map + + device_list = [] + try: + driver_name = iree_device_map(driver_name) + device_list_dict = get_all_devices(driver_name) + print(f"{driver_name} devices are available.") + except: + print(f"{driver_name} devices are not available.") + else: + cpu_name = get_cpu_info()["brand_raw"] + for i, device in enumerate(device_list_dict): + device_name = ( + cpu_name if device["name"] == "default" else device["name"] + ) + if "local" in driver_name: + device_list.append( + f"{device_name} => {driver_name.replace('local', 'cpu')}" + ) + else: + # for drivers with single devices + # let the default device be selected without any indexing + if len(device_list_dict) == 1: + device_list.append(f"{device_name} => {driver_name}") + else: + device_list.append( + f"{device_name} => {driver_name}://{i}" + ) + return device_list + + set_iree_runtime_flags() + + available_devices = [] + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + ) + + vulkaninfo_list = get_all_vulkan_devices() + vulkan_devices = [] + id = 0 + for device in vulkaninfo_list: + vulkan_devices.append(f"{device.strip()} => vulkan://{id}") + id += 1 + if id != 0: + print(f"vulkan devices are available.") + available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) + cuda_devices = get_devices_by_name("cuda") + available_devices.extend(cuda_devices) + rocm_devices = get_devices_by_name("rocm") + available_devices.extend(rocm_devices) + cpu_device = get_devices_by_name("cpu-sync") + available_devices.extend(cpu_device) + cpu_device = get_devices_by_name("cpu-task") + available_devices.extend(cpu_device) + return available_devices + + +# take a seed expression in an input format and convert it to +# a list of integers, where possible +def parse_seed_input(seed_input: str | list | int): + if isinstance(seed_input, str): + try: + seed_input = json.loads(seed_input) + except (ValueError, TypeError): + seed_input = None + + if isinstance(seed_input, int): + return [seed_input] + + if isinstance(seed_input, list) and all( + type(seed) is int for seed in seed_input + ): + return seed_input + + raise TypeError( + "Seed input must be an integer or an array of integers in JSON format" + ) + + +# Generate and return a new seed if the provided one is not in the +# supported range (including -1) +def sanitize_seed(seed: int | str): + seed = int(seed) + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + return seed + + +# take a seed expression in an input format and convert it to +# a list of integers, where possible +def parse_seed_input(seed_input: str | list | int): + if isinstance(seed_input, str): + try: + seed_input = json.loads(seed_input) + except (ValueError, TypeError): + seed_input = None + + if isinstance(seed_input, int): + return [seed_input] + + if isinstance(seed_input, list) and all( + type(seed) is int for seed in seed_input + ): + return seed_input + + raise TypeError( + "Seed input must be an integer or an array of integers in JSON format" + ) diff --git a/apps/shark_studio/modules/checkpoint_proc.py b/apps/shark_studio/modules/checkpoint_proc.py new file mode 100644 index 0000000000..e924de4640 --- /dev/null +++ b/apps/shark_studio/modules/checkpoint_proc.py @@ -0,0 +1,66 @@ +import os +import json +import re +from pathlib import Path +from omegaconf import OmegaConf + + +def get_path_to_diffusers_checkpoint(custom_weights): + path = Path(custom_weights) + diffusers_path = path.parent.absolute() + diffusers_directory_name = os.path.join("diffusers", path.stem) + complete_path_to_diffusers = diffusers_path / diffusers_directory_name + complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) + path_to_diffusers = complete_path_to_diffusers.as_posix() + return path_to_diffusers + + +def preprocessCKPT(custom_weights, is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) + if next(Path(path_to_diffusers).iterdir(), None): + print("Checkpoint already loaded at : ", path_to_diffusers) + return + else: + print( + "Diffusers' checkpoint will be identified here : ", + path_to_diffusers, + ) + from_safetensors = ( + True if custom_weights.lower().endswith(".safetensors") else False + ) + # EMA weights usually yield higher quality images for inference but + # non-EMA weights have been yielding better results in our case. + # TODO: Add an option `--ema` (`--no-ema`) for users to specify if + # they want to go for EMA weight extraction or not. + extract_ema = False + print( + "Loading diffusers' pipeline from original stable diffusion checkpoint" + ) + num_in_channels = 9 if is_inpaint else 4 + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=custom_weights, + extract_ema=extract_ema, + from_safetensors=from_safetensors, + num_in_channels=num_in_channels, + ) + pipe.save_pretrained(path_to_diffusers) + print("Loading complete") + + +def convert_original_vae(vae_checkpoint): + vae_state_dict = {} + for key in list(vae_checkpoint.keys()): + vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key) + + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/" + "main/configs/stable-diffusion/v1-inference.yaml" + ) + original_config_file = BytesIO(requests.get(config_url).content) + original_config = OmegaConf.load(original_config_file) + vae_config = create_vae_diffusers_config(original_config, image_size=512) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + vae_state_dict, vae_config + ) + return converted_vae_checkpoint diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py new file mode 100644 index 0000000000..d8cf544f81 --- /dev/null +++ b/apps/shark_studio/modules/embeddings.py @@ -0,0 +1,171 @@ +import os +import sys +import torch +import json +import safetensors +from safetensors.torch import load_file +from apps.shark_studio.api.utils import get_checkpoint_pathfile + + +def processLoRA(model, use_lora, splitting_prefix): + state_dict = "" + if ".safetensors" in use_lora: + state_dict = load_file(use_lora) + else: + state_dict = torch.load(use_lora) + alpha = 0.75 + visited = [] + + # directly update weight in model + process_unet = "te" not in splitting_prefix + for key in state_dict: + if ".alpha" in key or key in visited: + continue + + curr_layer = model + if ("text" not in key and process_unet) or ( + "text" in key and not process_unet + ): + layer_infos = ( + key.split(".")[0].split(splitting_prefix)[-1].split("_") + ) + else: + continue + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = ( + state_dict[pair_keys[0]] + .squeeze(3) + .squeeze(2) + .to(torch.float32) + ) + weight_down = ( + state_dict[pair_keys[1]] + .squeeze(3) + .squeeze(2) + .to(torch.float32) + ) + curr_layer.weight.data += alpha * torch.mm( + weight_up, weight_down + ).unsqueeze(2).unsqueeze(3) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) + # update visited list + for item in pair_keys: + visited.append(item) + return model + + +def update_lora_weight_for_unet(unet, use_lora): + extensions = [".bin", ".safetensors", ".pt"] + if not any([extension in use_lora for extension in extensions]): + # We assume if it is a HF ID with standalone LoRA weights. + unet.load_attn_procs(use_lora) + return unet + + main_file_name = get_path_stem(use_lora) + if ".bin" in use_lora: + main_file_name += ".bin" + elif ".safetensors" in use_lora: + main_file_name += ".safetensors" + elif ".pt" in use_lora: + main_file_name += ".pt" + else: + sys.exit("Only .bin and .safetensors format for LoRA is supported") + + try: + dir_name = os.path.dirname(use_lora) + unet.load_attn_procs(dir_name, weight_name=main_file_name) + return unet + except: + return processLoRA(unet, use_lora, "lora_unet_") + + +def update_lora_weight(model, use_lora, model_name): + if "unet" in model_name: + return update_lora_weight_for_unet(model, use_lora) + try: + return processLoRA(model, use_lora, "lora_te_") + except: + return None + + +def get_lora_metadata(lora_filename): + # get the metadata from the file + filename = get_checkpoint_pathfile(lora_filename, "lora") + with safetensors.safe_open(filename, framework="pt", device="cpu") as f: + metadata = f.metadata() + + # guard clause for if there isn't any metadata + if not metadata: + return None + + # metadata is a dictionary of strings, the values of the keys we're + # interested in are actually json, and need to be loaded as such + tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}"))) + dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}"))) + tag_dirs = [dir for dir in tag_frequencies.keys()] + + # gather the tag frequency information for all the datasets trained + all_frequencies = {} + for dataset in tag_dirs: + frequencies = sorted( + [entry for entry in tag_frequencies[dataset].items()], + reverse=True, + key=lambda x: x[1], + ) + + # get a figure for the total number of images processed for this dataset + # either then number actually listed or in its dataset_dir entry or + # the highest frequency's number if that doesn't exist + img_count = dataset_dirs.get(dir, {}).get( + "img_count", frequencies[0][1] + ) + + # add the dataset frequencies to the overall frequencies replacing the + # frequency counts on the tags with a percentage/ratio + all_frequencies.update( + [(entry[0], entry[1] / img_count) for entry in frequencies] + ) + + trained_model_id = " ".join( + [ + metadata.get("ss_sd_model_hash", ""), + metadata.get("ss_sd_model_name", ""), + metadata.get("ss_base_model_version", ""), + ] + ).strip() + + # return the topmost of all frequencies in all datasets + return { + "model": trained_model_id, + "frequencies": sorted( + all_frequencies.items(), reverse=True, key=lambda x: x[1] + ), + } diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py new file mode 100644 index 0000000000..b5cf28ce47 --- /dev/null +++ b/apps/shark_studio/modules/img_processing.py @@ -0,0 +1,168 @@ +import os +import sys +from PIL import Image +from pathlib import Path + + +# save output images and the inputs corresponding to it. +def save_output_img(output_img, img_seed, extra_info=None): + if extra_info is None: + extra_info = {} + generated_imgs_path = Path( + get_generated_imgs_path(), get_generated_imgs_todays_subdir() + ) + generated_imgs_path.mkdir(parents=True, exist_ok=True) + csv_path = Path(generated_imgs_path, "imgs_details.csv") + + prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15]) + out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}" + + img_model = cmd_opts.hf_model_id + if cmd_opts.ckpt_loc: + img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem + + img_vae = None + if cmd_opts.custom_vae: + img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem + + img_lora = None + if cmd_opts.use_lora: + img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem + + if cmd_opts.output_img_format == "jpg": + out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") + output_img.save(out_img_path, quality=95, subsampling=0) + else: + out_img_path = Path(generated_imgs_path, f"{out_img_name}.png") + pngInfo = PngImagePlugin.PngInfo() + + if cmd_opts.write_metadata_to_png: + # Using a conditional expression caused problems, so setting a new + # variable for now. + if cmd_opts.use_hiresfix: + png_size_text = ( + f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}" + ) + else: + png_size_text = f"{cmd_opts.width}x{cmd_opts.height}" + + pngInfo.add_text( + "parameters", + f"{cmd_opts.prompts[0]}" + f"\nNegative prompt: {cmd_opts.negative_prompts[0]}" + f"\nSteps: {cmd_opts.steps}," + f"Sampler: {cmd_opts.scheduler}, " + f"CFG scale: {cmd_opts.guidance_scale}, " + f"Seed: {img_seed}," + f"Size: {png_size_text}, " + f"Model: {img_model}, " + f"VAE: {img_vae}, " + f"LoRA: {img_lora}", + ) + + output_img.save(out_img_path, "PNG", pnginfo=pngInfo) + + if cmd_opts.output_img_format not in ["png", "jpg"]: + print( + f"[ERROR] Format {cmd_opts.output_img_format} is not " + f"supported yet. Image saved as png instead." + f"Supported formats: png / jpg" + ) + + # To be as low-impact as possible to the existing CSV format, we append + # "VAE" and "LORA" to the end. However, it does not fit the hierarchy of + # importance for each data point. Something to consider. + new_entry = { + "VARIANT": img_model, + "SCHEDULER": cmd_opts.scheduler, + "PROMPT": cmd_opts.prompts[0], + "NEG_PROMPT": cmd_opts.negative_prompts[0], + "SEED": img_seed, + "CFG_SCALE": cmd_opts.guidance_scale, + "PRECISION": cmd_opts.precision, + "STEPS": cmd_opts.steps, + "HEIGHT": cmd_opts.height + if not cmd_opts.use_hiresfix + else cmd_opts.hiresfix_height, + "WIDTH": cmd_opts.width + if not cmd_opts.use_hiresfix + else cmd_opts.hiresfix_width, + "MAX_LENGTH": cmd_opts.max_length, + "OUTPUT": out_img_path, + "VAE": img_vae, + "LORA": img_lora, + } + + new_entry.update(extra_info) + + csv_mode = "a" if os.path.isfile(csv_path) else "w" + with open(csv_path, csv_mode, encoding="utf-8") as csv_obj: + dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys())) + if csv_mode == "w": + dictwriter_obj.writeheader() + dictwriter_obj.writerow(new_entry) + csv_obj.close() + + if cmd_opts.save_metadata_to_json: + del new_entry["OUTPUT"] + json_path = Path(generated_imgs_path, f"{out_img_name}.json") + with open(json_path, "w") as f: + json.dump(new_entry, f, indent=4) + + +resamplers = { + "Lanczos": Image.Resampling.LANCZOS, + "Nearest Neighbor": Image.Resampling.NEAREST, + "Bilinear": Image.Resampling.BILINEAR, + "Bicubic": Image.Resampling.BICUBIC, + "Hamming": Image.Resampling.HAMMING, + "Box": Image.Resampling.BOX, +} + +resampler_list = resamplers.keys() + + +# For stencil, the input image can be of any size, but we need to ensure that +# it conforms with our model constraints :- +# Both width and height should be in the range of [128, 768] and multiple of 8. +# This utility function performs the transformation on the input image while +# also maintaining the aspect ratio before sending it to the stencil pipeline. +def resize_stencil(image: Image.Image, width, height, resampler_type=None): + aspect_ratio = width / height + min_size = min(width, height) + if min_size < 128: + n_size = 128 + if width == min_size: + width = n_size + height = n_size / aspect_ratio + else: + height = n_size + width = n_size * aspect_ratio + width = int(width) + height = int(height) + n_width = width // 8 + n_height = height // 8 + n_width *= 8 + n_height *= 8 + + min_size = min(width, height) + if min_size > 768: + n_size = 768 + if width == min_size: + height = n_size + width = n_size * aspect_ratio + else: + width = n_size + height = n_size / aspect_ratio + width = int(width) + height = int(height) + n_width = width // 8 + n_height = height // 8 + n_width *= 8 + n_height *= 8 + if resampler_type in resamplers: + resampler = resamplers[resampler_type] + else: + resampler = resamplers["Nearest Neighbor"] + new_image = image.resize((n_width, n_height), resampler=resampler) + return new_image, n_width, n_height diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py new file mode 100644 index 0000000000..c087175de4 --- /dev/null +++ b/apps/shark_studio/modules/pipeline.py @@ -0,0 +1,71 @@ +from shark.iree_utils.compile_utils import get_iree_compiled_module + + +class SharkPipelineBase: + # This class is a lightweight base for managing an + # inference API class. It should provide methods for: + # - compiling a set (model map) of torch IR modules + # - preparing weights for an inference job + # - loading weights for an inference job + # - utilites like benchmarks, tests + + def __init__( + self, + model_map: dict, + device: str, + import_mlir: bool = True, + ): + self.model_map = model_map + self.device = device + self.import_mlir = import_mlir + + def import_torch_ir(self, base_model_id): + for submodel in self.model_map: + hf_id = ( + submodel["custom_hf_id"] + if submodel["custom_hf_id"] + else base_model_id + ) + torch_ir = submodel["initializer"]( + hf_id, **submodel["init_kwargs"], compile_to="torch" + ) + submodel["tempfile_name"] = get_resource_path( + f"{submodel}.torch.tempfile" + ) + with open(submodel["tempfile_name"], "w+") as f: + f.write(torch_ir) + del torch_ir + gc.collect() + + def load_vmfb(self, submodel): + if self.iree_module_dict[submodel]: + print( + f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}" + ) + elif self.model_map[submodel]["tempfile_name"]: + submodel["tempfile_name"] + + return submodel["vmfb"] + + def merge_custom_map(self, custom_model_map): + for submodel in custom_model_map: + for key in submodel: + self.model_map[submodel][key] = key + print(self.model_map) + + def get_compiled_map(self, device) -> None: + # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". + for submodel in self.model_map: + if not self.iree_module_dict[submodel][vmfb]: + self.iree_module_dict[submodel] = get_iree_compiled_module( + submodel.tempfile_name, + device=self.device, + frontend="torch", + ) + # TODO: delete the temp file + + def run(self, submodel, inputs): + return + + def safe_name(name): + return name.replace("/", "_").replace("-", "_") diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py new file mode 100644 index 0000000000..c62646f69c --- /dev/null +++ b/apps/shark_studio/modules/schedulers.py @@ -0,0 +1,30 @@ +# from shark_turbine.turbine_models.schedulers import export_scheduler_model + + +def export_scheduler_model(model): + return "None", "None" + + +scheduler_model_map = { + "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), + "EulerAncestralDiscrete": export_scheduler_model( + "EulerAncestralDiscreteScheduler" + ), + "LCM": export_scheduler_model("LCMScheduler"), + "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), + "PNDM": export_scheduler_model("PNDMScheduler"), + "DDPM": export_scheduler_model("DDPMScheduler"), + "DDIM": export_scheduler_model("DDIMScheduler"), + "DPMSolverMultistep": export_scheduler_model( + "DPMSolverMultistepScheduler" + ), + "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"), + "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"), + "DPMSolverSinglestep": export_scheduler_model( + "DPMSolverSingleStepScheduler" + ), + "KDPM2AncestralDiscrete": export_scheduler_model( + "KDPM2AncestralDiscreteScheduler" + ), + "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"), +} diff --git a/apps/shark_studio/modules/shared.py b/apps/shark_studio/modules/shared.py new file mode 100644 index 0000000000..d9dc3ea26e --- /dev/null +++ b/apps/shark_studio/modules/shared.py @@ -0,0 +1,69 @@ +import sys + +import gradio as gr + +from modules import ( + shared_cmd_options, + shared_gradio, + options, + shared_items, + sd_models_types, +) +from modules.paths_internal import ( + models_path, + script_path, + data_path, + sd_configs_path, + sd_default_config, + sd_model_file, + default_sd_model_file, + extensions_dir, + extensions_builtin_dir, +) # noqa: F401 +from modules import util + +cmd_opts = shared_cmd_options.cmd_opts +parser = shared_cmd_options.parser + +parallel_processing_allowed = True +styles_filename = cmd_opts.styles_file +config_filename = cmd_opts.ui_settings_file + +demo = None + +device = None + +weight_load_location = None + +state = None + +prompt_styles = None + +options_templates = None +opts = None +restricted_opts = None + +sd_model: sd_models_types.WebuiSdModel = None + +settings_components = None +"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings""" + +tab_names = [] + +sd_upscalers = [] + +clip_model = None + +progress_print_out = sys.stdout + +gradio_theme = gr.themes.Base() + +total_tqdm = None + +mem_mon = None + +reload_gradio_theme = shared_gradio.reload_gradio_theme + +list_checkpoint_tiles = shared_items.list_checkpoint_tiles +refresh_checkpoints = shared_items.refresh_checkpoints +list_samplers = shared_items.list_samplers diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py new file mode 100644 index 0000000000..dfb166a52e --- /dev/null +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -0,0 +1,763 @@ +import argparse +import os +from pathlib import Path + +from apps.shark_studio.modules.img_processing import resampler_list + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# Stable Diffusion Params +############################################################################## + +p.add_argument( + "-a", + "--app", + default="txt2img", + help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.", +) +p.add_argument( + "-p", + "--prompts", + nargs="+", + default=[ + "a photo taken of the front of a super-car drifting on a road near " + "mountains at high speeds with smoke coming off the tires, front " + "angle, front point of view, trees in the mountains of the " + "background, ((sharp focus))" + ], + help="Text of which images to be generated.", +) + +p.add_argument( + "--negative_prompts", + nargs="+", + default=[ + "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), " + "blurry, ugly, blur, oversaturated, cropped" + ], + help="Text you don't want to see in the generated image.", +) + +p.add_argument( + "--img_path", + type=str, + help="Path to the image input for img2img/inpainting.", +) + +p.add_argument( + "--steps", + type=int, + default=50, + help="The number of steps to do the sampling.", +) + +p.add_argument( + "--seed", + type=str, + default=-1, + help="The seed or list of seeds to use. -1 for a random one.", +) + +p.add_argument( + "--batch_size", + type=int, + default=1, + choices=range(1, 4), + help="The number of inferences to be made in a single `batch_count`.", +) + +p.add_argument( + "--height", + type=int, + default=512, + choices=range(128, 1025, 8), + help="The height of the output image.", +) + +p.add_argument( + "--width", + type=int, + default=512, + choices=range(128, 1025, 8), + help="The width of the output image.", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="The value to be used for guidance scaling.", +) + +p.add_argument( + "--noise_level", + type=int, + default=20, + help="The value to be used for noise level of upscaler.", +) + +p.add_argument( + "--max_length", + type=int, + default=64, + help="Max length of the tokenizer output, options are 64 and 77.", +) + +p.add_argument( + "--max_embeddings_multiples", + type=int, + default=5, + help="The max multiple length of prompt embeddings compared to the max " + "output length of text encoder.", +) + +p.add_argument( + "--strength", + type=float, + default=0.8, + help="The strength of change applied on the given input image for " + "img2img.", +) + +p.add_argument( + "--use_hiresfix", + type=bool, + default=False, + help="Use Hires Fix to do higher resolution images, while trying to " + "avoid the issues that come with it. This is accomplished by first " + "generating an image using txt2img, then running it through img2img.", +) + +p.add_argument( + "--hiresfix_height", + type=int, + default=768, + choices=range(128, 769, 8), + help="The height of the Hires Fix image.", +) + +p.add_argument( + "--hiresfix_width", + type=int, + default=768, + choices=range(128, 769, 8), + help="The width of the Hires Fix image.", +) + +p.add_argument( + "--hiresfix_strength", + type=float, + default=0.6, + help="The denoising strength to apply for the Hires Fix.", +) + +p.add_argument( + "--resample_type", + type=str, + default="Nearest Neighbor", + choices=resampler_list, + help="The resample type to use when resizing an image before being run " + "through stable diffusion.", +) + +############################################################################## +# Stable Diffusion Training Params +############################################################################## + +p.add_argument( + "--lora_save_dir", + type=str, + default="models/lora/", + help="Directory to save the lora fine tuned model.", +) + +p.add_argument( + "--training_images_dir", + type=str, + default="models/lora/training_images/", + help="Directory containing images that are an example of the prompt.", +) + +p.add_argument( + "--training_steps", + type=int, + default=2000, + help="The number of steps to train.", +) + +############################################################################## +# Inpainting and Outpainting Params +############################################################################## + +p.add_argument( + "--mask_path", + type=str, + help="Path to the mask image input for inpainting.", +) + +p.add_argument( + "--inpaint_full_res", + default=False, + action=argparse.BooleanOptionalAction, + help="If inpaint only masked area or whole picture.", +) + +p.add_argument( + "--inpaint_full_res_padding", + type=int, + default=32, + choices=range(0, 257, 4), + help="Number of pixels for only masked padding.", +) + +p.add_argument( + "--pixels", + type=int, + default=128, + choices=range(8, 257, 8), + help="Number of expended pixels for one direction for outpainting.", +) + +p.add_argument( + "--mask_blur", + type=int, + default=8, + choices=range(0, 65), + help="Number of blur pixels for outpainting.", +) + +p.add_argument( + "--left", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend left for outpainting.", +) + +p.add_argument( + "--right", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend right for outpainting.", +) + +p.add_argument( + "--up", + "--top", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend top for outpainting.", +) + +p.add_argument( + "--down", + "--bottom", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend bottom for outpainting.", +) + +p.add_argument( + "--noise_q", + type=float, + default=1.0, + help="Fall-off exponent for outpainting (lower=higher detail) " + "(min=0.0, max=4.0).", +) + +p.add_argument( + "--color_variation", + type=float, + default=0.05, + help="Color variation for outpainting (min=0.0, max=1.0).", +) + +############################################################################## +# Model Config and Usage Params +############################################################################## + +p.add_argument( + "--device", type=str, default="vulkan", help="Device to run the model." +) + +p.add_argument( + "--precision", type=str, default="fp16", help="Precision to run the model." +) + +p.add_argument( + "--import_mlir", + default=True, + action=argparse.BooleanOptionalAction, + help="Imports the model from torch module to shark_module otherwise " + "downloads the model from shark_tank.", +) + +p.add_argument( + "--use_tuned", + default=False, + action=argparse.BooleanOptionalAction, + help="Download and use the tuned version of the model if available.", +) + +p.add_argument( + "--use_base_vae", + default=False, + action=argparse.BooleanOptionalAction, + help="Do conversion from the VAE output to pixel space on cpu.", +) + +p.add_argument( + "--scheduler", + type=str, + default="SharkEulerDiscrete", + help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, " + "DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, " + "DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, " + "DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, " + "HeunDiscrete].", +) + +p.add_argument( + "--output_img_format", + type=str, + default="png", + help="Specify the format in which output image is save. " + "Supported options: jpg / png.", +) + +p.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory path to save the output images and json.", +) + +p.add_argument( + "--batch_count", + type=int, + default=1, + help="Number of batches to be generated with random seeds in " + "single execution.", +) + +p.add_argument( + "--repeatable_seeds", + default=False, + action=argparse.BooleanOptionalAction, + help="The seed of the first batch will be used as the rng seed to " + "generate the subsequent seeds for subsequent batches in that run.", +) + +p.add_argument( + "--ckpt_loc", + type=str, + default="", + help="Path to SD's .ckpt file.", +) + +p.add_argument( + "--custom_vae", + type=str, + default="", + help="HuggingFace repo-id or path to SD model's checkpoint whose VAE " + "needs to be plugged in.", +) + +p.add_argument( + "--hf_model_id", + type=str, + default="stabilityai/stable-diffusion-2-1-base", + help="The repo-id of hugging face.", +) + +p.add_argument( + "--low_cpu_mem_usage", + default=False, + action=argparse.BooleanOptionalAction, + help="Use the accelerate package to reduce cpu memory consumption.", +) + +p.add_argument( + "--attention_slicing", + type=str, + default="none", + help="Amount of attention slicing to use (one of 'max', 'auto', 'none', " + "or an integer).", +) + +p.add_argument( + "--use_stencil", + choices=["canny", "openpose", "scribble", "zoedepth"], + help="Enable the stencil feature.", +) + +p.add_argument( + "--control_mode", + choices=["Prompt", "Balanced", "Controlnet"], + default="Balanced", + help="How Controlnet injection should be prioritized.", +) + +p.add_argument( + "--use_lora", + type=str, + default="", + help="Use standalone LoRA weight using a HF ID or a checkpoint " + "file (~3 MB).", +) + +p.add_argument( + "--use_quantize", + type=str, + default="none", + help="Runs the quantized version of stable diffusion model. " + "This is currently in experimental phase. " + "Currently, only runs the stable-diffusion-2-1-base model in " + "int8 quantization.", +) + +p.add_argument( + "--lowvram", + default=False, + action=argparse.BooleanOptionalAction, + help="Load and unload models for low VRAM.", +) + +p.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="Specify your own huggingface authentication tokens for models like Llama2.", +) + +p.add_argument( + "--device_allocator_heap_key", + type=str, + default="", + help="Specify heap key for device caching allocator." + "Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count" + "Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)", +) + +p.add_argument( + "--custom_model_map", + type=str, + default="", + help="path to custom model map to import. This should be a .json file", +) +############################################################################## +# IREE - Vulkan supported flags +############################################################################## + +p.add_argument( + "--iree_vulkan_target_triple", + type=str, + default="", + help="Specify target triple for vulkan.", +) + +p.add_argument( + "--iree_metal_target_platform", + type=str, + default="", + help="Specify target triple for metal.", +) + +############################################################################## +# Misc. Debug and Optimization flags +############################################################################## + +p.add_argument( + "--use_compiled_scheduler", + default=True, + action=argparse.BooleanOptionalAction, + help="Use the default scheduler precompiled into the model if available.", +) + +p.add_argument( + "--local_tank_cache", + default="", + help="Specify where to save downloaded shark_tank artifacts. " + "If this is not set, the default is ~/.local/shark_tank/.", +) + +p.add_argument( + "--dump_isa", + default=False, + action="store_true", + help="When enabled call amdllpc to get ISA dumps. " + "Use with dispatch benchmarks.", +) + +p.add_argument( + "--dispatch_benchmarks", + default=None, + help="Dispatches to return benchmark data on. " + 'Use "All" for all, and None for none.', +) + +p.add_argument( + "--dispatch_benchmarks_dir", + default="temp_dispatch_benchmarks", + help="Directory where you want to store dispatch data " + 'generated with "--dispatch_benchmarks".', +) + +p.add_argument( + "--enable_rgp", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for inserting debug frames between iterations " + "for use with rgp.", +) + +p.add_argument( + "--hide_steps", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for hiding the details of iteration/sec for each step.", +) + +p.add_argument( + "--warmup_count", + type=int, + default=0, + help="Flag setting warmup count for CLIP and VAE [>= 0].", +) + +p.add_argument( + "--clear_all", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag to clear all mlir and vmfb from common locations. " + "Recompiling will take several minutes.", +) + +p.add_argument( + "--save_metadata_to_json", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for whether or not to save a generation information " + "json file with the image.", +) + +p.add_argument( + "--write_metadata_to_png", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for whether or not to save generation information in " + "PNG chunk text to generated images.", +) + +p.add_argument( + "--import_debug", + default=False, + action=argparse.BooleanOptionalAction, + help="If import_mlir is True, saves mlir via the debug option " + "in shark importer. Does nothing if import_mlir is false (the default).", +) + +p.add_argument( + "--compile_debug", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag to toggle debug assert/verify flags for imported IR in the" + "iree-compiler. Default to false.", +) + +p.add_argument( + "--iree_constant_folding", + default=True, + action=argparse.BooleanOptionalAction, + help="Controls constant folding in iree-compile for all SD models.", +) + +p.add_argument( + "--data_tiling", + default=False, + action=argparse.BooleanOptionalAction, + help="Controls data tiling in iree-compile for all SD models.", +) + +############################################################################## +# Web UI flags +############################################################################## + +p.add_argument( + "--webui", + default=True, + action=argparse.BooleanOptionalAction, + help="controls whether the webui is launched.", +) + +p.add_argument( + "--progress_bar", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for removing the progress bar animation during " + "image generation.", +) + +p.add_argument( + "--ckpt_dir", + type=str, + default="", + help="Path to directory where all .ckpts are stored in order to populate " + "them in the web UI.", +) +# TODO: replace API flag when these can be run together +p.add_argument( + "--ui", + type=str, + default="app" if os.name == "nt" else "web", + help="One of: [api, app, web].", +) + +p.add_argument( + "--share", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for generating a public URL.", +) + +p.add_argument( + "--server_port", + type=int, + default=8080, + help="Flag for setting server port.", +) + +p.add_argument( + "--api", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for enabling rest API.", +) + +p.add_argument( + "--api_accept_origin", + action="append", + type=str, + help="An origin to be accepted by the REST api for Cross Origin" + "Resource Sharing (CORS). Use multiple times for multiple origins, " + 'or use --api_accept_origin="*" to accept all origins. If no origins ' + "are set no CORS headers will be returned by the api. Use, for " + "instance, if you need to access the REST api from Javascript running " + "in a web browser.", +) + +p.add_argument( + "--debug", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for enabling debugging log in WebUI.", +) + +p.add_argument( + "--output_gallery", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for removing the output gallery tab, and avoid exposing " + "images under --output_dir in the UI.", +) + +p.add_argument( + "--output_gallery_followlinks", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for whether the output gallery tab in the UI should " + "follow symlinks when listing subdirectories under --output_dir.", +) + + +############################################################################## +# SD model auto-annotation flags +############################################################################## + +p.add_argument( + "--annotation_output", + type=path_expand, + default="./", + help="Directory to save the annotated mlir file.", +) + +p.add_argument( + "--annotation_model", + type=str, + default="unet", + help="Options are unet and vae.", +) + +p.add_argument( + "--save_annotation", + default=False, + action=argparse.BooleanOptionalAction, + help="Save annotated mlir file.", +) +############################################################################## +# SD model auto-tuner flags +############################################################################## + +p.add_argument( + "--tuned_config_dir", + type=path_expand, + default="./", + help="Directory to save the tuned config file.", +) + +p.add_argument( + "--num_iters", + type=int, + default=400, + help="Number of iterations for tuning.", +) + +p.add_argument( + "--search_op", + type=str, + default="all", + help="Op to be optimized, options are matmul, bmm, conv and all.", +) + +############################################################################## +# DocuChat Flags +############################################################################## + +p.add_argument( + "--run_docuchat_web", + default=False, + action=argparse.BooleanOptionalAction, + help="Specifies whether the docuchat's web version is running or not.", +) + +############################################################################## +# rocm Flags +############################################################################## + +p.add_argument( + "--iree_rocm_target_chip", + type=str, + default="", + help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` " + "or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name", +) + +cmd_opts, unknown = p.parse_known_args() +if cmd_opts.import_debug: + os.environ["IREE_SAVE_TEMPS"] = os.path.join( + os.getcwd(), cmd_opts.hf_model_id.replace("/", "_") + ) diff --git a/apps/shark_studio/modules/timer.py b/apps/shark_studio/modules/timer.py new file mode 100644 index 0000000000..8fd1e6a7df --- /dev/null +++ b/apps/shark_studio/modules/timer.py @@ -0,0 +1,111 @@ +import time +import argparse + + +class TimerSubcategory: + def __init__(self, timer, category): + self.timer = timer + self.category = category + self.start = None + self.original_base_category = timer.base_category + + def __enter__(self): + self.start = time.time() + self.timer.base_category = ( + self.original_base_category + self.category + "/" + ) + self.timer.subcategory_level += 1 + + if self.timer.print_log: + print(f"{' ' * self.timer.subcategory_level}{self.category}:") + + def __exit__(self, exc_type, exc_val, exc_tb): + elapsed_for_subcategroy = time.time() - self.start + self.timer.base_category = self.original_base_category + self.timer.add_time_to_record( + self.original_base_category + self.category, + elapsed_for_subcategroy, + ) + self.timer.subcategory_level -= 1 + self.timer.record(self.category, disable_log=True) + + +class Timer: + def __init__(self, print_log=False): + self.start = time.time() + self.records = {} + self.total = 0 + self.base_category = "" + self.print_log = print_log + self.subcategory_level = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def add_time_to_record(self, category, amount): + if category not in self.records: + self.records[category] = 0 + + self.records[category] += amount + + def record(self, category, extra_time=0, disable_log=False): + e = self.elapsed() + + self.add_time_to_record(self.base_category + category, e + extra_time) + + self.total += e + extra_time + + if self.print_log and not disable_log: + print( + f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s" + ) + + def subcategory(self, name): + self.elapsed() + + subcat = TimerSubcategory(self, name) + return subcat + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [ + (category, time_taken) + for category, time_taken in self.records.items() + if time_taken >= 0.1 and "/" not in category + ] + if not additions: + return res + + res += " (" + res += ", ".join( + [ + f"{category}: {time_taken:.1f}s" + for category, time_taken in additions + ] + ) + res += ")" + + return res + + def dump(self): + return {"total": self.total, "records": self.records} + + def reset(self): + self.__init__() + + +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument( + "--log-startup", + action="store_true", + help="print a detailed log of what's happening at startup", +) +args = parser.parse_known_args()[0] + +startup_timer = Timer(print_log=args.log_startup) + +startup_record = None diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py new file mode 100644 index 0000000000..80399505c4 --- /dev/null +++ b/apps/shark_studio/web/api/compat.py @@ -0,0 +1,310 @@ +import base64 +import io +import os +import time +import datetime +import uvicorn +import ipaddress +import requests +import gradio as gr +from threading import Lock +from io import BytesIO +from fastapi import APIRouter, Depends, FastAPI, Request, Response +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.exceptions import HTTPException +from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder + +from apps.shark_studio.modules.img_processing import sampler_list +from sdapi_v1 import shark_sd_api +from api.llm import chat_api + + +def decode_base64_to_image(encoding): + if encoding.startswith("http://") or encoding.startswith("https://"): + if not opts.api_enable_requests: + raise HTTPException(status_code=500, detail="Requests not allowed") + + if opts.api_forbid_local_requests and not verify_url(encoding): + raise HTTPException( + status_code=500, detail="Request to local resource not allowed" + ) + + headers = ( + {"user-agent": opts.api_useragent} if opts.api_useragent else {} + ) + response = requests.get(encoding, timeout=30, headers=headers) + try: + image = Image.open(BytesIO(response.content)) + return image + except Exception as e: + raise HTTPException( + status_code=500, detail="Invalid image url" + ) from e + + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + raise HTTPException( + status_code=500, detail="Invalid encoded image" + ) from e + + +def encode_pil_to_base64(image): + with io.BytesIO() as output_bytes: + if opts.samples_format.lower() == "png": + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save( + output_bytes, + format="PNG", + pnginfo=(metadata if use_metadata else None), + quality=opts.jpeg_quality, + ) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + if image.mode == "RGBA": + image = image.convert("RGB") + parameters = image.info.get("parameters", None) + exif_bytes = piexif.dump( + { + "Exif": { + piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump( + parameters or "", encoding="unicode" + ) + } + } + ) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save( + output_bytes, + format="JPEG", + exif=exif_bytes, + quality=opts.jpeg_quality, + ) + else: + image.save( + output_bytes, + format="WEBP", + exif=exif_bytes, + quality=opts.jpeg_quality, + ) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") + + bytes_data = output_bytes.getvalue() + + return base64.b64encode(bytes_data) + + +def api_middleware(app: FastAPI): + rich_available = False + try: + if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + + console = Console() + rich_available = True + except Exception: + pass + + @app.middleware("http") + async def log_and_time(req: Request, call_next): + ts = time.time() + res: Response = await call_next(req) + duration = str(round(time.time() - ts, 4)) + res.headers["X-Process-Time"] = duration + endpoint = req.scope.get("path", "err") + if shared.cmd_opts.api_log and endpoint.startswith("/sdapi"): + print( + "API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format( + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get("http_version", "0.0"), + cli=req.scope.get("client", ("0:0.0.0", 0))[0], + prot=req.scope.get("scheme", "err"), + method=req.scope.get("method", "err"), + endpoint=endpoint, + duration=duration, + ) + ) + return res + + def handle_exception(request: Request, e: Exception): + err = { + "error": type(e).__name__, + "detail": vars(e).get("detail", ""), + "body": vars(e).get("body", ""), + "errors": str(e), + } + if not isinstance( + e, HTTPException + ): # do not print backtrace on known httpexceptions + message = f"API error: {request.method}: {request.url} {err}" + if rich_available: + print(message) + console.print_exception( + show_locals=True, + max_frames=2, + extra_lines=1, + suppress=[anyio, starlette], + word_wrap=False, + width=min([console.width, 200]), + ) + else: + errors.report(message, exc_info=True) + return JSONResponse( + status_code=vars(e).get("status_code", 500), + content=jsonable_encoder(err), + ) + + @app.middleware("http") + async def exception_handling(request: Request, call_next): + try: + return await call_next(request) + except Exception as e: + return handle_exception(request, e) + + @app.exception_handler(Exception) + async def fastapi_exception_handler(request: Request, e: Exception): + return handle_exception(request, e) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, e: HTTPException): + return handle_exception(request, e) + + +class ApiCompat: + def __init__(self, queue_lock: Lock): + self.router = APIRouter() + self.app = FastAPI() + self.queue_lock = queue_lock + api_middleware(self.app) + self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["post"]) + self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["post"]) + # self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"]) + # self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) + # self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) + # self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) + # self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) + # self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) + # self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) + # self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + # self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) + # self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) + # self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + # self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) + # self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + # self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) + # self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) + # self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) + # self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) + # self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + # self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) + # self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + # self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) + # self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) + # self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) + # self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) + # self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + + # chat APIs needed for compatibility with multiple extensions using OpenAI API + self.add_api_route("/v1/chat/completions", chat_api, methods=["post"]) + self.add_api_route("/v1/completions", chat_api, methods=["post"]) + self.add_api_route("/chat/completions", chat_api, methods=["post"]) + self.add_api_route("/completions", chat_api, methods=["post"]) + self.add_api_route( + "/v1/engines/codegen/completions", chat_api, methods=["post"] + ) + if studio.cmd_opts.api_server_stop: + self.add_api_route( + "/sdapi/v1/server-kill", self.kill_studio, methods=["POST"] + ) + self.add_api_route( + "/sdapi/v1/server-restart", + self.restart_studio, + methods=["POST"], + ) + self.add_api_route( + "/sdapi/v1/server-stop", self.stop_studio, methods=["POST"] + ) + + self.default_script_arg_txt2img = [] + self.default_script_arg_img2img = [] + + def add_api_route(self, path: str, endpoint, **kwargs): + if studio.cmd_opts.api_auth: + return self.app.add_api_route( + path, endpoint, dependencies=[Depends(self.auth)], **kwargs + ) + return self.app.add_api_route(path, endpoint, **kwargs) + + def refresh_checkpoints(self): + with self.queue_lock: + studio_data.refresh_checkpoints() + + def refresh_vae(self): + with self.queue_lock: + studio_data.refresh_vae_list() + + def unloadapi(self): + unload_model_weights() + + return {} + + def reloadapi(self): + reload_model_weights() + + return {} + + def skip(self): + studio.state.skip() + + def launch(self, server_name, port, root_path): + self.app.include_router(self.router) + uvicorn.run( + self.app, + host=server_name, + port=port, + timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, + root_path=root_path, + ) + + def kill_studio(self): + restart.stop_program() + + def restart_studio(self): + if restart.is_restartable(): + restart.restart_program() + return Response(status_code=501) + + def preprocess(self, args: dict): + try: + studio.state.begin(job="preprocess") + preprocess(**args) + studio.state.end() + return models.PreprocessResponse(info="preprocess complete") + except: + studio.state.end() + + def stop_studio(request): + studio.state.server_command = "stop" + return Response("Stopping.") diff --git a/apps/shark_studio/web/api/sd.py b/apps/shark_studio/web/api/sd.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/apps/shark_studio/web/api/sd.py @@ -0,0 +1 @@ + diff --git a/apps/shark_studio/web/configs/foo.json b/apps/shark_studio/web/configs/foo.json new file mode 100644 index 0000000000..0967ef424b --- /dev/null +++ b/apps/shark_studio/web/configs/foo.json @@ -0,0 +1 @@ +{} diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 3ef6bc5739..58b0c6c00b 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -1,20 +1,58 @@ from multiprocessing import Process, freeze_support import os +import time import sys import logging +import apps.shark_studio.api.initializers as initialize + from ui.chat import chat_element +from ui.sd import sd_element +from ui.outputgallery import outputgallery_element + +from apps.shark_studio.modules import timer + +startup_timer = timer.startup_timer +startup_timer.record("launcher") + +initialize.imports() if sys.platform == "darwin": os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib" # import before IREE to avoid MLIR library issues import torch_mlir -# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation -# from apps.stable_diffusion.src import args, clear_all -# import apps.stable_diffusion.web.utils.global_obj as global_obj +def create_api(app): + from apps.shark_studio.api.compat import ApiCompat + from modules.call_queue import queue_lock + + api = ApiCompat(app, queue_lock) + return api + + +def api_only(): + from fastapi import FastAPI + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + + initialize.initialize() + + app = FastAPI() + initialize.setup_middleware(app) + api = create_api(app) + + # from modules import script_callbacks + # script_callbacks.before_ui_callback() + # script_callbacks.app_started_callback(None, app) -def launch_app(address): + print(f"Startup time: {startup_timer.summary()}.") + api.launch( + server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + port=cmd_opts.port if cmd_opts.port else 8080, + root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "", + ) + + +def launch_webui(address): from tkinter import Tk import webview @@ -34,62 +72,81 @@ def launch_app(address): webview.start(private_mode=False, storage_path=os.getcwd()) -if __name__ == "__main__": - # if args.debug: +def webui(): + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + logging.basicConfig(level=logging.DEBUG) + + launch_api = cmd_opts.api + initialize.initialize() + # required to do multiprocessing in a pyinstaller freeze freeze_support() - # if args.api or "api" in args.ui.split(","): - # from apps.stable_diffusion.web.ui import ( - # txt2img_api, - # img2img_api, - # upscaler_api, - # inpaint_api, - # outpaint_api, - # llm_chat_api, - # ) + + # if args.api or "api" in args.ui.split(","): + # from apps.shark_studio.api.llm import ( + # chat, + # ) + # from apps.shark_studio.web.api import sdapi + # + # from fastapi import FastAPI, APIRouter + # from fastapi.middleware.cors import CORSMiddleware + # import uvicorn # - # from fastapi import FastAPI, APIRouter - # import uvicorn + # # init global sd pipeline and config + # global_obj._init() # - # # init global sd pipeline and config - # global_obj._init() + # api = FastAPI() + # api.mount("/sdapi/", sdapi) # - # app = FastAPI() - # app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"]) + # # chat APIs needed for compatibility with multiple extensions using OpenAI API + # api.add_api_route( + # "/v1/chat/completions", llm_chat_api, methods=["post"] + # ) + # api.add_api_route("/v1/completions", llm_chat_api, methods=["post"]) + # api.add_api_route("/chat/completions", llm_chat_api, methods=["post"]) + # api.add_api_route("/completions", llm_chat_api, methods=["post"]) + # api.add_api_route( + # "/v1/engines/codegen/completions", llm_chat_api, methods=["post"] + # ) + # api.include_router(APIRouter()) # - # # chat APIs needed for compatibility with multiple extensions using OpenAI API - # app.add_api_route( - # "/v1/chat/completions", llm_chat_api, methods=["post"] - # ) - # app.add_api_route("/v1/completions", llm_chat_api, methods=["post"]) - # app.add_api_route("/chat/completions", llm_chat_api, methods=["post"]) - # app.add_api_route("/completions", llm_chat_api, methods=["post"]) - # app.add_api_route( - # "/v1/engines/codegen/completions", llm_chat_api, methods=["post"] - # ) - # app.include_router(APIRouter()) - # uvicorn.run(app, host="0.0.0.0", port=args.server_port) - # sys.exit(0) + # # deal with CORS requests if CORS accept origins are set + # if args.api_accept_origin: + # print( + # f"API Configured for CORS. Accepting origins: { args.api_accept_origin }" + # ) + # api.add_middleware( + # CORSMiddleware, + # allow_origins=args.api_accept_origin, + # allow_methods=["GET", "POST"], + # allow_headers=["*"], + # ) + # else: + # print("API not configured for CORS") # + # uvicorn.run(api, host="0.0.0.0", port=args.server_port) + # sys.exit(0) # Setup to use shark_tmp for gradio's temporary image files and clear any # existing temporary images there if they exist. Then we can import gradio. # It has to be in this order or gradio ignores what we've set up. - # from apps.stable_diffusion.web.utils.gradio_configs import ( - # config_gradio_tmp_imgs_folder, - # ) + from apps.shark_studio.web.utils.tmp_configs import ( + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, + ) + from apps.shark_studio.api.utils import ( + create_checkpoint_folders, + ) - # config_gradio_tmp_imgs_folder() import gradio as gr - # Create custom models folders if they don't exist - # from apps.stable_diffusion.web.ui.utils import create_custom_models_folders + config_tmp() + clear_tmp_mlir() + clear_tmp_imgs() - # create_custom_models_folders() + # Create custom models folders if they don't exist + create_checkpoint_folders() def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" @@ -98,74 +155,7 @@ def resource_path(relative_path): dark_theme = resource_path("ui/css/sd_dark_theme.css") - # from apps.stable_diffusion.web.ui import ( - # txt2img_web, - # txt2img_custom_model, - # txt2img_gallery, - # txt2img_png_info_img, - # txt2img_status, - # txt2img_sendto_img2img, - # txt2img_sendto_inpaint, - # txt2img_sendto_outpaint, - # txt2img_sendto_upscaler, - ## h2ogpt_upload, - ## h2ogpt_web, - # img2img_web, - # img2img_custom_model, - # img2img_gallery, - # img2img_init_image, - # img2img_status, - # img2img_sendto_inpaint, - # img2img_sendto_outpaint, - # img2img_sendto_upscaler, - # inpaint_web, - # inpaint_custom_model, - # inpaint_gallery, - # inpaint_init_image, - # inpaint_status, - # inpaint_sendto_img2img, - # inpaint_sendto_outpaint, - # inpaint_sendto_upscaler, - # outpaint_web, - # outpaint_custom_model, - # outpaint_gallery, - # outpaint_init_image, - # outpaint_status, - # outpaint_sendto_img2img, - # outpaint_sendto_inpaint, - # outpaint_sendto_upscaler, - # upscaler_web, - # upscaler_custom_model, - # upscaler_gallery, - # upscaler_init_image, - # upscaler_status, - # upscaler_sendto_img2img, - # upscaler_sendto_inpaint, - # upscaler_sendto_outpaint, - ## lora_train_web, - ## model_web, - ## model_config_web, - # hf_models, - # modelmanager_sendto_txt2img, - # modelmanager_sendto_img2img, - # modelmanager_sendto_inpaint, - # modelmanager_sendto_outpaint, - # modelmanager_sendto_upscaler, - # stablelm_chat, - # minigpt4_web, - # outputgallery_web, - # outputgallery_tab_select, - # outputgallery_watch, - # outputgallery_filename, - # outputgallery_sendto_txt2img, - # outputgallery_sendto_img2img, - # outputgallery_sendto_inpaint, - # outputgallery_sendto_outpaint, - # outputgallery_sendto_upscaler, - # ) - - # init global sd pipeline and config - # global_obj._init() + # from apps.shark_studio.web.ui import load_ui_from_script def register_button_click(button, selectedid, inputs, outputs): button.click( @@ -177,17 +167,6 @@ def register_button_click(button, selectedid, inputs, outputs): outputs, ) - def register_modelmanager_button(button, selectedid, inputs, outputs): - button.click( - lambda x: ( - "None", - x, - gr.Tabs.update(selected=selectedid), - ), - inputs, - outputs, - ) - def register_outputgallery_button(button, selectedid, inputs, outputs): button.click( lambda x: ( @@ -200,7 +179,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): with gr.Blocks( css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta" - ) as sd_web: + ) as studio_web: with gr.Tabs() as tabs: # NOTE: If adding, removing, or re-ordering tabs, make sure that they # have a unique id that doesn't clash with any of the other tabs, @@ -211,216 +190,31 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): # destination of one of the 'send to' buttons. If you do have to change # that id, make sure you update the relevant register_button_click calls # further down with the new id. - # with gr.TabItem(label="Text-to-Image", id=0): - # txt2img_web.render() - # with gr.TabItem(label="Image-to-Image", id=1): - # img2img_web.render() - # with gr.TabItem(label="Inpainting", id=2): - # inpaint_web.render() - # with gr.TabItem(label="Outpainting", id=3): - # outpaint_web.render() - # with gr.TabItem(label="Upscaler", id=4): - # upscaler_web.render() - # if args.output_gallery: - # with gr.TabItem(label="Output Gallery", id=5) as og_tab: - # outputgallery_web.render() - - # # extra output gallery configuration - # outputgallery_tab_select(og_tab.select) - # outputgallery_watch( - # [ - # txt2img_status, - # img2img_status, - # inpaint_status, - # outpaint_status, - # upscaler_status, - # ] - # ) - ## with gr.TabItem(label="Model Manager", id=6): - ## model_web.render() - ## with gr.TabItem(label="LoRA Training (Experimental)", id=7): - ## lora_train_web.render() - with gr.TabItem(label="Chat Bot", id=0): + with gr.TabItem(label="Stable Diffusion", id=0): + sd_element.render() + with gr.TabItem(label="Output Gallery", id=1): + outputgallery_element.render() + with gr.TabItem(label="Chat Bot", id=2): chat_element.render() - ## with gr.TabItem( - ## label="Generate Sharding Config (Experimental)", id=9 - ## ): - ## model_config_web.render() - # with gr.TabItem(label="MultiModal (Experimental)", id=10): - # minigpt4_web.render() - # with gr.TabItem(label="DocuChat Upload", id=11): - # h2ogpt_upload.render() - # with gr.TabItem(label="DocuChat(Experimental)", id=12): - # h2ogpt_web.render() - - # send to buttons - # register_button_click( - # txt2img_sendto_img2img, - # 1, - # [txt2img_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # txt2img_sendto_inpaint, - # 2, - # [txt2img_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # txt2img_sendto_outpaint, - # 3, - # [txt2img_gallery], - # [outpaint_init_image, tabs], - # ) - # register_button_click( - # txt2img_sendto_upscaler, - # 4, - # [txt2img_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # img2img_sendto_inpaint, - # 2, - # [img2img_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # img2img_sendto_outpaint, - # 3, - # [img2img_gallery], - # [outpaint_init_image, tabs], - # ) - # register_button_click( - # img2img_sendto_upscaler, - # 4, - # [img2img_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # inpaint_sendto_img2img, - # 1, - # [inpaint_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # inpaint_sendto_outpaint, - # 3, - # [inpaint_gallery], - # [outpaint_init_image, tabs], - # ) - # register_button_click( - # inpaint_sendto_upscaler, - # 4, - # [inpaint_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # outpaint_sendto_img2img, - # 1, - # [outpaint_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # outpaint_sendto_inpaint, - # 2, - # [outpaint_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # outpaint_sendto_upscaler, - # 4, - # [outpaint_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # upscaler_sendto_img2img, - # 1, - # [upscaler_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # upscaler_sendto_inpaint, - # 2, - # [upscaler_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # upscaler_sendto_outpaint, - # 3, - # [upscaler_gallery], - # [outpaint_init_image, tabs], - # ) - # if args.output_gallery: - # register_outputgallery_button( - # outputgallery_sendto_txt2img, - # 0, - # [outputgallery_filename], - # [txt2img_png_info_img, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_img2img, - # 1, - # [outputgallery_filename], - # [img2img_init_image, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_inpaint, - # 2, - # [outputgallery_filename], - # [inpaint_init_image, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_outpaint, - # 3, - # [outputgallery_filename], - # [outpaint_init_image, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_upscaler, - # 4, - # [outputgallery_filename], - # [upscaler_init_image, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_txt2img, - # 0, - # [hf_models], - # [txt2img_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_img2img, - # 1, - # [hf_models], - # [img2img_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_inpaint, - # 2, - # [hf_models], - # [inpaint_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_outpaint, - # 3, - # [hf_models], - # [outpaint_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_upscaler, - # 4, - # [hf_models], - # [upscaler_custom_model, tabs], - # ) - - sd_web.queue() + + studio_web.queue() # if args.ui == "app": # t = Process( # target=launch_app, args=[f"http://localhost:{args.server_port}"] # ) # t.start() - sd_web.launch( + studio_web.launch( share=True, inbrowser=True, server_name="0.0.0.0", server_port=11911, # args.server_port, ) + + +if __name__ == "__main__": + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + + if cmd_opts.webui == False: + api_only() + else: + webui() diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 4726eef6e8..3a374eb5e2 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -240,9 +240,11 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): - config_file = gr.File(label="Upload sharding configuration", visible=False) - json_view_button = gr.Button(label="View as JSON", visible=False) - json_view = gr.JSON(interactive=True, visible=False) + config_file = gr.File( + label="Upload sharding configuration", visible=False + ) + json_view_button = gr.Button("View as JSON", visible=False) + json_view = gr.JSON(visible=False) json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) diff --git a/apps/shark_studio/web/ui/common_events.py b/apps/shark_studio/web/ui/common_events.py new file mode 100644 index 0000000000..37555ed7ee --- /dev/null +++ b/apps/shark_studio/web/ui/common_events.py @@ -0,0 +1,55 @@ +from apps.shark_studio.web.ui.utils import ( + HSLHue, + hsl_color, +) +from apps.shark_studio.modules.embeddings import get_lora_metadata + + +# Answers HTML to show the most frequent tags used when a LoRA was trained, +# taken from the metadata of its .safetensors file. +def lora_changed(lora_file): + # tag frequency percentage, that gets maximum amount of the staring hue + TAG_COLOR_THRESHOLD = 0.55 + # tag frequency percentage, above which a tag is displayed + TAG_DISPLAY_THRESHOLD = 0.65 + # template for the html used to display a tag + TAG_HTML_TEMPLATE = '{tag}' + + if lora_file == "None": + return ["
No LoRA selected
"] + elif not lora_file.lower().endswith(".safetensors"): + return [ + "
Only metadata queries for .safetensors files are currently supported
" + ] + else: + metadata = get_lora_metadata(lora_file) + if metadata: + frequencies = metadata["frequencies"] + return [ + "".join( + [ + f'
Trained against weights in: {metadata["model"]}
' + ] + + [ + TAG_HTML_TEMPLATE.format( + color=hsl_color( + (tag[1] - TAG_COLOR_THRESHOLD) + / (1 - TAG_COLOR_THRESHOLD), + start=HSLHue.RED, + end=HSLHue.GREEN, + ), + tag=tag[0], + ) + for tag in frequencies + if tag[1] > TAG_DISPLAY_THRESHOLD + ], + ) + ] + elif metadata is None: + return [ + "
This LoRA does not publish tag frequency metadata
" + ] + else: + return [ + "
This LoRA has empty tag frequency metadata, or we could not parse it
" + ] diff --git a/apps/shark_studio/web/ui/css/sd_dark_theme.css b/apps/shark_studio/web/ui/css/sd_dark_theme.css new file mode 100644 index 0000000000..5686f0868c --- /dev/null +++ b/apps/shark_studio/web/ui/css/sd_dark_theme.css @@ -0,0 +1,324 @@ +/* +Apply Gradio dark theme to the default Gradio theme. +Procedure to upgrade the dark theme: +- Using your browser, visit http://localhost:8080/?__theme=dark +- Open your browser inspector, search for the .dark css class +- Copy .dark class declarations, apply them here into :root +*/ + +:root { + --body-background-fill: var(--background-fill-primary); + --body-text-color: var(--neutral-100); + --color-accent-soft: var(--neutral-700); + --background-fill-primary: var(--neutral-950); + --background-fill-secondary: var(--neutral-900); + --border-color-accent: var(--neutral-600); + --border-color-primary: var(--neutral-700); + --link-text-color-active: var(--secondary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --body-text-color-subdued: var(--neutral-400); + --shadow-spread: 1px; + --block-background-fill: var(--neutral-800); + --block-border-color: var(--border-color-primary); + --block_border_width: None; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: None; + --block-label-text-color: var(--neutral-200); + --block_shadow: None; + --block_title_background_fill: None; + --block_title_border_color: None; + --block_title_border_width: None; + --block-title-text-color: var(--neutral-200); + --panel-background-fill: var(--background-fill-secondary); + --panel-border-color: var(--border-color-primary); + --panel_border_width: None; + --checkbox-background-color: var(--neutral-800); + --checkbox-background-color-focus: var(--checkbox-background-color); + --checkbox-background-color-hover: var(--checkbox-background-color); + --checkbox-background-color-selected: var(--secondary-600); + --checkbox-border-color: var(--neutral-700); + --checkbox-border-color-focus: var(--secondary-500); + --checkbox-border-color-hover: var(--neutral-600); + --checkbox-border-color-selected: var(--secondary-600); + --checkbox-border-width: var(--input-border-width); + --checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); + --checkbox-label-border-color: var(--border-color-primary); + --checkbox-label-border-color-hover: var(--checkbox-label-border-color); + --checkbox-label-border-width: var(--input-border-width); + --checkbox-label-text-color: var(--body-text-color); + --checkbox-label-text-color-selected: var(--checkbox-label-text-color); + --error-background-fill: var(--background-fill-primary); + --error-border-color: var(--border-color-primary); + --error_border_width: None; + --error-text-color: #ef4444; + --input-background-fill: var(--neutral-800); + --input-background-fill-focus: var(--secondary-600); + --input-background-fill-hover: var(--input-background-fill); + --input-border-color: var(--border-color-primary); + --input-border-color-focus: var(--neutral-700); + --input-border-color-hover: var(--input-border-color); + --input_border_width: None; + --input-placeholder-color: var(--neutral-500); + --input_shadow: None; + --input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset); + --loader_color: None; + --slider_color: None; + --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600)); + --table-border-color: var(--neutral-700); + --table-even-background-fill: var(--neutral-950); + --table-odd-background-fill: var(--neutral-900); + --table-row-focus: var(--color-accent-soft); + --button-border-width: var(--input-border-width); + --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); + --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); + --button-cancel-border-color: #dc2626; + --button-cancel-border-color-hover: var(--button-cancel-border-color); + --button-cancel-text-color: white; + --button-cancel-text-color-hover: var(--button-cancel-text-color); + --button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600)); + --button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500)); + --button-primary-border-color: var(--primary-500); + --button-primary-border-color-hover: var(--button-primary-border-color); + --button-primary-text-color: white; + --button-primary-text-color-hover: var(--button-primary-text-color); + --button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700)); + --button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600)); + --button-secondary-border-color: var(--neutral-600); + --button-secondary-border-color-hover: var(--button-secondary-border-color); + --button-secondary-text-color: white; + --button-secondary-text-color-hover: var(--button-secondary-text-color); + --block-border-width: 1px; + --block-label-border-width: 1px; + --form-gap-width: 1px; + --error-border-width: 1px; + --input-border-width: 1px; +} + +/* SHARK theme */ +body { + background-color: var(--background-fill-primary); +} + +.generating.svelte-zlszon.svelte-zlszon { + border: none; +} + +.generating { + border: none !important; +} + +#chatbot { + height: 100% !important; +} + +/* display in full width for desktop devices */ +@media (min-width: 1536px) +{ + .gradio-container { + max-width: var(--size-full) !important; + } +} + +.gradio-container .contain { + padding: 0 var(--size-4) !important; +} + +#top_logo { + color: transparent; + background-color: transparent; + border-radius: 0 !important; + border: 0; +} + +#ui_title { + padding: var(--size-2) 0 0 var(--size-1); +} + +#demo_title_outer { + border-radius: 0; +} + +#prompt_box_outer div:first-child { + border-radius: 0 !important +} + +#prompt_box textarea, #negative_prompt_box textarea { + background-color: var(--background-fill-primary) !important; +} + +#prompt_examples { + margin: 0 !important; +} + +#prompt_examples svg { + display: none !important; +} + +#ui_body { + padding: var(--size-2) !important; + border-radius: 0.5em !important; +} + +#img_result+div { + display: none !important; +} + +footer { + display: none !important; +} + +#gallery + div { + border-radius: 0 !important; +} + +/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */ +#gallery .thumbnail-item.thumbnail-lg { + aspect-ratio: unset; + max-height: calc(55vh - (2 * var(--spacing-lg))); +} +@media (min-width: 1921px) { + /* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */ + #gallery .grid-wrap, #gallery .preview{ + min-height: calc(768px + 4px + var(--size-14)); + max-height: calc(768px + 4px + var(--size-14)); + } + /* Limit height to 768px_height + 2px_margin_height for the thumbnails */ + #gallery .thumbnail-item.thumbnail-lg { + max-height: 770px !important; + } +} +/* Don't upscale when viewing in solo image mode */ +#gallery .preview img { + object-fit: scale-down; +} +/* Navbar images in cover mode*/ +#gallery .preview .thumbnail-item img { + object-fit: cover; +} + +/* Limit the stable diffusion text output height */ +#std_output textarea { + max-height: 215px; +} + +/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */ +#gallery .wrap.default { + pointer-events: none; +} + +/* Import Png info box */ +#txt2img_prompt_image { + height: var(--size-32) !important; +} + +/* Hide "remove buttons" from ui dropdowns */ +#custom_model .token-remove.remove-all, +#lora_weights .token-remove.remove-all, +#scheduler .token-remove.remove-all, +#device .token-remove.remove-all, +#stencil_model .token-remove.remove-all { + display: none; +} + +/* Hide selected items from ui dropdowns */ +#custom_model .options .item .inner-item, +#scheduler .options .item .inner-item, +#device .options .item .inner-item, +#stencil_model .options .item .inner-item { + display:none; +} + +/* workarounds for container=false not currently working for dropdowns */ +.dropdown_no_container { + padding: 0 !important; +} + +#output_subdir_container :first-child { + border: none; +} + +/* reduced animation load when generating */ +.generating { + animation-play-state: paused !important; +} + +/* better clarity when progress bars are minimal */ +.meta-text { + background-color: var(--block-label-background-fill); +} + +/* lora tag pills */ +.lora-tags { + border: 1px solid var(--border-color-primary); + color: var(--block-info-text-color) !important; + padding: var(--block-padding); +} + +.lora-tag { + display: inline-block; + height: 2em; + color: rgb(212 212 212) !important; + margin-right: 5pt; + margin-bottom: 5pt; + padding: 2pt 5pt; + border-radius: 5pt; + white-space: nowrap; +} + +.lora-model { + margin-bottom: var(--spacing-lg); + color: var(--block-info-text-color) !important; + line-height: var(--line-sm); +} + +/* output gallery tab */ +.output_parameters_dataframe table.table { + /* works around a gradio bug that always shows scrollbars */ + overflow: clip auto; +} + +.output_parameters_dataframe tbody td { + font-size: small; + line-height: var(--line-xs); +} + +.output_icon_button { + max-width: 30px; + align-self: end; + padding-bottom: 8px; +} + +.outputgallery_sendto { + min-width: 7em !important; +} + +/* output gallery should take up most of the viewport height regardless of image size/number */ +#outputgallery_gallery .fixed-height { + min-height: 89vh !important; +} + +/* don't stretch non-square images to be square, breaking their aspect ratio */ +#outputgallery_gallery .thumbnail-item.thumbnail-lg > img { + object-fit: contain !important; +} + +/* centered logo for when there are no images */ +#top_logo.logo_centered { + height: 100%; + width: 100%; +} + +#top_logo.logo_centered img{ + object-fit: scale-down; + position: absolute; + width: 80%; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); +} diff --git a/apps/shark_studio/web/ui/logos/nod-icon.png b/apps/shark_studio/web/ui/logos/nod-icon.png new file mode 100644 index 0000000000..29f7e32220 Binary files /dev/null and b/apps/shark_studio/web/ui/logos/nod-icon.png differ diff --git a/apps/shark_studio/web/ui/logos/nod-logo.png b/apps/shark_studio/web/ui/logos/nod-logo.png new file mode 100644 index 0000000000..4727e15a19 Binary files /dev/null and b/apps/shark_studio/web/ui/logos/nod-logo.png differ diff --git a/apps/shark_studio/web/ui/outputgallery.py b/apps/shark_studio/web/ui/outputgallery.py new file mode 100644 index 0000000000..dd58541aae --- /dev/null +++ b/apps/shark_studio/web/ui/outputgallery.py @@ -0,0 +1,416 @@ +import glob +import gradio as gr +import os +import subprocess +import sys +from PIL import Image + +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.api.utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, +) +from apps.shark_studio.web.ui.utils import nodlogo_loc +from apps.shark_studio.web.utils.metadata import displayable_metadata + +# -- Functions for file, directory and image info querying + +output_dir = get_generated_imgs_path() + + +def outputgallery_filenames(subdir) -> list[str]: + new_dir_path = os.path.join(output_dir, subdir) + if os.path.exists(new_dir_path): + filenames = [ + glob.glob(new_dir_path + "/" + ext) + for ext in ("*.png", "*.jpg", "*.jpeg") + ] + + return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True) + else: + return [] + + +def output_subdirs() -> list[str]: + # Gets a list of subdirectories of output_dir and below, as relative paths. + relative_paths = [ + os.path.relpath(entry[0], output_dir) + for entry in os.walk( + output_dir, followlinks=cmd_opts.output_gallery_followlinks + ) + ] + + # It is less confusing to always including the subdir that will take any + # images generated today even if it doesn't exist yet + if get_generated_imgs_todays_subdir() not in relative_paths: + relative_paths.append(get_generated_imgs_todays_subdir()) + + # sort subdirectories so that the date named ones we probably + # created in this or previous sessions come first, sorted with the most + # recent first. Other subdirs are listed after. + generated_paths = sorted( + [path for path in relative_paths if path.isnumeric()], reverse=True + ) + result_paths = generated_paths + sorted( + [ + path + for path in relative_paths + if (not path.isnumeric()) and path != "." + ] + ) + + return result_paths + + +# --- Define UI layout for Gradio + +with gr.Blocks() as outputgallery_element: + nod_logo = Image.open(nodlogo_loc) + + with gr.Row(elem_id="outputgallery_gallery"): + # needed to workaround gradio issue: + # https://github.com/gradio-app/gradio/issues/2907 + dev_null = gr.Textbox("", visible=False) + + gallery_files = gr.State(value=[]) + subdirectory_paths = gr.State(value=[]) + + with gr.Column(scale=6): + logo = gr.Image( + label="Getting subdirectories...", + value=nod_logo, + interactive=False, + visible=True, + show_label=True, + elem_id="top_logo", + elem_classes="logo_centered", + show_download_button=False, + ) + + gallery = gr.Gallery( + label="", + value=gallery_files.value, + visible=False, + show_label=True, + columns=4, + ) + + with gr.Column(scale=4): + with gr.Group(): + with gr.Row(): + with gr.Column( + scale=15, + min_width=160, + elem_id="output_subdir_container", + ): + subdirectories = gr.Dropdown( + label=f"Subdirectories of {output_dir}", + type="value", + choices=subdirectory_paths.value, + value="", + interactive=True, + elem_classes="dropdown_no_container", + allow_custom_value=True, + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + open_subdir = gr.Button( + variant="secondary", + value="\U0001F5C1", # unicode open folder + interactive=False, + size="sm", + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + refresh = gr.Button( + variant="secondary", + value="\u21BB", # unicode clockwise arrow circle + size="sm", + ) + + image_columns = gr.Slider( + label="Columns shown", value=4, minimum=1, maximum=16, step=1 + ) + outputgallery_filename = gr.Textbox( + label="Filename", + value="None", + interactive=False, + show_copy_button=True, + ) + + with gr.Accordion( + label="Parameter Information", open=False + ) as parameters_accordian: + image_parameters = gr.DataFrame( + headers=["Parameter", "Value"], + col_count=2, + wrap=True, + elem_classes="output_parameters_dataframe", + value=[["Status", "No image selected"]], + interactive=True, + ) + + with gr.Accordion(label="Send To", open=True): + with gr.Row(): + outputgallery_sendto_sd = gr.Button( + value="Stable Diffusion", + interactive=False, + elem_classes="outputgallery_sendto", + size="sm", + ) + + # --- Event handlers + + def on_clear_gallery(): + return [ + gr.Gallery( + value=[], + visible=False, + ), + gr.Image( + visible=True, + ), + ] + + def on_image_columns_change(columns): + return gr.Gallery(columns=columns) + + def on_select_subdir(subdir) -> list: + # evt.value is the subdirectory name + new_images = outputgallery_filenames(subdir) + new_label = ( + f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" + ) + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_open_subdir(subdir): + subdir_path = os.path.normpath(os.path.join(output_dir, subdir)) + + if os.path.isdir(subdir_path): + if sys.platform == "linux": + subprocess.run(["xdg-open", subdir_path]) + elif sys.platform == "darwin": + subprocess.run(["open", subdir_path]) + elif sys.platform == "win32": + os.startfile(subdir_path) + + def on_refresh(current_subdir: str) -> list: + # get an up-to-date subdirectory list + refreshed_subdirs = output_subdirs() + # get the images using either the current subdirectory or the most + # recent valid one + new_subdir = ( + current_subdir + if current_subdir in refreshed_subdirs + else refreshed_subdirs[0] + ) + new_images = outputgallery_filenames(new_subdir) + new_label = ( + f"{len(new_images)} images in " + f"{os.path.join(output_dir, new_subdir)}" + ) + + return [ + gr.Dropdown( + choices=refreshed_subdirs, + value=new_subdir, + ), + refreshed_subdirs, + new_images, + gr.Gallery( + value=new_images, label=new_label, visible=len(new_images) > 0 + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_new_image(subdir, subdir_paths, status) -> list: + # prevent error triggered when an image generates before the tab + # has even been selected + subdir_paths = ( + subdir_paths + if len(subdir_paths) > 0 + else [get_generated_imgs_todays_subdir()] + ) + + # only update if the current subdir is the most recent one as + # new images only go there + if subdir_paths[0] == subdir: + new_images = outputgallery_filenames(subdir) + new_label = ( + f"{len(new_images)} images in " + f"{os.path.join(output_dir, subdir)} - {status}" + ) + + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + else: + # otherwise change nothing, + # (only untyped gradio gr.update() does this) + return [gr.update(), gr.update(), gr.update()] + + def on_select_image(images: list[str], evt: gr.SelectData) -> list: + # evt.index is an index into the full list of filenames for + # the current subdirectory + filename = images[evt.index] + params = displayable_metadata(filename) + + if params: + if params["source"] == "missing": + return [ + "Could not find this image file, refresh the gallery and update the images", + [["Status", "File missing"]], + ] + else: + return [ + filename, + list(map(list, params["parameters"].items())), + ] + + return [ + filename, + [["Status", "No parameters found"]], + ] + + def on_outputgallery_filename_change(filename: str) -> list: + exists = filename != "None" and os.path.exists(filename) + return [ + # disable or enable each of the sendto button based on whether + # an image is selected + gr.Button(interactive=exists), + ] + + # The time first our tab is selected we need to do an initial refresh + # to populate the subdirectory select box and the images from the most + # recent subdirectory. + # + # We do it at this point rather than setting this up in the controls' + # definitions as when you refresh the browser you always get what was + # *initially* set, which won't include any new subdirectories or images + # that might have created since the application was started. Doing it + # this way means a browser refresh/reload always gets the most + # up-to-date data. + def on_select_tab(subdir_paths, request: gr.Request): + local_client = request.headers["host"].startswith( + "127.0.0.1:" + ) or request.headers["host"].startswith("localhost:") + + if len(subdir_paths) == 0: + return on_refresh("") + [gr.update(interactive=local_client)] + else: + return ( + # Change nothing, (only untyped gr.update() does this) + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + # clearing images when we need to completely change what's in the + # gallery avoids current images being shown replacing piecemeal and + # prevents weirdness and errors if the user selects an image during the + # replacement phase. + clear_gallery = dict( + fn=on_clear_gallery, + inputs=None, + outputs=[gallery, logo], + queue=False, + ) + + subdirectories.select(**clear_gallery).then( + on_select_subdir, + [subdirectories], + [gallery_files, gallery, logo], + queue=False, + ) + + open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False) + + refresh.click(**clear_gallery).then( + on_refresh, + [subdirectories], + [subdirectories, subdirectory_paths, gallery_files, gallery, logo], + queue=False, + ) + + image_columns.change( + fn=on_image_columns_change, + inputs=[image_columns], + outputs=[gallery], + queue=False, + ) + + gallery.select( + on_select_image, + [gallery_files], + [outputgallery_filename, image_parameters], + queue=False, + ) + + outputgallery_filename.change( + on_outputgallery_filename_change, + [outputgallery_filename], + [ + outputgallery_sendto_sd, + ], + queue=False, + ) + + # We should have been given the .select function for our tab, so set it up + def outputgallery_tab_select(select): + select( + fn=on_select_tab, + inputs=[subdirectory_paths], + outputs=[ + subdirectories, + subdirectory_paths, + gallery_files, + gallery, + logo, + open_subdir, + ], + queue=False, + ) + + # We should have been passed a list of components on other tabs that update + # when a new image has generated on that tab, so set things up so the user + # will see that new image if they are looking at today's subdirectory + def outputgallery_watch(components: gr.Textbox): + for component in components: + component.change( + on_new_image, + inputs=[subdirectories, subdirectory_paths, component], + outputs=[gallery_files, gallery, logo], + queue=False, + ) diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py new file mode 100644 index 0000000000..f26c7967e3 --- /dev/null +++ b/apps/shark_studio/web/ui/sd.py @@ -0,0 +1,650 @@ +import os +import time +import gradio as gr +import PIL +import json +import sys + +from math import ceil +from inspect import signature +from PIL import Image +from pathlib import Path +from datetime import datetime as dt +from gradio.components.image_editor import ( + Brush, + Eraser, + EditorValue, +) + +from apps.shark_studio.api.utils import ( + get_available_devices, + get_generated_imgs_path, + get_checkpoints_path, + get_checkpoints, +) +from apps.shark_studio.api.sd import ( + sd_model_map, + shark_sd_fn, + cancel_sd, +) +from apps.shark_studio.api.controlnet import ( + preprocessor_model_map, + PreprocessorModel, + cnet_preview, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, +) +from apps.shark_studio.modules.img_processing import ( + resampler_list, + resize_stencil, +) +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.web.ui.utils import ( + nodlogo_loc, +) +from apps.shark_studio.web.utils.state import ( + get_generation_text_info, + status_label, +) +from apps.shark_studio.web.ui.common_events import lora_changed + + +def view_json_file(file_obj): + content = "" + with open(file_obj.name, "r") as fopen: + content = fopen.read() + return content + + +max_controlnets = 3 +max_loras = 5 + + +def show_loras(k): + k = int(k) + return gr.State( + [gr.Dropdown(visible=True)] * k + + [gr.Dropdown(visible=False, value="None")] * (max_loras - k) + ) + + +def show_controlnets(k): + k = int(k) + return [ + gr.State( + [ + [gr.Row(visible=True, render=True)] * k + + [gr.Row(visible=False)] * (max_controlnets - k) + ] + ), + gr.State([None] * k), + gr.State([None] * k), + gr.State([None] * k), + ] + + +def create_canvas(width, height): + data = Image.fromarray( + np.zeros( + shape=(height, width, 3), + dtype=np.uint8, + ) + + 255 + ) + img_dict = { + "background": data, + "layers": [data], + "composite": None, + } + return EditorValue(img_dict) + + +def import_original(original_img, width, height): + resized_img, _, _ = resize_stencil(original_img, width, height) + img_dict = { + "background": resized_img, + "layers": [resized_img], + "composite": None, + } + return gr.ImageEditor( + value=EditorValue(img_dict), + crop_size=(width, height), + ) + + +def update_cn_input( + model, + width, + height, + stencils, + images, + preprocessed_hints, +): + if model == None: + stencils[index] = None + images[index] = None + preprocessed_hints[index] = None + return [ + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + stencils, + images, + preprocessed_hints, + ] + elif model == "scribble": + return [ + gr.ImageEditor( + visible=True, + interactive=True, + show_label=False, + image_mode="RGB", + type="pil", + brush=Brush( + colors=["#000000"], + color_mode="fixed", + default_size=5, + ), + ), + gr.Image( + visible=True, + show_label=False, + interactive=True, + show_download_button=False, + ), + gr.Slider(visible=True, label="Canvas Width"), + gr.Slider(visible=True, label="Canvas Height"), + gr.Button(visible=True), + gr.Button(visible=False), + stencils, + images, + preprocessed_hints, + ] + else: + return [ + gr.ImageEditor( + visible=True, + interactive=True, + show_label=False, + image_mode="RGB", + type="pil", + ), + gr.Image( + visible=True, + show_label=False, + interactive=True, + show_download_button=False, + ), + gr.Slider(visible=True, label="Canvas Width"), + gr.Slider(visible=True, label="Canvas Height"), + gr.Button(visible=True), + gr.Button(visible=False), + stencils, + images, + preprocessed_hints, + ] + + +sd_fn_inputs = [] +sd_fn_sig = signature(shark_sd_fn).replace() +for i in sd_fn_sig.parameters: + sd_fn_inputs.append(i) + +with gr.Blocks(title="Stable Diffusion") as sd_element: + # Get a list of arguments needed for the API call, then + # initialize an empty list that will manage the corresponding + # gradio values. + with gr.Row(elem_id="ui_title"): + nod_logo = Image.open(nodlogo_loc) + with gr.Row(variant="compact", equal_height=True): + with gr.Column( + scale=1, + elem_id="demo_title_outer", + ): + gr.Image( + value=nod_logo, + show_label=False, + interactive=False, + elem_id="top_logo", + width=150, + height=50, + show_download_button=False, + ) + with gr.Column(elem_id="ui_body"): + with gr.Row(): + with gr.Column(scale=1, min_width=600): + with gr.Row(equal_height=True): + with gr.Column(scale=3): + sd_model_info = ( + f"Checkpoint Path: {str(get_checkpoints_path())}" + ) + sd_base = gr.Dropdown( + label="Base Model", + info="Select or enter HF model ID", + elem_id="custom_model", + value="stabilityai/stable-diffusion-2-1-base", + choices=sd_model_map.keys(), + ) # base_model_id + sd_custom_weights = gr.Dropdown( + label="Weights (Optional)", + info="Select or enter HF model ID", + elem_id="custom_model", + value="None", + allow_custom_value=True, + choices=get_checkpoints(sd_base), + ) # + with gr.Column(scale=2): + sd_vae_info = ( + str(get_checkpoints_path("vae")) + ).replace("\\", "\n\\") + sd_vae_info = f"VAE Path: {sd_vae_info}" + sd_custom_vae = gr.Dropdown( + label=f"Custom VAE Models", + info=sd_vae_info, + elem_id="custom_model", + value=os.path.basename(cmd_opts.custom_vae) + if cmd_opts.custom_vae + else "None", + choices=["None"] + get_checkpoints("vae"), + allow_custom_value=True, + scale=1, + ) + with gr.Column(scale=1): + save_sd_config = gr.Button( + value="Save Config", size="sm" + ) + clear_sd_config = gr.ClearButton( + value="Clear Config", size="sm" + ) + load_sd_config = gr.FileExplorer( + label="Load Config", + root=os.path.basename("./configs"), + ) + + with gr.Group(elem_id="prompt_box_outer"): + prompt = gr.Textbox( + label="Prompt", + value=cmd_opts.prompts[0], + lines=2, + elem_id="prompt_box", + ) + negative_prompt = gr.Textbox( + label="Negative Prompt", + value=cmd_opts.negative_prompts[0], + lines=2, + elem_id="negative_prompt_box", + ) + + with gr.Accordion(label="Input Image", open=False): + # TODO: make this import image prompt info if it exists + sd_init_image = gr.Image( + label="Input Image", + type="pil", + height=300, + interactive=True, + ) + with gr.Accordion( + label="Embeddings options", open=False, render=True + ): + sd_lora_info = ( + str(get_checkpoints_path("loras")) + ).replace("\\", "\n\\") + num_loras = gr.Slider( + 1, max_loras, value=1, step=1, label="LoRA Count" + ) + loras = gr.State([]) + for i in range(max_loras): + with gr.Row(): + lora_opt = gr.Dropdown( + allow_custom_value=True, + label=f"Standalone LoRA Weights", + info=sd_lora_info, + elem_id="lora_weights", + value="None", + choices=["None"] + get_checkpoints("lora"), + ) + with gr.Row(): + lora_tags = gr.HTML( + value="
No LoRA selected
", + elem_classes="lora-tags", + ) + gr.on( + triggers=[lora_opt.change], + fn=lora_changed, + inputs=[lora_opt], + outputs=[lora_tags], + queue=True, + ) + loras.value.append(lora_opt) + + num_loras.change(show_loras, [num_loras], [loras]) + with gr.Accordion(label="Advanced Options", open=True): + with gr.Row(): + scheduler = gr.Dropdown( + elem_id="scheduler", + label="Scheduler", + value="EulerDiscrete", + choices=scheduler_model_map.keys(), + allow_custom_value=False, + ) + with gr.Row(): + height = gr.Slider( + 384, + 768, + value=cmd_opts.height, + step=8, + label="Height", + ) + width = gr.Slider( + 384, + 768, + value=cmd_opts.width, + step=8, + label="Width", + ) + with gr.Row(): + with gr.Column(scale=3): + steps = gr.Slider( + 1, + 100, + value=cmd_opts.steps, + step=1, + label="Steps", + ) + batch_count = gr.Slider( + 1, + 100, + value=cmd_opts.batch_count, + step=1, + label="Batch Count", + interactive=True, + ) + batch_size = gr.Slider( + 1, + 4, + value=cmd_opts.batch_size, + step=1, + label="Batch Size", + interactive=True, + visible=True, + ) + repeatable_seeds = gr.Checkbox( + cmd_opts.repeatable_seeds, + label="Repeatable Seeds", + ) + with gr.Column(scale=3): + strength = gr.Slider( + 0, + 1, + value=cmd_opts.strength, + step=0.01, + label="Denoising Strength", + ) + resample_type = gr.Dropdown( + value=cmd_opts.resample_type, + choices=resampler_list, + label="Resample Type", + allow_custom_value=True, + ) + guidance_scale = gr.Slider( + 0, + 50, + value=cmd_opts.guidance_scale, + step=0.1, + label="CFG Scale", + ) + ondemand = gr.Checkbox( + value=cmd_opts.lowvram, + label="Low VRAM", + interactive=True, + ) + precision = gr.Radio( + label="Precision", + value=cmd_opts.precision, + choices=[ + "fp16", + "fp32", + ], + visible=True, + ) + with gr.Row(): + seed = gr.Textbox( + value=cmd_opts.seed, + label="Seed", + info="An integer or a JSON list of integers, -1 for random", + ) + device = gr.Dropdown( + elem_id="device", + label="Device", + value=get_available_devices()[0], + choices=get_available_devices(), + allow_custom_value=False, + ) + with gr.Accordion( + label="Controlnet Options", open=False, render=False + ): + sd_cnet_info = ( + str(get_checkpoints_path("controlnet")) + ).replace("\\", "\n\\") + num_cnets = gr.Slider( + 0, + max_controlnets, + value=0, + step=1, + label="Controlnet Count", + ) + cnet_rows = [] + stencils = gr.State([]) + images = gr.State([]) + preprocessed_hints = gr.State([]) + control_mode = gr.Radio( + choices=["Prompt", "Balanced", "Controlnet"], + value="Balanced", + label="Control Mode", + ) + + for i in range(max_controlnets): + with gr.Row(visible=False) as cnet_row: + with gr.Column(): + cnet_gen = gr.Button( + value="Preprocess controlnet input", + ) + cnet_model = gr.Dropdown( + allow_custom_value=True, + label=f"Controlnet Model", + info=sd_cnet_info, + elem_id="lora_weights", + value="None", + choices=[ + "None", + "canny", + "openpose", + "scribble", + "zoedepth", + ] + + get_checkpoints("controlnet"), + ) + canvas_width = gr.Slider( + label="Canvas Width", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + canvas_height = gr.Slider( + label="Canvas Height", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + make_canvas = gr.Button( + value="Make Canvas!", + visible=False, + ) + use_input_img = gr.Button( + value="Use Original Image", + visible=False, + ) + cnet_input = gr.ImageEditor( + visible=True, + image_mode="RGB", + interactive=True, + show_label=True, + label="Input Image", + type="pil", + ) + cnet_output = gr.Image( + value=None, + visible=True, + label="Preprocessed Hint", + interactive=True, + show_label=True, + ) + use_input_img.click( + import_original, + [sd_init_image, canvas_width, canvas_height], + [cnet_input], + ) + cnet_model.change( + fn=update_cn_input, + inputs=[ + cnet_model, + canvas_width, + canvas_height, + stencils, + images, + preprocessed_hints, + ], + outputs=[ + cnet_input, + cnet_output, + canvas_width, + canvas_height, + make_canvas, + use_input_img, + stencils, + images, + preprocessed_hints, + ], + ) + make_canvas.click( + create_canvas, + [canvas_width, canvas_height], + [ + cnet_input, + ], + ) + gr.on( + triggers=[cnet_gen.click], + fn=cnet_preview, + inputs=[ + cnet_model, + cnet_input, + stencils, + images, + preprocessed_hints, + ], + outputs=[ + cnet_output, + stencils, + images, + preprocessed_hints, + ], + ) + cnet_rows.value.append(cnet_row) + + num_cnets.change( + show_controlnets, + [num_cnets], + [cnet_rows, stencils, images, preprocessed_hints], + ) + with gr.Column(scale=1, min_width=600): + with gr.Group(): + sd_gallery = gr.Gallery( + label="Generated images", + show_label=False, + elem_id="gallery", + columns=2, + object_fit="contain", + ) + std_output = gr.Textbox( + value=f"{sd_model_info}\n" + f"Images will be saved at " + f"{get_generated_imgs_path()}", + lines=2, + elem_id="std_output", + show_label=False, + ) + sd_status = gr.Textbox(visible=False) + with gr.Row(): + stable_diffusion = gr.Button("Generate Image(s)") + random_seed = gr.Button("Randomize Seed") + random_seed.click( + lambda: -1, + inputs=[], + outputs=[seed], + queue=False, + ) + stop_batch = gr.Button("Stop Batch") + + kwargs = dict( + fn=shark_sd_fn, + inputs=[ + prompt, + negative_prompt, + sd_init_image, + height, + width, + steps, + strength, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + sd_base, + sd_custom_weights, + sd_custom_vae, + precision, + device, + loras, + ondemand, + repeatable_seeds, + resample_type, + control_mode, + stencils, + images, + preprocessed_hints, + ], + outputs=[ + sd_gallery, + std_output, + sd_status, + stencils, + images, + ], + show_progress="minimal", + ) + + status_kwargs = dict( + fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs), + inputs=[batch_count, batch_size], + outputs=sd_status, + ) + + prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) + neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( + **kwargs + ) + generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) + stop_batch.click( + fn=cancel_sd, + cancels=[prompt_submit, neg_prompt_submit, generate_click], + ) diff --git a/apps/shark_studio/web/ui/utils.py b/apps/shark_studio/web/ui/utils.py new file mode 100644 index 0000000000..ba62e5adc0 --- /dev/null +++ b/apps/shark_studio/web/ui/utils.py @@ -0,0 +1,33 @@ +from enum import IntEnum +import math +import sys +import os + + +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) + + +nodlogo_loc = resource_path("logos/nod-logo.png") +nodicon_loc = resource_path("logos/nod-icon.png") + + +class HSLHue(IntEnum): + RED = 0 + YELLOW = 60 + GREEN = 120 + CYAN = 180 + BLUE = 240 + MAGENTA = 300 + + +def hsl_color(alpha: float, start, end): + b = (end - start) * (alpha if alpha > 0 else 0) + result = b + start + + # Return a CSS HSL string + return f"hsl({math.floor(result)}, 80%, 35%)" diff --git a/apps/shark_studio/web/utils/globals.py b/apps/shark_studio/web/utils/globals.py new file mode 100644 index 0000000000..0b5f54636a --- /dev/null +++ b/apps/shark_studio/web/utils/globals.py @@ -0,0 +1,74 @@ +import gc + +""" +The global objects include SD pipeline and config. +Maintaining the global objects would avoid creating extra pipeline objects when switching modes. +Also we could avoid memory leak when switching models by clearing the cache. +""" + + +def _init(): + global _sd_obj + global _config_obj + global _schedulers + _sd_obj = None + _config_obj = None + _schedulers = None + + +def set_sd_obj(value): + global _sd_obj + _sd_obj = value + + +def set_sd_scheduler(key): + global _sd_obj + _sd_obj.scheduler = _schedulers[key] + + +def set_sd_status(value): + global _sd_obj + _sd_obj.status = value + + +def set_cfg_obj(value): + global _config_obj + _config_obj = value + + +def set_schedulers(value): + global _schedulers + _schedulers = value + + +def get_sd_obj(): + global _sd_obj + return _sd_obj + + +def get_sd_status(): + global _sd_obj + return _sd_obj.status + + +def get_cfg_obj(): + global _config_obj + return _config_obj + + +def get_scheduler(key): + global _schedulers + return _schedulers[key] + + +def clear_cache(): + global _sd_obj + global _config_obj + global _schedulers + del _sd_obj + del _config_obj + del _schedulers + gc.collect() + _sd_obj = None + _config_obj = None + _schedulers = None diff --git a/apps/shark_studio/web/utils/metadata/__init__.py b/apps/shark_studio/web/utils/metadata/__init__.py new file mode 100644 index 0000000000..bcbcf746ca --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/__init__.py @@ -0,0 +1,6 @@ +from .png_metadata import ( + import_png_metadata, +) +from .display import ( + displayable_metadata, +) diff --git a/apps/shark_studio/web/utils/metadata/csv_metadata.py b/apps/shark_studio/web/utils/metadata/csv_metadata.py new file mode 100644 index 0000000000..d617e802bf --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/csv_metadata.py @@ -0,0 +1,45 @@ +import csv +import os +from .format import humanize, humanizable + + +def csv_path(image_filename: str): + return os.path.join(os.path.dirname(image_filename), "imgs_details.csv") + + +def has_csv(image_filename: str) -> bool: + return os.path.exists(csv_path(image_filename)) + + +def matching_filename(image_filename: str, row): + # we assume the final column of the csv has the original filename with full path and match that + # against the image_filename if we are given a list. Otherwise we assume a dict and and take + # the value of the OUTPUT key + return os.path.basename(image_filename) in ( + row[-1] if isinstance(row, list) else row["OUTPUT"] + ) + + +def parse_csv(image_filename: str): + csv_filename = csv_path(image_filename) + + with open(csv_filename, "r", newline="") as csv_file: + # We use a reader or DictReader here for images_details.csv depending on whether we think it + # has headers or not. Having headers means less guessing of the format. + has_header = csv.Sniffer().has_header(csv_file.read(2048)) + csv_file.seek(0) + + reader = ( + csv.DictReader(csv_file) if has_header else csv.reader(csv_file) + ) + + matches = [ + # we rely on humanize and humanizable to work out the parsing of the individual .csv rows + humanize(row) + for row in reader + if row + and (has_header or humanizable(row)) + and matching_filename(image_filename, row) + ] + + return matches[0] if matches else {} diff --git a/apps/shark_studio/web/utils/metadata/display.py b/apps/shark_studio/web/utils/metadata/display.py new file mode 100644 index 0000000000..26234aab5c --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/display.py @@ -0,0 +1,53 @@ +import json +import os +from PIL import Image +from .png_metadata import parse_generation_parameters +from .exif_metadata import has_exif, parse_exif +from .csv_metadata import has_csv, parse_csv +from .format import compact, humanize + + +def displayable_metadata(image_filename: str) -> dict: + if not os.path.isfile(image_filename): + return {"source": "missing", "parameters": {}} + + pil_image = Image.open(image_filename) + + # we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads, + # and we go via that for SendTo, and is directly tied to the image) + if "parameters" in pil_image.info: + return { + "source": "png", + "parameters": compact( + parse_generation_parameters(pil_image.info["parameters"]) + ), + } + + # we have a matching json file (next most likely to be accurate when it's there) + json_path = os.path.splitext(image_filename)[0] + ".json" + if os.path.isfile(json_path): + with open(json_path) as params_file: + return { + "source": "json", + "parameters": compact( + humanize(json.load(params_file), includes_filename=False) + ), + } + + # we have a CSV file so try that (can be different shapes, and it usually has no + # headers/param names so of the things we we *know* have parameters, it's the + # last resort) + if has_csv(image_filename): + params = parse_csv(image_filename) + if params: # we might not have found the filename in the csv + return { + "source": "csv", + "parameters": compact(params), # already humanized + } + + # EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something* + if has_exif(image_filename): + return {"source": "exif", "parameters": parse_exif(pil_image)} + + # we've got nothing + return None diff --git a/apps/shark_studio/web/utils/metadata/exif_metadata.py b/apps/shark_studio/web/utils/metadata/exif_metadata.py new file mode 100644 index 0000000000..c72da8a935 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/exif_metadata.py @@ -0,0 +1,52 @@ +from PIL import Image +from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS + + +def has_exif(image_filename: str) -> bool: + return True if Image.open(image_filename).getexif() else False + + +def parse_exif(pil_image: Image) -> dict: + img_exif = pil_image.getexif() + + # See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594 + # I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I + # I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a + # dependency + exif_tags = { + TAGS.get(key, key): str(val) + for (key, val) in img_exif.items() + if key in TAGS + and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo) + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + def try_get_ifd(ifd_id): + try: + return img_exif.get_ifd(ifd_id).items() + except KeyError: + return {} + + ifd_tags = { + TAGS.get(key, key): str(val) + for ifd_id in IFD + for (key, val) in try_get_ifd(ifd_id) + if ifd_id != IFD.GPSInfo + and key in TAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + gps_tags = { + GPSTAGS.get(key, key): str(val) + for (key, val) in try_get_ifd(IFD.GPSInfo) + if key in GPSTAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + return {**exif_tags, **ifd_tags, **gps_tags} diff --git a/apps/shark_studio/web/utils/metadata/format.py b/apps/shark_studio/web/utils/metadata/format.py new file mode 100644 index 0000000000..f097dab54f --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/format.py @@ -0,0 +1,143 @@ +# As SHARK has evolved more columns have been added to images_details.csv. However, since +# no version of the CSV has any headers (yet) we don't actually have anything within the +# file that tells us which parameter each column is for. So this is a list of known patterns +# indexed by length which is what we're going to have to use to guess which columns are the +# right ones for the file we're looking at. + +# The same ordering is used for JSON, but these do have key names, however they are not very +# human friendly, nor do they match up with the what is written to the .png headers + +# So these are functions to try and get something consistent out the raw input from all +# these sources + +PARAMS_FORMATS = { + 9: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 10: { + "MODEL": "Model", + "VARIANT": "Variant", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 12: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", + }, +} + +PARAMS_FORMAT_CURRENT = { + "VARIANT": "Model", + "VAE": "VAE", + "LORA": "LoRA", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", +} + + +def compact(metadata: dict) -> dict: + # we don't want to alter the original dictionary + result = dict(metadata) + + # discard the filename because we should already have it + if result.keys() & {"Filename"}: + result.pop("Filename") + + # make showing the sizes more compact by using only one line each + if result.keys() & {"Size-1", "Size-2"}: + result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}" + elif result.keys() & {"Height", "Width"}: + result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}" + + if result.keys() & {"Hires resize-1", "Hires resize-1"}: + hires_y = result.pop("Hires resize-1") + hires_x = result.pop("Hires resize-2") + + if hires_x == 0 and hires_y == 0: + result["Hires resize"] = "None" + else: + result["Hires resize"] = f"{hires_y}x{hires_x}" + + # remove VAE if it exists and is empty + if (result.keys() & {"VAE"}) and ( + not result["VAE"] or result["VAE"] == "None" + ): + result.pop("VAE") + + # remove LoRA if it exists and is empty + if (result.keys() & {"LoRA"}) and ( + not result["LoRA"] or result["LoRA"] == "None" + ): + result.pop("LoRA") + + return result + + +def humanizable(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + return lookup_key in PARAMS_FORMATS.keys() + + +def humanize(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + + # For lists we can only work based on the length, we have no other information + if isinstance(metadata, list): + if humanizable(metadata, includes_filename): + return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata)) + else: + raise KeyError( + f"Humanize could not find the format for a parameter list of length {len(metadata)}" + ) + + # For dictionaries we try to use the matching length parameter format if + # available, otherwise we just use the current format which is assumed to + # have everything currently known about. Then we swap keys in the metadata + # that match keys in the format for the friendlier name that we have set + # in the format value + if isinstance(metadata, dict): + if humanizable(metadata, includes_filename): + format = PARAMS_FORMATS[lookup_key] + else: + format = PARAMS_FORMAT_CURRENT + + return { + format[key]: metadata[key] + for key in format.keys() + if key in metadata.keys() and metadata[key] + } + + raise TypeError("Can only humanize parameter lists or dictionaries") diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py new file mode 100644 index 0000000000..cffc385ab7 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -0,0 +1,222 @@ +import re +from pathlib import Path +from apps.shark_studio.api.utils import ( + get_checkpoint_pathfile, +) +from apps.shark_studio.api.sd import ( + sd_model_map, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, +) + +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' +re_param = re.compile(re_param_code) +re_imagesize = re.compile(r"^(\d+)x(\d+)$") + + +def parse_generation_parameters(x: str): + res = {} + prompt = "" + negative_prompt = "" + done_with_prompt = False + + *lines, lastline = x.strip().split("\n") + if len(re_param.findall(lastline)) < 3: + lines.append(lastline) + lastline = "" + + for i, line in enumerate(lines): + line = line.strip() + if line.startswith("Negative prompt:"): + done_with_prompt = True + line = line[16:].strip() + + if done_with_prompt: + negative_prompt += ("" if negative_prompt == "" else "\n") + line + else: + prompt += ("" if prompt == "" else "\n") + line + + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt + + for k, v in re_param.findall(lastline): + v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v + m = re_imagesize.match(v) + if m is not None: + res[k + "-1"] = m.group(1) + res[k + "-2"] = m.group(2) + else: + res[k] = v + + # Missing CLIP skip means it was set to 1 (the default) + if "Clip skip" not in res: + res["Clip skip"] = "1" + + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res[ + "Prompt" + ] += f"""""" + + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + + return res + + +def try_find_model_base_from_png_metadata( + file: str, folder: str = "models" +) -> str: + custom = "" + + # Remove extension from file info + if file.endswith(".safetensors") or file.endswith(".ckpt"): + file = Path(file).stem + # Check for the file name match with one of the local ckpt or safetensors files + if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file(): + custom = file + ".ckpt" + if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file(): + custom = file + ".safetensors" + + return custom + + +def find_model_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + png_hf_id = "" + png_custom = "" + + if key in metadata: + model_file = metadata[key] + png_custom = try_find_model_base_from_png_metadata(model_file) + # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") + if model_file in sd_model_map: + png_custom = model_file + # If nothing had matched, check vendor/hf_model_id + if not png_custom and model_file.count("/"): + png_hf_id = model_file + # No matching model was found + if not png_custom and not png_hf_id: + print( + "Import PNG info: Unable to find a matching model for %s" + % model_file + ) + + return png_custom, png_hf_id + + +def find_vae_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> str: + vae_custom = "" + + if key in metadata: + vae_file = metadata[key] + vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae") + + # VAE input is optional, should not print or throw an error if missing + + return vae_custom + + +def find_lora_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + lora_hf_id = "" + lora_custom = "" + + if key in metadata: + lora_file = metadata[key] + lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora") + # If nothing had matched, check vendor/hf_model_id + if not lora_custom and lora_file.count("/"): + lora_hf_id = lora_file + + # LoRA input is optional, should not print or throw an error if missing + + return lora_custom, lora_hf_id + + +def import_png_metadata( + pil_data, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, +): + try: + png_info = pil_data.info["parameters"] + metadata = parse_generation_parameters(png_info) + + (png_custom_model, png_hf_model_id) = find_model_from_png_metadata( + "Model", metadata + ) + (lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata( + "LoRA", metadata + ) + vae_custom_model = find_vae_from_png_metadata("VAE", metadata) + + negative_prompt = metadata["Negative prompt"] + steps = int(metadata["Steps"]) + cfg_scale = float(metadata["CFG scale"]) + seed = int(metadata["Seed"]) + width = float(metadata["Size-1"]) + height = float(metadata["Size-2"]) + + if "Model" in metadata and png_custom_model: + custom_model = png_custom_model + elif "Model" in metadata and png_hf_model_id: + custom_model = png_hf_model_id + + if "LoRA" in metadata and lora_custom_model: + custom_lora = lora_custom_model + hf_lora_id = "" + if "LoRA" in metadata and lora_hf_model_id: + custom_lora = "None" + hf_lora_id = lora_hf_model_id + + if "VAE" in metadata and vae_custom_model: + custom_vae = vae_custom_model + + if "Prompt" in metadata: + prompt = metadata["Prompt"] + if "Sampler" in metadata: + if metadata["Sampler"] in scheduler_model_map: + sampler = metadata["Sampler"] + else: + print( + "Import PNG info: Unable to find a scheduler for %s" + % metadata["Sampler"] + ) + + except Exception as ex: + if pil_data and pil_data.info.get("parameters"): + print("import_png_metadata failed with %s" % ex) + pass + + return ( + None, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, + ) diff --git a/apps/shark_studio/web/utils/state.py b/apps/shark_studio/web/utils/state.py new file mode 100644 index 0000000000..626d4ce53f --- /dev/null +++ b/apps/shark_studio/web/utils/state.py @@ -0,0 +1,41 @@ +import apps.shark_studio.web.utils.globals as global_obj +import gc + + +def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1): + print(f"Getting status label for {tab_name}") + if batch_index < batch_count: + bs = f"x{batch_size}" if batch_size > 1 else "" + return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}" + else: + return f"{tab_name} complete" + + +def get_generation_text_info(seeds, device): + cfg_dump = {} + for cfg in global_obj.get_config_dict(): + cfg_dump[cfg] = cfg + text_output = f"prompt={cfg_dump['prompts']}" + text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}" + text_output += ( + f"\nmodel_id={cfg_dump['hf_model_id']}, " + f"ckpt_loc={cfg_dump['ckpt_loc']}" + ) + text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}" + text_output += ( + f"\nsteps={cfg_dump['steps']}, " + f"guidance_scale={cfg_dump['guidance_scale']}, " + f"seed={seeds}" + ) + text_output += ( + f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, " + if not cfg_dump.use_hiresfix + else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, " + ) + text_output += ( + f"batch_count={cfg_dump['batch_count']}, " + f"batch_size={cfg_dump['batch_size']}, " + f"max_length={cfg_dump['max_length']}" + ) + + return text_output diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py new file mode 100644 index 0000000000..3e6ba46bfe --- /dev/null +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -0,0 +1,77 @@ +import os +import shutil +from time import time + +shark_tmp = os.path.join(os.getcwd(), "shark_tmp/") + + +def clear_tmp_mlir(): + cleanup_start = time() + print( + "Clearing .mlir temporary files from a prior run. This may take some time..." + ) + mlir_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.endswith(".mlir") + ] + for filename in mlir_files: + os.remove(shark_tmp + filename) + print( + f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds." + ) + + +def clear_tmp_imgs(): + # tell gradio to use a directory under shark_tmp for its temporary + # image files unless somewhere else has been set + if "GRADIO_TEMP_DIR" not in os.environ: + os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio") + + print( + f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. " + + "You may change this by setting the GRADIO_TEMP_DIR environment variable." + ) + + # Clear all gradio tmp images from the last session + if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): + cleanup_start = time() + print( + "Clearing gradio UI temporary image files from a prior run. This may take some time..." + ) + shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True) + print( + f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds." + ) + + # older SHARK versions had to workaround gradio bugs and stored things differently + else: + image_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.startswith("tmp") + and filename.endswith(".png") + ] + if len(image_files) > 0: + print( + "Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..." + ) + cleanup_start = time() + for filename in image_files: + os.remove(shark_tmp + filename) + print( + f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds." + ) + else: + print("No temporary images files to clear.") + + +def config_tmp(): + # create shark_tmp if it does not exist + if not os.path.exists(shark_tmp): + os.mkdir(shark_tmp) + + clear_tmp_mlir() + clear_tmp_imgs()