diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py new file mode 100644 index 0000000000..ea8cdf0cc9 --- /dev/null +++ b/apps/shark_studio/api/controlnet.py @@ -0,0 +1,134 @@ +# from turbine_models.custom_models.controlnet import control_adapter, preprocessors + + +class control_adapter: + def __init__( + self, + model: str, + ): + self.model = None + + def export_control_adapter_model(model_keyword): + return None + + def export_xl_control_adapter_model(model_keyword): + return None + + +class preprocessors: + def __init__( + self, + model: str, + ): + self.model = None + + def export_controlnet_model(model_keyword): + return None + + +control_adapter_map = { + "sd15": { + "canny": {"initializer": control_adapter.export_control_adapter_model}, + "openpose": { + "initializer": control_adapter.export_control_adapter_model + }, + "scribble": { + "initializer": control_adapter.export_control_adapter_model + }, + "zoedepth": { + "initializer": control_adapter.export_control_adapter_model + }, + }, + "sdxl": { + "canny": { + "initializer": control_adapter.export_xl_control_adapter_model + }, + }, +} +preprocessor_model_map = { + "canny": {"initializer": preprocessors.export_controlnet_model}, + "openpose": {"initializer": preprocessors.export_controlnet_model}, + "scribble": {"initializer": preprocessors.export_controlnet_model}, + "zoedepth": {"initializer": preprocessors.export_controlnet_model}, +} + + +class PreprocessorModel: + def __init__( + self, + hf_model_id, + device, + ): + self.model = None + + def compile(self, device): + print("compile not implemented for preprocessor.") + return + + def run(self, inputs): + print("run not implemented for preprocessor.") + return + + +def cnet_preview(model, input_img, stencils, images, preprocessed_hints): + if isinstance(input_image, PIL.Image.Image): + img_dict = { + "background": None, + "layers": [None], + "composite": input_image, + } + input_image = EditorValue(img_dict) + images[index] = input_image + if model: + stencils[index] = model + match model: + case "canny": + canny = CannyDetector() + result = canny( + np.array(input_image["composite"]), + 100, + 200, + ) + preprocessed_hints[index] = Image.fromarray(result) + return ( + Image.fromarray(result), + stencils, + images, + preprocessed_hints, + ) + case "openpose": + openpose = OpenposeDetector() + result = openpose(np.array(input_image["composite"])) + preprocessed_hints[index] = Image.fromarray(result[0]) + return ( + Image.fromarray(result[0]), + stencils, + images, + preprocessed_hints, + ) + case "zoedepth": + zoedepth = ZoeDetector() + result = zoedepth(np.array(input_image["composite"])) + preprocessed_hints[index] = Image.fromarray(result) + return ( + Image.fromarray(result), + stencils, + images, + preprocessed_hints, + ) + case "scribble": + preprocessed_hints[index] = input_image["composite"] + return ( + input_image["composite"], + stencils, + images, + preprocessed_hints, + ) + case _: + preprocessed_hints[index] = None + return ( + None, + stencils, + images, + preprocessed_hints, + ) diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index 432eaf5331..bbb273354c 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -26,22 +26,21 @@ def imports(): startup_timer.record("import gradio") - # from apps.shark_studio.modules import shared_init - # shared_init.initialize() - # startup_timer.record("initialize shared") + import apps.shark_studio.web.utils.globals as global_obj + + global_obj._init() + startup_timer.record("initialize globals") from apps.shark_studio.modules import ( - processing, - gradio_extensons, - ui, + img_processing, ) # noqa: F401 + from apps.shark_studio.modules.schedulers import scheduler_model_map startup_timer.record("other imports") def initialize(): configure_sigint_handler() - configure_opts_onchange() # from apps.shark_studio.modules import modelloader # modelloader.cleanup_models() diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index f4b979d1fc..a601a068f7 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -1,9 +1,13 @@ from turbine_models.custom_models.sd_inference import clip, unet, vae from shark.iree_utils.compile_utils import get_iree_compiled_module from apps.shark_studio.api.utils import get_resource_path +from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.web.utils.state import status_label +from apps.shark_studio.modules.pipeline import SharkPipelineBase import iree.runtime as ireert import gc import torch +import gradio as gr sd_model_map = { "CompVis/stable-diffusion-v1-4": { @@ -86,16 +90,15 @@ class StableDiffusion(SharkPipelineBase): - # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile # a list of necessary modules or a combined "pipeline module" for a # specified job based on the inference task. - # + # # custom_model_ids: a dict of submodel + HF ID pairs for custom submodels. # e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"} - # + # # embeddings: a dict of embedding checkpoints or model IDs to use when # initializing the compiled modules. @@ -107,7 +110,6 @@ def __init__( precision: str = "fp16", device: str = None, custom_model_map: dict = {}, - custom_weights_map: dict = {}, embeddings: dict = {}, import_ir: bool = True, ): @@ -118,12 +120,185 @@ def __init__( self.iree_module_dict = None self.get_compiled_map() + def prepare_pipeline(self, scheduler, custom_model_map): + return None def generate_images( - self, - prompt, - ): - return result_output, + self, + prompt, + negative_prompt, + steps, + strength, + guidance_scale, + seed, + ondemand, + repeatable_seeds, + resample_type, + control_mode, + preprocessed_hints, + ): + return None, None, None, None, None + + +# NOTE: Each `hf_model_id` should have its own starting configuration. + +# model_vmfb_key = "" + + +def shark_sd_fn( + prompt, + negative_prompt, + image_dict, + height: int, + width: int, + steps: int, + strength: float, + guidance_scale: float, + seed: str | int, + batch_count: int, + batch_size: int, + scheduler: str, + base_model_id: str, + custom_weights: str, + custom_vae: str, + precision: str, + device: str, + lora_weights: str | list, + ondemand: bool, + repeatable_seeds: bool, + resample_type: str, + control_mode: str, + stencils: list, + images: list, + preprocessed_hints: list, + progress=gr.Progress(), +): + # Handling gradio ImageEditor datatypes so we have unified inputs to the SD API + for i, stencil in enumerate(stencils): + if images[i] is None and stencil is not None: + continue + elif stencil is None and any( + img is not None for img in [images[i], preprocessed_hints[i]] + ): + images[i] = None + preprocessed_hints[i] = None + elif images[i] is not None: + if isinstance(images[i], dict): + images[i] = images[i]["composite"] + images[i] = images[i].convert("RGB") + + if isinstance(image_dict, PIL.Image.Image): + image = image_dict.convert("RGB") + elif image_dict: + image = image_dict["image"].convert("RGB") + else: + image = None + is_img2img = False + if image: + ( + image, + _, + _, + ) = resize_stencil(image, width, height) + is_img2img = True + print("Performing Stable Diffusion Pipeline setup...") + + device_id = None + + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import apps.shark_studio.web.utils.globals as global_obj + + custom_model_map = {} + if custom_weights != "None": + custom_model_map["unet"] = {"custom_weights": custom_weights} + if custom_vae != "None": + custom_model_map["vae"] = {"custom_weights": custom_vae} + if stencils: + for i, stencil in enumerate(stencils): + if "xl" not in base_model_id.lower(): + custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ + "runwayml/stable-diffusion-v1-5" + ][stencil] + else: + custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ + "stabilityai/stable-diffusion-xl-1.0" + ][stencil] + + submit_pipe_kwargs = { + "base_model_id": base_model_id, + "height": height, + "width": width, + "precision": precision, + "device": device, + "custom_model_map": custom_model_map, + "import_ir": cmd_opts.import_mlir, + "is_img2img": is_img2img, + } + submit_prep_kwargs = { + "scheduler": scheduler, + "custom_model_map": custom_model_map, + "embeddings": lora_weights, + } + submit_run_kwargs = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "steps": steps, + "strength": strength, + "guidance_scale": guidance_scale, + "seed": seed, + "ondemand": ondemand, + "repeatable_seeds": repeatable_seeds, + "resample_type": resample_type, + "control_mode": control_mode, + "preprocessed_hints": preprocessed_hints, + } + + global sd_pipe + global sd_pipe_kwargs + + if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs: + sd_pipe = None + sd_pipe_kwargs = submit_pipe_kwargs + gc.collect() + + if sd_pipe is None: + history[-1][-1] = "Getting the pipeline ready..." + yield history, "" + + # Initializes the pipeline and retrieves IR based on all + # parameters that are static in the turbine output format, + # which is currently MLIR in the torch dialect. + + sd_pipe = SharkStableDiffusionPipeline( + **submit_pipe_kwargs, + ) + + sd_pipe.prepare_pipe(**submit_prep_kwargs) + + for prompt, msg, exec_time in progress.tqdm( + out_imgs=sd_pipe.generate_images(**submit_run_kwargs), + desc="Generating Image...", + ): + text_output = get_generation_text_info( + seeds[: current_batch + 1], device + ) + save_output_img( + out_imgs[0], + seeds[current_batch], + extra_info, + ) + generated_imgs.extend(out_imgs) + yield generated_imgs, text_output, status_label( + "Stable Diffusion", current_batch + 1, batch_count, batch_size + ), stencils, images + + return generated_imgs, text_output, "", stencils, images + + +def cancel_sd(): + print("Inject call to cancel longer API calls.") + return + if __name__ == "__main__": sd = StableDiffusion( diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 8139ed9cb5..120ec3adfa 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -2,6 +2,7 @@ import sys import os import numpy as np +import glob from random import ( randint, seed as seed_random, @@ -12,6 +13,19 @@ from pathlib import Path from safetensors.torch import load_file from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from cpuinfo import get_cpu_info + +# TODO: migrate these utils to studio +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, + get_iree_vulkan_runtime_flags, +) + +checkpoints_filetypes = ( + "*.ckpt", + "*.safetensors", +) def get_available_devices(): @@ -75,6 +89,67 @@ def get_devices_by_name(driver_name): return available_devices +def set_init_device_flags(): + if "vulkan" in cmd_opts.device: + # set runtime flags for vulkan. + set_iree_runtime_flags() + + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_vulkan_target_triple: + triple = get_vulkan_target_triple(device_name) + if triple is not None: + cmd_opts.iree_vulkan_target_triple = triple + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_vulkan_target_triple}." + ) + elif "cuda" in cmd_opts.device: + cmd_opts.device = "cuda" + elif "metal" in cmd_opts.device: + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_metal_target_platform: + triple = get_metal_target_triple(device_name) + if triple is not None: + cmd_opts.iree_metal_target_platform = triple.split("-")[-1] + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_metal_target_platform}." + ) + elif "cpu" in cmd_opts.device: + cmd_opts.device = "cpu" + + +def set_iree_runtime_flags(): + # TODO: This function should be device-agnostic and piped properly + # to general runtime driver init. + vulkan_runtime_flags = get_iree_vulkan_runtime_flags() + if cmd_opts.enable_rgp: + vulkan_runtime_flags += [ + f"--enable_rgp=true", + f"--vulkan_debug_utils=true", + ] + if cmd_opts.device_allocator_heap_key: + vulkan_runtime_flags += [ + f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}", + ] + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + + +def get_all_devices(driver_name): + """ + Inputs: driver_name + Returns a list of all the available devices for a given driver sorted by + the iree path names of the device as in --list_devices option in iree. + """ + from iree.runtime import get_driver + + driver = get_driver(driver_name) + device_list_src = driver.query_available_devices() + device_list_src.sort(key=lambda d: d["path"]) + return device_list_src + + def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" base_path = getattr( @@ -83,26 +158,52 @@ def get_resource_path(relative_path): return os.path.join(base_path, relative_path) - def get_generated_imgs_path() -> Path: return Path( - cmd_opts.output_dir - if cmd_opts.output_dir + cmd_opts.output_dir + if cmd_opts.output_dir else get_resource_path("..\web\generated_imgs") -) + ) def get_generated_imgs_todays_subdir() -> str: return dt.now().strftime("%Y%m%d") -def get_checkpoints_path(model = ""): +def create_checkpoint_folders(): + dir = ["vae", "lora"] + if not cmd_opts.ckpt_dir: + dir.insert(0, "models") + else: + if not os.path.isdir(cmd_opts.ckpt_dir): + sys.exit( + f"Invalid --ckpt_dir argument, " + f"{args.ckpt_dir} folder does not exists." + ) + for root in dir: + Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True) + + +def get_checkpoints_path(model=""): return get_resource_path(f"..\web\models\{model}") -def get_checkpoints(path): - files = [] - for file in +def get_checkpoints(model="models"): + ckpt_files = [] + file_types = checkpoints_filetypes + if model == "lora": + file_types = file_types + ("*.pt", "*.bin") + for extn in file_types: + files = [ + os.path.basename(x) + for x in glob.glob(os.path.join(get_checkpoints_path(model), extn)) + ] + ckpt_files.extend(files) + return sorted(ckpt_files, key=str.casefold) + + +def get_checkpoint_pathfile(checkpoint_name, model="models"): + return os.path.join(get_checkpoints_path(model), checkpoint_name) def get_device_mapping(driver, key_combination=3): @@ -144,6 +245,30 @@ def get_output_value(dev_dict): return device_map +def get_opt_flags(model, precision="fp16"): + iree_flags = [] + if len(cmd_opts.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" + ) + if "rocm" in cmd_opts.device: + rocm_args = get_iree_rocm_args() + iree_flags.extend(rocm_args) + if cmd_opts.iree_constant_folding == False: + iree_flags.append("--iree-opt-const-expr-hoisting=False") + iree_flags.append( + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + ) + if cmd_opts.data_tiling == False: + iree_flags.append("--iree-opt-data-tiling=False") + + if "vae" not in model: + # Due to lack of support for multi-reduce, we always collapse reduction + # dims before dispatch formation right now. + iree_flags += ["--iree-flow-collapse-reduction-dims"] + return iree_flags + + def map_device_to_name_path(device, key_combination=3): """Gives the appropriate device data (supported name/path) for user selected execution device @@ -250,6 +375,7 @@ def parse_seed_input(seed_input: str | list | int): "Seed input must be an integer or an array of integers in JSON format" ) + # Generate and return a new seed if the provided one is not in the # supported range (including -1) def sanitize_seed(seed: int | str): @@ -260,6 +386,7 @@ def sanitize_seed(seed: int | str): seed = randint(uint32_min, uint32_max) return seed + # take a seed expression in an input format and convert it to # a list of integers, where possible def parse_seed_input(seed_input: str | list | int): diff --git a/apps/shark_studio/modules/checkpoint_proc.py b/apps/shark_studio/modules/checkpoint_proc.py new file mode 100644 index 0000000000..e924de4640 --- /dev/null +++ b/apps/shark_studio/modules/checkpoint_proc.py @@ -0,0 +1,66 @@ +import os +import json +import re +from pathlib import Path +from omegaconf import OmegaConf + + +def get_path_to_diffusers_checkpoint(custom_weights): + path = Path(custom_weights) + diffusers_path = path.parent.absolute() + diffusers_directory_name = os.path.join("diffusers", path.stem) + complete_path_to_diffusers = diffusers_path / diffusers_directory_name + complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) + path_to_diffusers = complete_path_to_diffusers.as_posix() + return path_to_diffusers + + +def preprocessCKPT(custom_weights, is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) + if next(Path(path_to_diffusers).iterdir(), None): + print("Checkpoint already loaded at : ", path_to_diffusers) + return + else: + print( + "Diffusers' checkpoint will be identified here : ", + path_to_diffusers, + ) + from_safetensors = ( + True if custom_weights.lower().endswith(".safetensors") else False + ) + # EMA weights usually yield higher quality images for inference but + # non-EMA weights have been yielding better results in our case. + # TODO: Add an option `--ema` (`--no-ema`) for users to specify if + # they want to go for EMA weight extraction or not. + extract_ema = False + print( + "Loading diffusers' pipeline from original stable diffusion checkpoint" + ) + num_in_channels = 9 if is_inpaint else 4 + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=custom_weights, + extract_ema=extract_ema, + from_safetensors=from_safetensors, + num_in_channels=num_in_channels, + ) + pipe.save_pretrained(path_to_diffusers) + print("Loading complete") + + +def convert_original_vae(vae_checkpoint): + vae_state_dict = {} + for key in list(vae_checkpoint.keys()): + vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key) + + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/" + "main/configs/stable-diffusion/v1-inference.yaml" + ) + original_config_file = BytesIO(requests.get(config_url).content) + original_config = OmegaConf.load(original_config_file) + vae_config = create_vae_diffusers_config(original_config, image_size=512) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + vae_state_dict, vae_config + ) + return converted_vae_checkpoint diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py index 5fc64c0ccc..d8cf544f81 100644 --- a/apps/shark_studio/modules/embeddings.py +++ b/apps/shark_studio/modules/embeddings.py @@ -1,5 +1,10 @@ +import os +import sys import torch +import json +import safetensors from safetensors.torch import load_file +from apps.shark_studio.api.utils import get_checkpoint_pathfile def processLoRA(model, use_lora, splitting_prefix): @@ -109,3 +114,58 @@ def update_lora_weight(model, use_lora, model_name): return processLoRA(model, use_lora, "lora_te_") except: return None + + +def get_lora_metadata(lora_filename): + # get the metadata from the file + filename = get_checkpoint_pathfile(lora_filename, "lora") + with safetensors.safe_open(filename, framework="pt", device="cpu") as f: + metadata = f.metadata() + + # guard clause for if there isn't any metadata + if not metadata: + return None + + # metadata is a dictionary of strings, the values of the keys we're + # interested in are actually json, and need to be loaded as such + tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}"))) + dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}"))) + tag_dirs = [dir for dir in tag_frequencies.keys()] + + # gather the tag frequency information for all the datasets trained + all_frequencies = {} + for dataset in tag_dirs: + frequencies = sorted( + [entry for entry in tag_frequencies[dataset].items()], + reverse=True, + key=lambda x: x[1], + ) + + # get a figure for the total number of images processed for this dataset + # either then number actually listed or in its dataset_dir entry or + # the highest frequency's number if that doesn't exist + img_count = dataset_dirs.get(dir, {}).get( + "img_count", frequencies[0][1] + ) + + # add the dataset frequencies to the overall frequencies replacing the + # frequency counts on the tags with a percentage/ratio + all_frequencies.update( + [(entry[0], entry[1] / img_count) for entry in frequencies] + ) + + trained_model_id = " ".join( + [ + metadata.get("ss_sd_model_hash", ""), + metadata.get("ss_sd_model_name", ""), + metadata.get("ss_base_model_version", ""), + ] + ).strip() + + # return the topmost of all frequencies in all datasets + return { + "model": trained_model_id, + "frequencies": sorted( + all_frequencies.items(), reverse=True, key=lambda x: x[1] + ), + } diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py index e709facbbf..b5cf28ce47 100644 --- a/apps/shark_studio/modules/img_processing.py +++ b/apps/shark_studio/modules/img_processing.py @@ -1,4 +1,8 @@ -from +import os +import sys +from PIL import Image +from pathlib import Path + # save output images and the inputs corresponding to it. def save_output_img(output_img, img_seed, extra_info=None): @@ -10,43 +14,45 @@ def save_output_img(output_img, img_seed, extra_info=None): generated_imgs_path.mkdir(parents=True, exist_ok=True) csv_path = Path(generated_imgs_path, "imgs_details.csv") - prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15]) + prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15]) out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}" - img_model = args.hf_model_id - if args.ckpt_loc: - img_model = Path(os.path.basename(args.ckpt_loc)).stem + img_model = cmd_opts.hf_model_id + if cmd_opts.ckpt_loc: + img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem img_vae = None - if args.custom_vae: - img_vae = Path(os.path.basename(args.custom_vae)).stem + if cmd_opts.custom_vae: + img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem img_lora = None - if args.use_lora: - img_lora = Path(os.path.basename(args.use_lora)).stem + if cmd_opts.use_lora: + img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem - if args.output_img_format == "jpg": + if cmd_opts.output_img_format == "jpg": out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") output_img.save(out_img_path, quality=95, subsampling=0) else: out_img_path = Path(generated_imgs_path, f"{out_img_name}.png") pngInfo = PngImagePlugin.PngInfo() - if args.write_metadata_to_png: + if cmd_opts.write_metadata_to_png: # Using a conditional expression caused problems, so setting a new # variable for now. - if args.use_hiresfix: - png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}" + if cmd_opts.use_hiresfix: + png_size_text = ( + f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}" + ) else: - png_size_text = f"{args.width}x{args.height}" + png_size_text = f"{cmd_opts.width}x{cmd_opts.height}" pngInfo.add_text( "parameters", - f"{args.prompts[0]}" - f"\nNegative prompt: {args.negative_prompts[0]}" - f"\nSteps: {args.steps}," - f"Sampler: {args.scheduler}, " - f"CFG scale: {args.guidance_scale}, " + f"{cmd_opts.prompts[0]}" + f"\nNegative prompt: {cmd_opts.negative_prompts[0]}" + f"\nSteps: {cmd_opts.steps}," + f"Sampler: {cmd_opts.scheduler}, " + f"CFG scale: {cmd_opts.guidance_scale}, " f"Seed: {img_seed}," f"Size: {png_size_text}, " f"Model: {img_model}, " @@ -56,9 +62,9 @@ def save_output_img(output_img, img_seed, extra_info=None): output_img.save(out_img_path, "PNG", pnginfo=pngInfo) - if args.output_img_format not in ["png", "jpg"]: + if cmd_opts.output_img_format not in ["png", "jpg"]: print( - f"[ERROR] Format {args.output_img_format} is not " + f"[ERROR] Format {cmd_opts.output_img_format} is not " f"supported yet. Image saved as png instead." f"Supported formats: png / jpg" ) @@ -68,18 +74,20 @@ def save_output_img(output_img, img_seed, extra_info=None): # importance for each data point. Something to consider. new_entry = { "VARIANT": img_model, - "SCHEDULER": args.scheduler, - "PROMPT": args.prompts[0], - "NEG_PROMPT": args.negative_prompts[0], + "SCHEDULER": cmd_opts.scheduler, + "PROMPT": cmd_opts.prompts[0], + "NEG_PROMPT": cmd_opts.negative_prompts[0], "SEED": img_seed, - "CFG_SCALE": args.guidance_scale, - "PRECISION": args.precision, - "STEPS": args.steps, - "HEIGHT": args.height - if not args.use_hiresfix - else args.hiresfix_height, - "WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width, - "MAX_LENGTH": args.max_length, + "CFG_SCALE": cmd_opts.guidance_scale, + "PRECISION": cmd_opts.precision, + "STEPS": cmd_opts.steps, + "HEIGHT": cmd_opts.height + if not cmd_opts.use_hiresfix + else cmd_opts.hiresfix_height, + "WIDTH": cmd_opts.width + if not cmd_opts.use_hiresfix + else cmd_opts.hiresfix_width, + "MAX_LENGTH": cmd_opts.max_length, "OUTPUT": out_img_path, "VAE": img_vae, "LORA": img_lora, @@ -95,37 +103,23 @@ def save_output_img(output_img, img_seed, extra_info=None): dictwriter_obj.writerow(new_entry) csv_obj.close() - if args.save_metadata_to_json: + if cmd_opts.save_metadata_to_json: del new_entry["OUTPUT"] json_path = Path(generated_imgs_path, f"{out_img_name}.json") with open(json_path, "w") as f: json.dump(new_entry, f, indent=4) -def get_generation_text_info(seeds, device): - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, " f"device={device}" - text_output += ( - f"\nsteps={args.steps}, " - f"guidance_scale={args.guidance_scale}, " - f"seed={seeds}" - ) - text_output += ( - f"\nsize={args.height}x{args.width}, " - if not args.use_hiresfix - else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, " - ) - text_output += ( - f"batch_count={args.batch_count}, " - f"batch_size={args.batch_size}, " - f"max_length={args.max_length}" - ) +resamplers = { + "Lanczos": Image.Resampling.LANCZOS, + "Nearest Neighbor": Image.Resampling.NEAREST, + "Bilinear": Image.Resampling.BILINEAR, + "Bicubic": Image.Resampling.BICUBIC, + "Hamming": Image.Resampling.HAMMING, + "Box": Image.Resampling.BOX, +} - return text_output +resampler_list = resamplers.keys() # For stencil, the input image can be of any size, but we need to ensure that @@ -133,7 +127,7 @@ def get_generation_text_info(seeds, device): # Both width and height should be in the range of [128, 768] and multiple of 8. # This utility function performs the transformation on the input image while # also maintaining the aspect ratio before sending it to the stencil pipeline. -def resize_stencil(image: Image.Image, width, height): +def resize_stencil(image: Image.Image, width, height, resampler_type=None): aspect_ratio = width / height min_size = min(width, height) if min_size < 128: @@ -166,6 +160,9 @@ def resize_stencil(image: Image.Image, width, height): n_height = height // 8 n_width *= 8 n_height *= 8 - new_image = image.resize((n_width, n_height)) + if resampler_type in resamplers: + resampler = resamplers[resampler_type] + else: + resampler = resamplers["Nearest Neighbor"] + new_image = image.resize((n_width, n_height), resampler=resampler) return new_image, n_width, n_height - diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py new file mode 100644 index 0000000000..c087175de4 --- /dev/null +++ b/apps/shark_studio/modules/pipeline.py @@ -0,0 +1,71 @@ +from shark.iree_utils.compile_utils import get_iree_compiled_module + + +class SharkPipelineBase: + # This class is a lightweight base for managing an + # inference API class. It should provide methods for: + # - compiling a set (model map) of torch IR modules + # - preparing weights for an inference job + # - loading weights for an inference job + # - utilites like benchmarks, tests + + def __init__( + self, + model_map: dict, + device: str, + import_mlir: bool = True, + ): + self.model_map = model_map + self.device = device + self.import_mlir = import_mlir + + def import_torch_ir(self, base_model_id): + for submodel in self.model_map: + hf_id = ( + submodel["custom_hf_id"] + if submodel["custom_hf_id"] + else base_model_id + ) + torch_ir = submodel["initializer"]( + hf_id, **submodel["init_kwargs"], compile_to="torch" + ) + submodel["tempfile_name"] = get_resource_path( + f"{submodel}.torch.tempfile" + ) + with open(submodel["tempfile_name"], "w+") as f: + f.write(torch_ir) + del torch_ir + gc.collect() + + def load_vmfb(self, submodel): + if self.iree_module_dict[submodel]: + print( + f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}" + ) + elif self.model_map[submodel]["tempfile_name"]: + submodel["tempfile_name"] + + return submodel["vmfb"] + + def merge_custom_map(self, custom_model_map): + for submodel in custom_model_map: + for key in submodel: + self.model_map[submodel][key] = key + print(self.model_map) + + def get_compiled_map(self, device) -> None: + # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". + for submodel in self.model_map: + if not self.iree_module_dict[submodel][vmfb]: + self.iree_module_dict[submodel] = get_iree_compiled_module( + submodel.tempfile_name, + device=self.device, + frontend="torch", + ) + # TODO: delete the temp file + + def run(self, submodel, inputs): + return + + def safe_name(name): + return name.replace("/", "_").replace("-", "_") diff --git a/apps/shark_studio/api/schedulers.py b/apps/shark_studio/modules/schedulers.py similarity index 100% rename from apps/shark_studio/api/schedulers.py rename to apps/shark_studio/modules/schedulers.py diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index 88434ff580..dfb166a52e 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -2,7 +2,7 @@ import os from pathlib import Path -from apps.stable_diffusion.src.utils.resamplers import resampler_list +from apps.shark_studio.modules.img_processing import resampler_list def path_expand(s): @@ -36,7 +36,7 @@ def is_valid_file(arg): nargs="+", default=[ "a photo taken of the front of a super-car drifting on a road near " - "mountains at high speeds with smokes coming off the tires, front " + "mountains at high speeds with smoke coming off the tires, front " "angle, front point of view, trees in the mountains of the " "background, ((sharp focus))" ], @@ -306,21 +306,6 @@ def is_valid_file(arg): "downloads the model from shark_tank.", ) -p.add_argument( - "--load_vmfb", - default=True, - action=argparse.BooleanOptionalAction, - help="Attempts to load the model from a precompiled flat-buffer " - "and compiles + saves it if not found.", -) - -p.add_argument( - "--save_vmfb", - default=False, - action=argparse.BooleanOptionalAction, - help="Saves the compiled flat-buffer to the local directory.", -) - p.add_argument( "--use_tuned", default=False, @@ -446,7 +431,7 @@ def is_valid_file(arg): ) p.add_argument( - "--ondemand", + "--lowvram", default=False, action=argparse.BooleanOptionalAction, help="Load and unload models for low VRAM.", @@ -469,10 +454,10 @@ def is_valid_file(arg): ) p.add_argument( - "--autogen", - type=bool, - default="False", - help="Only used for a gradio workaround.", + "--custom_model_map", + type=str, + default="", + help="path to custom model map to import. This should be a .json file", ) ############################################################################## # IREE - Vulkan supported flags @@ -612,6 +597,13 @@ def is_valid_file(arg): # Web UI flags ############################################################################## +p.add_argument( + "--webui", + default=True, + action=argparse.BooleanOptionalAction, + help="controls whether the webui is launched.", +) + p.add_argument( "--progress_bar", default=True, @@ -764,8 +756,8 @@ def is_valid_file(arg): "or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name", ) -args, unknown = p.parse_known_args() -if args.import_debug: +cmd_opts, unknown = p.parse_known_args() +if cmd_opts.import_debug: os.environ["IREE_SAVE_TEMPS"] = os.path.join( - os.getcwd(), args.hf_model_id.replace("/", "_") + os.getcwd(), cmd_opts.hf_model_id.replace("/", "_") ) diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py index c5fafd7ad2..80399505c4 100644 --- a/apps/shark_studio/web/api/compat.py +++ b/apps/shark_studio/web/api/compat.py @@ -15,7 +15,7 @@ from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder -from apps.shark_studio. import sd_samplers, postprocessing, errors, restart +from apps.shark_studio.modules.img_processing import sampler_list from sdapi_v1 import shark_sd_api from api.llm import chat_api @@ -26,15 +26,21 @@ def decode_base64_to_image(encoding): raise HTTPException(status_code=500, detail="Requests not allowed") if opts.api_forbid_local_requests and not verify_url(encoding): - raise HTTPException(status_code=500, detail="Request to local resource not allowed") + raise HTTPException( + status_code=500, detail="Request to local resource not allowed" + ) - headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} + headers = ( + {"user-agent": opts.api_useragent} if opts.api_useragent else {} + ) response = requests.get(encoding, timeout=30, headers=headers) try: image = Image.open(BytesIO(response.content)) return image except Exception as e: - raise HTTPException(status_code=500, detail="Invalid image url") from e + raise HTTPException( + status_code=500, detail="Invalid image url" + ) from e if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] @@ -42,32 +48,54 @@ def decode_base64_to_image(encoding): image = Image.open(BytesIO(base64.b64decode(encoding))) return image except Exception as e: - raise HTTPException(status_code=500, detail="Invalid encoded image") from e + raise HTTPException( + status_code=500, detail="Invalid encoded image" + ) from e def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - - if opts.samples_format.lower() == 'png': + if opts.samples_format.lower() == "png": use_metadata = False metadata = PngImagePlugin.PngInfo() for key, value in image.info.items(): if isinstance(key, str) and isinstance(value, str): metadata.add_text(key, value) use_metadata = True - image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + image.save( + output_bytes, + format="PNG", + pnginfo=(metadata if use_metadata else None), + quality=opts.jpeg_quality, + ) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): if image.mode == "RGBA": image = image.convert("RGB") - parameters = image.info.get('parameters', None) - exif_bytes = piexif.dump({ - "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } - }) + parameters = image.info.get("parameters", None) + exif_bytes = piexif.dump( + { + "Exif": { + piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump( + parameters or "", encoding="unicode" + ) + } + } + ) if opts.samples_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + image.save( + output_bytes, + format="JPEG", + exif=exif_bytes, + quality=opts.jpeg_quality, + ) else: - image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + image.save( + output_bytes, + format="WEBP", + exif=exif_bytes, + quality=opts.jpeg_quality, + ) else: raise HTTPException(status_code=500, detail="Invalid image format") @@ -80,10 +108,11 @@ def encode_pil_to_base64(image): def api_middleware(app: FastAPI): rich_available = False try: - if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: + if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: import anyio # importing just so it can be placed on silent list import starlette # importing just so it can be placed on silent list from rich.console import Console + console = Console() rich_available = True except Exception: @@ -95,35 +124,49 @@ async def log_and_time(req: Request, call_next): res: Response = await call_next(req) duration = str(round(time.time() - ts, 4)) res.headers["X-Process-Time"] = duration - endpoint = req.scope.get('path', 'err') - if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): - print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( - t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), - code=res.status_code, - ver=req.scope.get('http_version', '0.0'), - cli=req.scope.get('client', ('0:0.0.0', 0))[0], - prot=req.scope.get('scheme', 'err'), - method=req.scope.get('method', 'err'), - endpoint=endpoint, - duration=duration, - )) + endpoint = req.scope.get("path", "err") + if shared.cmd_opts.api_log and endpoint.startswith("/sdapi"): + print( + "API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format( + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get("http_version", "0.0"), + cli=req.scope.get("client", ("0:0.0.0", 0))[0], + prot=req.scope.get("scheme", "err"), + method=req.scope.get("method", "err"), + endpoint=endpoint, + duration=duration, + ) + ) return res def handle_exception(request: Request, e: Exception): err = { "error": type(e).__name__, - "detail": vars(e).get('detail', ''), - "body": vars(e).get('body', ''), + "detail": vars(e).get("detail", ""), + "body": vars(e).get("body", ""), "errors": str(e), } - if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if not isinstance( + e, HTTPException + ): # do not print backtrace on known httpexceptions message = f"API error: {request.method}: {request.url} {err}" if rich_available: print(message) - console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) + console.print_exception( + show_locals=True, + max_frames=2, + extra_lines=1, + suppress=[anyio, starlette], + word_wrap=False, + width=min([console.width, 200]), + ) else: errors.report(message, exc_info=True) - return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) + return JSONResponse( + status_code=vars(e).get("status_code", 500), + content=jsonable_encoder(err), + ) @app.middleware("http") async def exception_handling(request: Request, call_next): @@ -143,52 +186,48 @@ async def http_exception_handler(request: Request, e: HTTPException): class ApiCompat: def __init__(self, queue_lock: Lock): - self.router = APIRouter() self.app = FastAPI() self.queue_lock = queue_lock api_middleware(self.app) self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["post"]) self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["post"]) - #self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"]) - #self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) - #self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) - #self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) - #self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) - #self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) - #self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) - #self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - #self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) - #self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) - #self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) - #self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) - #self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) - #self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) - #self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) - #self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) - #self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) - #self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) - #self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) - #self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) - #self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) - #self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) - #self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) - #self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) - #self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) - #self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) - #self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) - #self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) - #self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) - #self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - + # self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"]) + # self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) + # self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) + # self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) + # self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) + # self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) + # self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) + # self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + # self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) + # self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) + # self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + # self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) + # self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + # self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) + # self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) + # self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) + # self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) + # self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + # self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) + # self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + # self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) + # self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) + # self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) + # self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) + # self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) # chat APIs needed for compatibility with multiple extensions using OpenAI API - self.add_api_route( - "/v1/chat/completions", chat_api, methods=["post"] - ) + self.add_api_route("/v1/chat/completions", chat_api, methods=["post"]) self.add_api_route("/v1/completions", chat_api, methods=["post"]) self.add_api_route("/chat/completions", chat_api, methods=["post"]) self.add_api_route("/completions", chat_api, methods=["post"]) @@ -196,16 +235,26 @@ def __init__(self, queue_lock: Lock): "/v1/engines/codegen/completions", chat_api, methods=["post"] ) if studio.cmd_opts.api_server_stop: - self.add_api_route("/sdapi/v1/server-kill", self.kill_studio, methods=["POST"]) - self.add_api_route("/sdapi/v1/server-restart", self.restart_studio, methods=["POST"]) - self.add_api_route("/sdapi/v1/server-stop", self.stop_studio, methods=["POST"]) + self.add_api_route( + "/sdapi/v1/server-kill", self.kill_studio, methods=["POST"] + ) + self.add_api_route( + "/sdapi/v1/server-restart", + self.restart_studio, + methods=["POST"], + ) + self.add_api_route( + "/sdapi/v1/server-stop", self.stop_studio, methods=["POST"] + ) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - def add_api_route(self, path:str, endpoint, **kwargs): + def add_api_route(self, path: str, endpoint, **kwargs): if studio.cmd_opts.api_auth: - return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs + return self.app.add_api_route( + path, endpoint, dependencies=[Depends(self.auth)], **kwargs + ) return self.app.add_api_route(path, endpoint, **kwargs) def refresh_checkpoints(self): @@ -231,7 +280,13 @@ def skip(self): def launch(self, server_name, port, root_path): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, root_path=root_path) + uvicorn.run( + self.app, + host=server_name, + port=port, + timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, + root_path=root_path, + ) def kill_studio(self): restart.stop_program() @@ -246,7 +301,7 @@ def preprocess(self, args: dict): studio.state.begin(job="preprocess") preprocess(**args) studio.state.end() - return models.PreprocessResponse(info='preprocess complete') + return models.PreprocessResponse(info="preprocess complete") except: studio.state.end() diff --git a/apps/shark_studio/web/configs/foo.json b/apps/shark_studio/web/configs/foo.json new file mode 100644 index 0000000000..0967ef424b --- /dev/null +++ b/apps/shark_studio/web/configs/foo.json @@ -0,0 +1 @@ +{} diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index d678d0b647..6ff90b4dbc 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -3,12 +3,13 @@ import time import sys import logging +import apps.shark_studio.api.initializers as initialize from ui.chat import chat_element from ui.sd import sd_element from ui.outputgallery import outputgallery_element -from modules import timer, initialize +from apps.shark_studio.modules import timer startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -72,15 +73,13 @@ def launch_webui(address): def webui(): - from apps.shark_studio.shared_cmd_options import cmd_opts + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts logging.basicConfig(level=logging.DEBUG) launch_api = cmd_opts.api initialize.initialize() - from modules import shared, ui_tempdir, script_callbacks, ui, progress - # required to do multiprocessing in a pyinstaller freeze freeze_support() @@ -131,16 +130,23 @@ def webui(): # Setup to use shark_tmp for gradio's temporary image files and clear any # existing temporary images there if they exist. Then we can import gradio. # It has to be in this order or gradio ignores what we've set up. - from apps.shark_studio.web.initializers import ( - config_gradio_tmp_imgs_folder, - create_custom_models_folders, + from apps.shark_studio.web.utils.tmp_configs import ( + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, + ) + from apps.shark_studio.api.utils import ( + create_checkpoint_folders, ) - config_gradio_tmp_imgs_folder() import gradio as gr + config_tmp() + clear_tmp_mlir() + clear_tmp_imgs() + # Create custom models folders if they don't exist - create_custom_models_folders() + create_checkpoint_folders() def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" @@ -151,10 +157,7 @@ def resource_path(relative_path): dark_theme = resource_path("ui/css/sd_dark_theme.css") - from apps.shark_studio.web.ui import load_ui_from_script - - # init global sd pipeline and config - studio.state._init() + # from apps.shark_studio.web.ui import load_ui_from_script def register_button_click(button, selectedid, inputs, outputs): button.click( @@ -211,9 +214,9 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): if __name__ == "__main__": - from apps.shark_studio.shared_cmd_options import cmd_opts + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - if cmd_opts.nowebui: + if cmd_opts.webui == False: api_only() else: webui() diff --git a/apps/shark_studio/web/ui/common_events.py b/apps/shark_studio/web/ui/common_events.py new file mode 100644 index 0000000000..37555ed7ee --- /dev/null +++ b/apps/shark_studio/web/ui/common_events.py @@ -0,0 +1,55 @@ +from apps.shark_studio.web.ui.utils import ( + HSLHue, + hsl_color, +) +from apps.shark_studio.modules.embeddings import get_lora_metadata + + +# Answers HTML to show the most frequent tags used when a LoRA was trained, +# taken from the metadata of its .safetensors file. +def lora_changed(lora_file): + # tag frequency percentage, that gets maximum amount of the staring hue + TAG_COLOR_THRESHOLD = 0.55 + # tag frequency percentage, above which a tag is displayed + TAG_DISPLAY_THRESHOLD = 0.65 + # template for the html used to display a tag + TAG_HTML_TEMPLATE = '{tag}' + + if lora_file == "None": + return ["
No LoRA selected
"] + elif not lora_file.lower().endswith(".safetensors"): + return [ + "
Only metadata queries for .safetensors files are currently supported
" + ] + else: + metadata = get_lora_metadata(lora_file) + if metadata: + frequencies = metadata["frequencies"] + return [ + "".join( + [ + f'
Trained against weights in: {metadata["model"]}
' + ] + + [ + TAG_HTML_TEMPLATE.format( + color=hsl_color( + (tag[1] - TAG_COLOR_THRESHOLD) + / (1 - TAG_COLOR_THRESHOLD), + start=HSLHue.RED, + end=HSLHue.GREEN, + ), + tag=tag[0], + ) + for tag in frequencies + if tag[1] > TAG_DISPLAY_THRESHOLD + ], + ) + ] + elif metadata is None: + return [ + "
This LoRA does not publish tag frequency metadata
" + ] + else: + return [ + "
This LoRA has empty tag frequency metadata, or we could not parse it
" + ] diff --git a/apps/shark_studio/web/ui/css/sd_dark_theme.css b/apps/shark_studio/web/ui/css/sd_dark_theme.css new file mode 100644 index 0000000000..5686f0868c --- /dev/null +++ b/apps/shark_studio/web/ui/css/sd_dark_theme.css @@ -0,0 +1,324 @@ +/* +Apply Gradio dark theme to the default Gradio theme. +Procedure to upgrade the dark theme: +- Using your browser, visit http://localhost:8080/?__theme=dark +- Open your browser inspector, search for the .dark css class +- Copy .dark class declarations, apply them here into :root +*/ + +:root { + --body-background-fill: var(--background-fill-primary); + --body-text-color: var(--neutral-100); + --color-accent-soft: var(--neutral-700); + --background-fill-primary: var(--neutral-950); + --background-fill-secondary: var(--neutral-900); + --border-color-accent: var(--neutral-600); + --border-color-primary: var(--neutral-700); + --link-text-color-active: var(--secondary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --body-text-color-subdued: var(--neutral-400); + --shadow-spread: 1px; + --block-background-fill: var(--neutral-800); + --block-border-color: var(--border-color-primary); + --block_border_width: None; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: None; + --block-label-text-color: var(--neutral-200); + --block_shadow: None; + --block_title_background_fill: None; + --block_title_border_color: None; + --block_title_border_width: None; + --block-title-text-color: var(--neutral-200); + --panel-background-fill: var(--background-fill-secondary); + --panel-border-color: var(--border-color-primary); + --panel_border_width: None; + --checkbox-background-color: var(--neutral-800); + --checkbox-background-color-focus: var(--checkbox-background-color); + --checkbox-background-color-hover: var(--checkbox-background-color); + --checkbox-background-color-selected: var(--secondary-600); + --checkbox-border-color: var(--neutral-700); + --checkbox-border-color-focus: var(--secondary-500); + --checkbox-border-color-hover: var(--neutral-600); + --checkbox-border-color-selected: var(--secondary-600); + --checkbox-border-width: var(--input-border-width); + --checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); + --checkbox-label-border-color: var(--border-color-primary); + --checkbox-label-border-color-hover: var(--checkbox-label-border-color); + --checkbox-label-border-width: var(--input-border-width); + --checkbox-label-text-color: var(--body-text-color); + --checkbox-label-text-color-selected: var(--checkbox-label-text-color); + --error-background-fill: var(--background-fill-primary); + --error-border-color: var(--border-color-primary); + --error_border_width: None; + --error-text-color: #ef4444; + --input-background-fill: var(--neutral-800); + --input-background-fill-focus: var(--secondary-600); + --input-background-fill-hover: var(--input-background-fill); + --input-border-color: var(--border-color-primary); + --input-border-color-focus: var(--neutral-700); + --input-border-color-hover: var(--input-border-color); + --input_border_width: None; + --input-placeholder-color: var(--neutral-500); + --input_shadow: None; + --input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset); + --loader_color: None; + --slider_color: None; + --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600)); + --table-border-color: var(--neutral-700); + --table-even-background-fill: var(--neutral-950); + --table-odd-background-fill: var(--neutral-900); + --table-row-focus: var(--color-accent-soft); + --button-border-width: var(--input-border-width); + --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); + --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); + --button-cancel-border-color: #dc2626; + --button-cancel-border-color-hover: var(--button-cancel-border-color); + --button-cancel-text-color: white; + --button-cancel-text-color-hover: var(--button-cancel-text-color); + --button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600)); + --button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500)); + --button-primary-border-color: var(--primary-500); + --button-primary-border-color-hover: var(--button-primary-border-color); + --button-primary-text-color: white; + --button-primary-text-color-hover: var(--button-primary-text-color); + --button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700)); + --button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600)); + --button-secondary-border-color: var(--neutral-600); + --button-secondary-border-color-hover: var(--button-secondary-border-color); + --button-secondary-text-color: white; + --button-secondary-text-color-hover: var(--button-secondary-text-color); + --block-border-width: 1px; + --block-label-border-width: 1px; + --form-gap-width: 1px; + --error-border-width: 1px; + --input-border-width: 1px; +} + +/* SHARK theme */ +body { + background-color: var(--background-fill-primary); +} + +.generating.svelte-zlszon.svelte-zlszon { + border: none; +} + +.generating { + border: none !important; +} + +#chatbot { + height: 100% !important; +} + +/* display in full width for desktop devices */ +@media (min-width: 1536px) +{ + .gradio-container { + max-width: var(--size-full) !important; + } +} + +.gradio-container .contain { + padding: 0 var(--size-4) !important; +} + +#top_logo { + color: transparent; + background-color: transparent; + border-radius: 0 !important; + border: 0; +} + +#ui_title { + padding: var(--size-2) 0 0 var(--size-1); +} + +#demo_title_outer { + border-radius: 0; +} + +#prompt_box_outer div:first-child { + border-radius: 0 !important +} + +#prompt_box textarea, #negative_prompt_box textarea { + background-color: var(--background-fill-primary) !important; +} + +#prompt_examples { + margin: 0 !important; +} + +#prompt_examples svg { + display: none !important; +} + +#ui_body { + padding: var(--size-2) !important; + border-radius: 0.5em !important; +} + +#img_result+div { + display: none !important; +} + +footer { + display: none !important; +} + +#gallery + div { + border-radius: 0 !important; +} + +/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */ +#gallery .thumbnail-item.thumbnail-lg { + aspect-ratio: unset; + max-height: calc(55vh - (2 * var(--spacing-lg))); +} +@media (min-width: 1921px) { + /* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */ + #gallery .grid-wrap, #gallery .preview{ + min-height: calc(768px + 4px + var(--size-14)); + max-height: calc(768px + 4px + var(--size-14)); + } + /* Limit height to 768px_height + 2px_margin_height for the thumbnails */ + #gallery .thumbnail-item.thumbnail-lg { + max-height: 770px !important; + } +} +/* Don't upscale when viewing in solo image mode */ +#gallery .preview img { + object-fit: scale-down; +} +/* Navbar images in cover mode*/ +#gallery .preview .thumbnail-item img { + object-fit: cover; +} + +/* Limit the stable diffusion text output height */ +#std_output textarea { + max-height: 215px; +} + +/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */ +#gallery .wrap.default { + pointer-events: none; +} + +/* Import Png info box */ +#txt2img_prompt_image { + height: var(--size-32) !important; +} + +/* Hide "remove buttons" from ui dropdowns */ +#custom_model .token-remove.remove-all, +#lora_weights .token-remove.remove-all, +#scheduler .token-remove.remove-all, +#device .token-remove.remove-all, +#stencil_model .token-remove.remove-all { + display: none; +} + +/* Hide selected items from ui dropdowns */ +#custom_model .options .item .inner-item, +#scheduler .options .item .inner-item, +#device .options .item .inner-item, +#stencil_model .options .item .inner-item { + display:none; +} + +/* workarounds for container=false not currently working for dropdowns */ +.dropdown_no_container { + padding: 0 !important; +} + +#output_subdir_container :first-child { + border: none; +} + +/* reduced animation load when generating */ +.generating { + animation-play-state: paused !important; +} + +/* better clarity when progress bars are minimal */ +.meta-text { + background-color: var(--block-label-background-fill); +} + +/* lora tag pills */ +.lora-tags { + border: 1px solid var(--border-color-primary); + color: var(--block-info-text-color) !important; + padding: var(--block-padding); +} + +.lora-tag { + display: inline-block; + height: 2em; + color: rgb(212 212 212) !important; + margin-right: 5pt; + margin-bottom: 5pt; + padding: 2pt 5pt; + border-radius: 5pt; + white-space: nowrap; +} + +.lora-model { + margin-bottom: var(--spacing-lg); + color: var(--block-info-text-color) !important; + line-height: var(--line-sm); +} + +/* output gallery tab */ +.output_parameters_dataframe table.table { + /* works around a gradio bug that always shows scrollbars */ + overflow: clip auto; +} + +.output_parameters_dataframe tbody td { + font-size: small; + line-height: var(--line-xs); +} + +.output_icon_button { + max-width: 30px; + align-self: end; + padding-bottom: 8px; +} + +.outputgallery_sendto { + min-width: 7em !important; +} + +/* output gallery should take up most of the viewport height regardless of image size/number */ +#outputgallery_gallery .fixed-height { + min-height: 89vh !important; +} + +/* don't stretch non-square images to be square, breaking their aspect ratio */ +#outputgallery_gallery .thumbnail-item.thumbnail-lg > img { + object-fit: contain !important; +} + +/* centered logo for when there are no images */ +#top_logo.logo_centered { + height: 100%; + width: 100%; +} + +#top_logo.logo_centered img{ + object-fit: scale-down; + position: absolute; + width: 80%; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); +} diff --git a/apps/shark_studio/web/ui/logos/nod-icon.png b/apps/shark_studio/web/ui/logos/nod-icon.png new file mode 100644 index 0000000000..29f7e32220 Binary files /dev/null and b/apps/shark_studio/web/ui/logos/nod-icon.png differ diff --git a/apps/shark_studio/web/ui/logos/nod-logo.png b/apps/shark_studio/web/ui/logos/nod-logo.png new file mode 100644 index 0000000000..4727e15a19 Binary files /dev/null and b/apps/shark_studio/web/ui/logos/nod-logo.png differ diff --git a/apps/shark_studio/web/ui/outputgallery.py b/apps/shark_studio/web/ui/outputgallery.py new file mode 100644 index 0000000000..dd58541aae --- /dev/null +++ b/apps/shark_studio/web/ui/outputgallery.py @@ -0,0 +1,416 @@ +import glob +import gradio as gr +import os +import subprocess +import sys +from PIL import Image + +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.api.utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, +) +from apps.shark_studio.web.ui.utils import nodlogo_loc +from apps.shark_studio.web.utils.metadata import displayable_metadata + +# -- Functions for file, directory and image info querying + +output_dir = get_generated_imgs_path() + + +def outputgallery_filenames(subdir) -> list[str]: + new_dir_path = os.path.join(output_dir, subdir) + if os.path.exists(new_dir_path): + filenames = [ + glob.glob(new_dir_path + "/" + ext) + for ext in ("*.png", "*.jpg", "*.jpeg") + ] + + return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True) + else: + return [] + + +def output_subdirs() -> list[str]: + # Gets a list of subdirectories of output_dir and below, as relative paths. + relative_paths = [ + os.path.relpath(entry[0], output_dir) + for entry in os.walk( + output_dir, followlinks=cmd_opts.output_gallery_followlinks + ) + ] + + # It is less confusing to always including the subdir that will take any + # images generated today even if it doesn't exist yet + if get_generated_imgs_todays_subdir() not in relative_paths: + relative_paths.append(get_generated_imgs_todays_subdir()) + + # sort subdirectories so that the date named ones we probably + # created in this or previous sessions come first, sorted with the most + # recent first. Other subdirs are listed after. + generated_paths = sorted( + [path for path in relative_paths if path.isnumeric()], reverse=True + ) + result_paths = generated_paths + sorted( + [ + path + for path in relative_paths + if (not path.isnumeric()) and path != "." + ] + ) + + return result_paths + + +# --- Define UI layout for Gradio + +with gr.Blocks() as outputgallery_element: + nod_logo = Image.open(nodlogo_loc) + + with gr.Row(elem_id="outputgallery_gallery"): + # needed to workaround gradio issue: + # https://github.com/gradio-app/gradio/issues/2907 + dev_null = gr.Textbox("", visible=False) + + gallery_files = gr.State(value=[]) + subdirectory_paths = gr.State(value=[]) + + with gr.Column(scale=6): + logo = gr.Image( + label="Getting subdirectories...", + value=nod_logo, + interactive=False, + visible=True, + show_label=True, + elem_id="top_logo", + elem_classes="logo_centered", + show_download_button=False, + ) + + gallery = gr.Gallery( + label="", + value=gallery_files.value, + visible=False, + show_label=True, + columns=4, + ) + + with gr.Column(scale=4): + with gr.Group(): + with gr.Row(): + with gr.Column( + scale=15, + min_width=160, + elem_id="output_subdir_container", + ): + subdirectories = gr.Dropdown( + label=f"Subdirectories of {output_dir}", + type="value", + choices=subdirectory_paths.value, + value="", + interactive=True, + elem_classes="dropdown_no_container", + allow_custom_value=True, + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + open_subdir = gr.Button( + variant="secondary", + value="\U0001F5C1", # unicode open folder + interactive=False, + size="sm", + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + refresh = gr.Button( + variant="secondary", + value="\u21BB", # unicode clockwise arrow circle + size="sm", + ) + + image_columns = gr.Slider( + label="Columns shown", value=4, minimum=1, maximum=16, step=1 + ) + outputgallery_filename = gr.Textbox( + label="Filename", + value="None", + interactive=False, + show_copy_button=True, + ) + + with gr.Accordion( + label="Parameter Information", open=False + ) as parameters_accordian: + image_parameters = gr.DataFrame( + headers=["Parameter", "Value"], + col_count=2, + wrap=True, + elem_classes="output_parameters_dataframe", + value=[["Status", "No image selected"]], + interactive=True, + ) + + with gr.Accordion(label="Send To", open=True): + with gr.Row(): + outputgallery_sendto_sd = gr.Button( + value="Stable Diffusion", + interactive=False, + elem_classes="outputgallery_sendto", + size="sm", + ) + + # --- Event handlers + + def on_clear_gallery(): + return [ + gr.Gallery( + value=[], + visible=False, + ), + gr.Image( + visible=True, + ), + ] + + def on_image_columns_change(columns): + return gr.Gallery(columns=columns) + + def on_select_subdir(subdir) -> list: + # evt.value is the subdirectory name + new_images = outputgallery_filenames(subdir) + new_label = ( + f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" + ) + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_open_subdir(subdir): + subdir_path = os.path.normpath(os.path.join(output_dir, subdir)) + + if os.path.isdir(subdir_path): + if sys.platform == "linux": + subprocess.run(["xdg-open", subdir_path]) + elif sys.platform == "darwin": + subprocess.run(["open", subdir_path]) + elif sys.platform == "win32": + os.startfile(subdir_path) + + def on_refresh(current_subdir: str) -> list: + # get an up-to-date subdirectory list + refreshed_subdirs = output_subdirs() + # get the images using either the current subdirectory or the most + # recent valid one + new_subdir = ( + current_subdir + if current_subdir in refreshed_subdirs + else refreshed_subdirs[0] + ) + new_images = outputgallery_filenames(new_subdir) + new_label = ( + f"{len(new_images)} images in " + f"{os.path.join(output_dir, new_subdir)}" + ) + + return [ + gr.Dropdown( + choices=refreshed_subdirs, + value=new_subdir, + ), + refreshed_subdirs, + new_images, + gr.Gallery( + value=new_images, label=new_label, visible=len(new_images) > 0 + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_new_image(subdir, subdir_paths, status) -> list: + # prevent error triggered when an image generates before the tab + # has even been selected + subdir_paths = ( + subdir_paths + if len(subdir_paths) > 0 + else [get_generated_imgs_todays_subdir()] + ) + + # only update if the current subdir is the most recent one as + # new images only go there + if subdir_paths[0] == subdir: + new_images = outputgallery_filenames(subdir) + new_label = ( + f"{len(new_images)} images in " + f"{os.path.join(output_dir, subdir)} - {status}" + ) + + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + else: + # otherwise change nothing, + # (only untyped gradio gr.update() does this) + return [gr.update(), gr.update(), gr.update()] + + def on_select_image(images: list[str], evt: gr.SelectData) -> list: + # evt.index is an index into the full list of filenames for + # the current subdirectory + filename = images[evt.index] + params = displayable_metadata(filename) + + if params: + if params["source"] == "missing": + return [ + "Could not find this image file, refresh the gallery and update the images", + [["Status", "File missing"]], + ] + else: + return [ + filename, + list(map(list, params["parameters"].items())), + ] + + return [ + filename, + [["Status", "No parameters found"]], + ] + + def on_outputgallery_filename_change(filename: str) -> list: + exists = filename != "None" and os.path.exists(filename) + return [ + # disable or enable each of the sendto button based on whether + # an image is selected + gr.Button(interactive=exists), + ] + + # The time first our tab is selected we need to do an initial refresh + # to populate the subdirectory select box and the images from the most + # recent subdirectory. + # + # We do it at this point rather than setting this up in the controls' + # definitions as when you refresh the browser you always get what was + # *initially* set, which won't include any new subdirectories or images + # that might have created since the application was started. Doing it + # this way means a browser refresh/reload always gets the most + # up-to-date data. + def on_select_tab(subdir_paths, request: gr.Request): + local_client = request.headers["host"].startswith( + "127.0.0.1:" + ) or request.headers["host"].startswith("localhost:") + + if len(subdir_paths) == 0: + return on_refresh("") + [gr.update(interactive=local_client)] + else: + return ( + # Change nothing, (only untyped gr.update() does this) + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + # clearing images when we need to completely change what's in the + # gallery avoids current images being shown replacing piecemeal and + # prevents weirdness and errors if the user selects an image during the + # replacement phase. + clear_gallery = dict( + fn=on_clear_gallery, + inputs=None, + outputs=[gallery, logo], + queue=False, + ) + + subdirectories.select(**clear_gallery).then( + on_select_subdir, + [subdirectories], + [gallery_files, gallery, logo], + queue=False, + ) + + open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False) + + refresh.click(**clear_gallery).then( + on_refresh, + [subdirectories], + [subdirectories, subdirectory_paths, gallery_files, gallery, logo], + queue=False, + ) + + image_columns.change( + fn=on_image_columns_change, + inputs=[image_columns], + outputs=[gallery], + queue=False, + ) + + gallery.select( + on_select_image, + [gallery_files], + [outputgallery_filename, image_parameters], + queue=False, + ) + + outputgallery_filename.change( + on_outputgallery_filename_change, + [outputgallery_filename], + [ + outputgallery_sendto_sd, + ], + queue=False, + ) + + # We should have been given the .select function for our tab, so set it up + def outputgallery_tab_select(select): + select( + fn=on_select_tab, + inputs=[subdirectory_paths], + outputs=[ + subdirectories, + subdirectory_paths, + gallery_files, + gallery, + logo, + open_subdir, + ], + queue=False, + ) + + # We should have been passed a list of components on other tabs that update + # when a new image has generated on that tab, so set things up so the user + # will see that new image if they are looking at today's subdirectory + def outputgallery_watch(components: gr.Textbox): + for component in components: + component.change( + on_new_image, + inputs=[subdirectories, subdirectory_paths, component], + outputs=[gallery_files, gallery, logo], + queue=False, + ) diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 1be5dc89fe..f26c7967e3 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -24,130 +24,31 @@ ) from apps.shark_studio.api.sd import ( sd_model_map, - StableDiffusion, -) -from apps.shark_studio.api.schedulers import ( - scheduler_model_map, + shark_sd_fn, + cancel_sd, ) from apps.shark_studio.api.controlnet import ( preprocessor_model_map, - control_adapter_model_map, PreprocessorModel, + cnet_preview, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, ) from apps.shark_studio.modules.img_processing import ( resampler_list, resize_stencil, ) +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from apps.shark_studio.web.ui.utils import ( - get_generation_text_info, nodlogo_loc, ) +from apps.shark_studio.web.utils.state import ( + get_generation_text_info, + status_label, +) from apps.shark_studio.web.ui.common_events import lora_changed -sd_pipe = None - - -# NOTE: Each `hf_model_id` should have its own starting configuration. - -# model_vmfb_key = "" - -def shark_sd_fn( - prompt: str, - negative_prompt: str, - image_dict, - height: int, - width: int, - steps: int, - strength: float, - guidance_scale: float, - seed: str | int, - batch_count: int, - batch_size: int, - scheduler: str, - base_model_id: str, - custom_checkpoints: str, - custom_vae: str, - precision: str, - device: str, - lora_weights: str | list, - lora_hf_ids: str | list, - ondemand: bool, - repeatable_seeds: bool, - resample_type: str, - control_mode: str, - stencils: list, - images: list, - preprocessed_hints: list, - progress=gr.Progress(), -): - - # Handling gradio ImageEditor datatypes so we have unified inputs to the SD API - for i, stencil in enumerate(stencils): - if images[i] is None and stencil is not None: - continue - elif stencil is None and any(img is not None for img in [images[i], preprocessed_hints[i]]): - images[i] = None - preprocessed_hints[i] = None - elif images[i] is not None: - if isinstance(images[i], dict): - images[i] = images[i]["composite"] - images[i] = images[i].convert("RGB") - - if isinstance(image_dict, PIL.Image.Image): - image = image_dict.convert("RGB") - elif image_dict: - image = image_dict["image"].convert("RGB") - else: - image = None - if image: - image, _, _, = resize_stencil(image, width, height) - - device_id = None - - from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - - submit_pipe_kwargs = { - base_model_id: base_model_id, - height: height, - width: width, - precision: precision, - device: device, - extra_model_ids: extra_model_ids, - embeddings: lora_hf_ids, - import_ir: cmd_opts.import_ir, - } - submit_prep_kwargs = { - - - - global sd_pipe - global sd_pipe_kwargs - - for key in - - if sd_pipe is None: - history[-1][-1] = "Getting the pipeline ready..." - yield history, "" - - # Initializes the pipeline and retrieves IR based on all - # parameters that are static in the turbine output format, - # which is currently MLIR in the torch dialect. - - sd_pipe = SharkStableDiffusionPipeline( - **submit_pipe_kwargs - ) - sd_pipe.queue_compile() - - for prompt, msg, exec_time in progress.tqdm( - sd_pipe.generate_images( - prompt, - negative_prompt, - ), - desc="Generating Image...", - ): - - return history, "" - def view_json_file(file_obj): content = "" @@ -155,17 +56,33 @@ def view_json_file(file_obj): content = fopen.read() return content -sd_fn_sig = signature(shark_sd_fn) -max_controlnets = 5 + +max_controlnets = 3 max_loras = 5 + def show_loras(k): k = int(k) - return [gr.Dropdown(visible=True)]*k + [gr.Dropdown(visible=False, value="None")]*(max_textboxes-k) + return gr.State( + [gr.Dropdown(visible=True)] * k + + [gr.Dropdown(visible=False, value="None")] * (max_loras - k) + ) + def show_controlnets(k): k = int(k) - return [gr.Row(visible=True)]*k + [gr.Row(visible=False)]*(max_textboxes-k) + return [ + gr.State( + [ + [gr.Row(visible=True, render=True)] * k + + [gr.Row(visible=False)] * (max_controlnets - k) + ] + ), + gr.State([None] * k), + gr.State([None] * k), + gr.State([None] * k), + ] + def create_canvas(width, height): data = Image.fromarray( @@ -182,10 +99,9 @@ def create_canvas(width, height): } return EditorValue(img_dict) + def import_original(original_img, width, height): - resized_img, _, _ = resize_stencil( - original_img, width, height - ) + resized_img, _, _ = resize_stencil(original_img, width, height) img_dict = { "background": resized_img, "layers": [resized_img], @@ -196,6 +112,7 @@ def import_original(original_img, width, height): crop_size=(width, height), ) + def update_cn_input( model, width, @@ -203,7 +120,6 @@ def update_cn_input( stencils, images, preprocessed_hints, - index, ): if model == None: stencils[index] = None @@ -271,80 +187,99 @@ def update_cn_input( images, preprocessed_hints, ] + + +sd_fn_inputs = [] +sd_fn_sig = signature(shark_sd_fn).replace() +for i in sd_fn_sig.parameters: + sd_fn_inputs.append(i) + with gr.Blocks(title="Stable Diffusion") as sd_element: # Get a list of arguments needed for the API call, then # initialize an empty list that will manage the corresponding # gradio values. - inputs_list = gr.State(signature(shark_sd_fn)) - inputs_args = gr.State([None] * len(inputs_list)) with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - elem_id="top_logo", - width=150, - height=50, - show_download_button=False, - ) - save_sd_config = gr.Button(label="Save Config", scale=1) - load_sd_config = gr.FileExplorer("Load Config", scale=1) - clear_sd_config = gr.ClearButton("Clear Config", scale=1) - with gr.Column(elem_if="ui_body"): + nod_logo = Image.open(nodlogo_loc) + with gr.Row(variant="compact", equal_height=True): + with gr.Column( + scale=1, + elem_id="demo_title_outer", + ): + gr.Image( + value=nod_logo, + show_label=False, + interactive=False, + elem_id="top_logo", + width=150, + height=50, + show_download_button=False, + ) + with gr.Column(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): - with gr.Group() - sd_model_info = ( - f"Checkpoint Path: {str(get_checkpoint_path())}" - ) - sd_base = gr.Dropdown( - label="Base Model", - info="Select or enter HF model ID", - elem_id="custom_model", - value="stabilityai/stable-diffusion-2.1-base", - choices=get_base_models(), - ) # base_model_id - sd_checkpoint = gr.Dropdown( - label="Checkpoints (optional)", - info="Select or enter HF model ID", - elem_id="custom_model", - value="None", - choices=get_checkpoints(sd_base), - ) # - sd_vae_info = (str(get_checkpoints_path("vae"))).replace( - "\\", "\n\\" - ) - sd_vae_info = f"VAE Path: {sd_vae_info}" - sd_custom_vae = gr.Dropdown( - label=f"Custom VAE Models", - info=sd_vae_info, - elem_id="custom_model", - value=os.path.basename(cmd_opts.custom_vae) - if cmd_opts.custom_vae - else "None", - choices=["None"] + get_checkpoints("vae"), - allow_custom_value=True, - scale=1, - ) - + with gr.Row(equal_height=True): + with gr.Column(scale=3): + sd_model_info = ( + f"Checkpoint Path: {str(get_checkpoints_path())}" + ) + sd_base = gr.Dropdown( + label="Base Model", + info="Select or enter HF model ID", + elem_id="custom_model", + value="stabilityai/stable-diffusion-2-1-base", + choices=sd_model_map.keys(), + ) # base_model_id + sd_custom_weights = gr.Dropdown( + label="Weights (Optional)", + info="Select or enter HF model ID", + elem_id="custom_model", + value="None", + allow_custom_value=True, + choices=get_checkpoints(sd_base), + ) # + with gr.Column(scale=2): + sd_vae_info = ( + str(get_checkpoints_path("vae")) + ).replace("\\", "\n\\") + sd_vae_info = f"VAE Path: {sd_vae_info}" + sd_custom_vae = gr.Dropdown( + label=f"Custom VAE Models", + info=sd_vae_info, + elem_id="custom_model", + value=os.path.basename(cmd_opts.custom_vae) + if cmd_opts.custom_vae + else "None", + choices=["None"] + get_checkpoints("vae"), + allow_custom_value=True, + scale=1, + ) + with gr.Column(scale=1): + save_sd_config = gr.Button( + value="Save Config", size="sm" + ) + clear_sd_config = gr.ClearButton( + value="Clear Config", size="sm" + ) + load_sd_config = gr.FileExplorer( + label="Load Config", + root=os.path.basename("./configs"), + ) + with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( label="Prompt", - value=args.prompts[0], + value=cmd_opts.prompts[0], lines=2, elem_id="prompt_box", ) negative_prompt = gr.Textbox( label="Negative Prompt", - value=args.negative_prompts[0], + value=cmd_opts.negative_prompts[0], lines=2, elem_id="negative_prompt_box", ) - - with gr.Accordion(label = "Input Image", open=False): + + with gr.Accordion(label="Input Image", open=False): # TODO: make this import image prompt info if it exists sd_init_image = gr.Image( label="Input Image", @@ -352,41 +287,94 @@ def update_cn_input( height=300, interactive=True, ) - with gr.Accordion(label="Embeddings options", open=False): + with gr.Accordion( + label="Embeddings options", open=False, render=True + ): sd_lora_info = ( str(get_checkpoints_path("loras")) ).replace("\\", "\n\\") - num_loras = gr.Slider(1, max_loras, value=1, step=1, label="LoRA Count") - loras = [] + num_loras = gr.Slider( + 1, max_loras, value=1, step=1, label="LoRA Count" + ) + loras = gr.State([]) for i in range(max_loras): - lora_opt = gr.Dropdown( - allow_custom_value=False, - label=f"Standalone LoRA Weights", - info=sd_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), + with gr.Row(): + lora_opt = gr.Dropdown( + allow_custom_value=True, + label=f"Standalone LoRA Weights", + info=sd_lora_info, + elem_id="lora_weights", + value="None", + choices=["None"] + get_checkpoints("lora"), + ) + with gr.Row(): + lora_tags = gr.HTML( + value="
No LoRA selected
", + elem_classes="lora-tags", + ) + gr.on( + triggers=[lora_opt.change], + fn=lora_changed, + inputs=[lora_opt], + outputs=[lora_tags], + queue=True, ) + loras.value.append(lora_opt) + + num_loras.change(show_loras, [num_loras], [loras]) with gr.Accordion(label="Advanced Options", open=True): with gr.Row(): scheduler = gr.Dropdown( elem_id="scheduler", label="Scheduler", value="EulerDiscrete", - choices=scheduler_list, + choices=scheduler_model_map.keys(), allow_custom_value=False, ) with gr.Row(): height = gr.Slider( - 384, 768, value=cmd_opts.height, step=8, label="Height" + 384, + 768, + value=cmd_opts.height, + step=8, + label="Height", ) width = gr.Slider( - 384, 768, value=cmd_opts.width, step=8, label="Width" + 384, + 768, + value=cmd_opts.width, + step=8, + label="Width", ) with gr.Row(): with gr.Column(scale=3): steps = gr.Slider( - 1, 100, value=args.steps, step=1, label="Steps" + 1, + 100, + value=cmd_opts.steps, + step=1, + label="Steps", + ) + batch_count = gr.Slider( + 1, + 100, + value=cmd_opts.batch_count, + step=1, + label="Batch Count", + interactive=True, + ) + batch_size = gr.Slider( + 1, + 4, + value=cmd_opts.batch_size, + step=1, + label="Batch Size", + interactive=True, + visible=True, + ) + repeatable_seeds = gr.Checkbox( + cmd_opts.repeatable_seeds, + label="Repeatable Seeds", ) with gr.Column(scale=3): strength = gr.Slider( @@ -402,6 +390,13 @@ def update_cn_input( label="Resample Type", allow_custom_value=True, ) + guidance_scale = gr.Slider( + 0, + 50, + value=cmd_opts.guidance_scale, + step=0.1, + label="CFG Scale", + ) ondemand = gr.Checkbox( value=cmd_opts.lowvram, label="Low VRAM", @@ -416,38 +411,6 @@ def update_cn_input( ], visible=True, ) - with gr.Row(): - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=cmd_opts.guidance_scale, - step=0.1, - label="CFG Scale", - ) - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=cmd_opts.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - repeatable_seeds = gr.Checkbox( - cmd_opts.repeatable_seeds, - label="Repeatable Seeds", - ) - with gr.Row(): - batch_size = gr.Slider( - 1, - 4, - value=cmd_opts.batch_size, - step=1, - label="Batch Size", - interactive=True, - visible=True, - ) with gr.Row(): seed = gr.Textbox( value=cmd_opts.seed, @@ -457,40 +420,53 @@ def update_cn_input( device = gr.Dropdown( elem_id="device", label="Device", - value=get_available_devices[0], - choices=get_available_devices, + value=get_available_devices()[0], + choices=get_available_devices(), allow_custom_value=False, ) - with gr.Accordion(label="Controlnet Options", open=False): + with gr.Accordion( + label="Controlnet Options", open=False, render=False + ): sd_cnet_info = ( str(get_checkpoints_path("controlnet")) ).replace("\\", "\n\\") - num_cnets = gr.Slider(1, max_controlnets, value=1, step=1, label="Controlnet Count") + num_cnets = gr.Slider( + 0, + max_controlnets, + value=0, + step=1, + label="Controlnet Count", + ) cnet_rows = [] - stencils = [] - images = [] - preprocessed_hints = [] + stencils = gr.State([]) + images = gr.State([]) + preprocessed_hints = gr.State([]) + control_mode = gr.Radio( + choices=["Prompt", "Balanced", "Controlnet"], + value="Balanced", + label="Control Mode", + ) + for i in range(max_controlnets): - with gr.Row as cnet_row: + with gr.Row(visible=False) as cnet_row: with gr.Column(): cnet_gen = gr.Button( value="Preprocess controlnet input", ) - cnet_processor = gr.Dropdown( - allow_custom_value=True, - label=f"Controlnet Preprocessor", - info=sd_cnet_info, - elem_id="lora_weights", - value="None", - choices=["None"] + controlnet_list + get_custom_model_files("controlnet"), - ) - cnet_adapter = gr.Dropdown( + cnet_model = gr.Dropdown( allow_custom_value=True, - label=f"Controlnet Adapter", + label=f"Controlnet Model", info=sd_cnet_info, elem_id="lora_weights", value="None", - choices=["None"] + controlnet_list + get_custom_model_files("controlnet"), + choices=[ + "None", + "canny", + "openpose", + "scribble", + "zoedepth", + ] + + get_checkpoints("controlnet"), ) canvas_width = gr.Slider( label="Canvas Width", @@ -529,14 +505,13 @@ def update_cn_input( visible=True, label="Preprocessed Hint", interactive=True, - show_label=True + show_label=True, ) use_input_img.click( import_original, [sd_init_image, canvas_width, canvas_height], - [cnet_image], + [cnet_input], ) - cnet_model.change( fn=update_cn_input, inputs=[ @@ -563,7 +538,7 @@ def update_cn_input( create_canvas, [canvas_width, canvas_height], [ - cnet_image, + cnet_input, ], ) gr.on( @@ -583,12 +558,16 @@ def update_cn_input( preprocessed_hints, ], ) - cnet_rows.append(cnet_row) + cnet_rows.value.append(cnet_row) - num_cnets.change(show_controlnets, num_cnets, cnet_rows) + num_cnets.change( + show_controlnets, + [num_cnets], + [cnet_rows, stencils, images, preprocessed_hints], + ) with gr.Column(scale=1, min_width=600): with gr.Group(): - img2img_gallery = gr.Gallery( + sd_gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery", @@ -596,14 +575,14 @@ def update_cn_input( object_fit="contain", ) std_output = gr.Textbox( - value=f"{i2i_model_info}\n" + value=f"{sd_model_info}\n" f"Images will be saved at " f"{get_generated_imgs_path()}", lines=2, elem_id="std_output", show_label=False, ) - img2img_status = gr.Textbox(visible=False) + sd_status = gr.Textbox(visible=False) with gr.Row(): stable_diffusion = gr.Button("Generate Image(s)") random_seed = gr.Button("Randomize Seed") @@ -631,12 +610,11 @@ def update_cn_input( batch_size, scheduler, sd_base, - sd_checkpoint, + sd_custom_weights, sd_custom_vae, precision, device, - lora_weights, - lora_hf_id, + loras, ondemand, repeatable_seeds, resample_type, @@ -652,13 +630,13 @@ def update_cn_input( stencils, images, ], - show_progress="minimal" if cmd_opts.progress_bar else "none", + show_progress="minimal", ) status_kwargs = dict( - fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs), + fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs), inputs=[batch_count, batch_size], - outputs=img2img_status, + outputs=sd_status, ) prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) @@ -670,10 +648,3 @@ def update_cn_input( fn=cancel_sd, cancels=[prompt_submit, neg_prompt_submit, generate_click], ) - - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/shark_studio/web/ui/utils.py b/apps/shark_studio/web/ui/utils.py index 9b588f858a..ba62e5adc0 100644 --- a/apps/shark_studio/web/ui/utils.py +++ b/apps/shark_studio/web/ui/utils.py @@ -1,10 +1,33 @@ -def nodlogo_loc(): - return "foo" +from enum import IntEnum +import math +import sys +import os -def get_checkpoints_path(model_type: str = None): - return "foo" +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) -def get_checkpoints(): - return "foo" +nodlogo_loc = resource_path("logos/nod-logo.png") +nodicon_loc = resource_path("logos/nod-icon.png") + + +class HSLHue(IntEnum): + RED = 0 + YELLOW = 60 + GREEN = 120 + CYAN = 180 + BLUE = 240 + MAGENTA = 300 + + +def hsl_color(alpha: float, start, end): + b = (end - start) * (alpha if alpha > 0 else 0) + result = b + start + + # Return a CSS HSL string + return f"hsl({math.floor(result)}, 80%, 35%)" diff --git a/apps/shark_studio/web/utils/globals.py b/apps/shark_studio/web/utils/globals.py new file mode 100644 index 0000000000..0b5f54636a --- /dev/null +++ b/apps/shark_studio/web/utils/globals.py @@ -0,0 +1,74 @@ +import gc + +""" +The global objects include SD pipeline and config. +Maintaining the global objects would avoid creating extra pipeline objects when switching modes. +Also we could avoid memory leak when switching models by clearing the cache. +""" + + +def _init(): + global _sd_obj + global _config_obj + global _schedulers + _sd_obj = None + _config_obj = None + _schedulers = None + + +def set_sd_obj(value): + global _sd_obj + _sd_obj = value + + +def set_sd_scheduler(key): + global _sd_obj + _sd_obj.scheduler = _schedulers[key] + + +def set_sd_status(value): + global _sd_obj + _sd_obj.status = value + + +def set_cfg_obj(value): + global _config_obj + _config_obj = value + + +def set_schedulers(value): + global _schedulers + _schedulers = value + + +def get_sd_obj(): + global _sd_obj + return _sd_obj + + +def get_sd_status(): + global _sd_obj + return _sd_obj.status + + +def get_cfg_obj(): + global _config_obj + return _config_obj + + +def get_scheduler(key): + global _schedulers + return _schedulers[key] + + +def clear_cache(): + global _sd_obj + global _config_obj + global _schedulers + del _sd_obj + del _config_obj + del _schedulers + gc.collect() + _sd_obj = None + _config_obj = None + _schedulers = None diff --git a/apps/shark_studio/web/utils/metadata/__init__.py b/apps/shark_studio/web/utils/metadata/__init__.py new file mode 100644 index 0000000000..bcbcf746ca --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/__init__.py @@ -0,0 +1,6 @@ +from .png_metadata import ( + import_png_metadata, +) +from .display import ( + displayable_metadata, +) diff --git a/apps/shark_studio/web/utils/metadata/csv_metadata.py b/apps/shark_studio/web/utils/metadata/csv_metadata.py new file mode 100644 index 0000000000..d617e802bf --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/csv_metadata.py @@ -0,0 +1,45 @@ +import csv +import os +from .format import humanize, humanizable + + +def csv_path(image_filename: str): + return os.path.join(os.path.dirname(image_filename), "imgs_details.csv") + + +def has_csv(image_filename: str) -> bool: + return os.path.exists(csv_path(image_filename)) + + +def matching_filename(image_filename: str, row): + # we assume the final column of the csv has the original filename with full path and match that + # against the image_filename if we are given a list. Otherwise we assume a dict and and take + # the value of the OUTPUT key + return os.path.basename(image_filename) in ( + row[-1] if isinstance(row, list) else row["OUTPUT"] + ) + + +def parse_csv(image_filename: str): + csv_filename = csv_path(image_filename) + + with open(csv_filename, "r", newline="") as csv_file: + # We use a reader or DictReader here for images_details.csv depending on whether we think it + # has headers or not. Having headers means less guessing of the format. + has_header = csv.Sniffer().has_header(csv_file.read(2048)) + csv_file.seek(0) + + reader = ( + csv.DictReader(csv_file) if has_header else csv.reader(csv_file) + ) + + matches = [ + # we rely on humanize and humanizable to work out the parsing of the individual .csv rows + humanize(row) + for row in reader + if row + and (has_header or humanizable(row)) + and matching_filename(image_filename, row) + ] + + return matches[0] if matches else {} diff --git a/apps/shark_studio/web/utils/metadata/display.py b/apps/shark_studio/web/utils/metadata/display.py new file mode 100644 index 0000000000..26234aab5c --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/display.py @@ -0,0 +1,53 @@ +import json +import os +from PIL import Image +from .png_metadata import parse_generation_parameters +from .exif_metadata import has_exif, parse_exif +from .csv_metadata import has_csv, parse_csv +from .format import compact, humanize + + +def displayable_metadata(image_filename: str) -> dict: + if not os.path.isfile(image_filename): + return {"source": "missing", "parameters": {}} + + pil_image = Image.open(image_filename) + + # we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads, + # and we go via that for SendTo, and is directly tied to the image) + if "parameters" in pil_image.info: + return { + "source": "png", + "parameters": compact( + parse_generation_parameters(pil_image.info["parameters"]) + ), + } + + # we have a matching json file (next most likely to be accurate when it's there) + json_path = os.path.splitext(image_filename)[0] + ".json" + if os.path.isfile(json_path): + with open(json_path) as params_file: + return { + "source": "json", + "parameters": compact( + humanize(json.load(params_file), includes_filename=False) + ), + } + + # we have a CSV file so try that (can be different shapes, and it usually has no + # headers/param names so of the things we we *know* have parameters, it's the + # last resort) + if has_csv(image_filename): + params = parse_csv(image_filename) + if params: # we might not have found the filename in the csv + return { + "source": "csv", + "parameters": compact(params), # already humanized + } + + # EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something* + if has_exif(image_filename): + return {"source": "exif", "parameters": parse_exif(pil_image)} + + # we've got nothing + return None diff --git a/apps/shark_studio/web/utils/metadata/exif_metadata.py b/apps/shark_studio/web/utils/metadata/exif_metadata.py new file mode 100644 index 0000000000..c72da8a935 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/exif_metadata.py @@ -0,0 +1,52 @@ +from PIL import Image +from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS + + +def has_exif(image_filename: str) -> bool: + return True if Image.open(image_filename).getexif() else False + + +def parse_exif(pil_image: Image) -> dict: + img_exif = pil_image.getexif() + + # See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594 + # I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I + # I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a + # dependency + exif_tags = { + TAGS.get(key, key): str(val) + for (key, val) in img_exif.items() + if key in TAGS + and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo) + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + def try_get_ifd(ifd_id): + try: + return img_exif.get_ifd(ifd_id).items() + except KeyError: + return {} + + ifd_tags = { + TAGS.get(key, key): str(val) + for ifd_id in IFD + for (key, val) in try_get_ifd(ifd_id) + if ifd_id != IFD.GPSInfo + and key in TAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + gps_tags = { + GPSTAGS.get(key, key): str(val) + for (key, val) in try_get_ifd(IFD.GPSInfo) + if key in GPSTAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + return {**exif_tags, **ifd_tags, **gps_tags} diff --git a/apps/shark_studio/web/utils/metadata/format.py b/apps/shark_studio/web/utils/metadata/format.py new file mode 100644 index 0000000000..f097dab54f --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/format.py @@ -0,0 +1,143 @@ +# As SHARK has evolved more columns have been added to images_details.csv. However, since +# no version of the CSV has any headers (yet) we don't actually have anything within the +# file that tells us which parameter each column is for. So this is a list of known patterns +# indexed by length which is what we're going to have to use to guess which columns are the +# right ones for the file we're looking at. + +# The same ordering is used for JSON, but these do have key names, however they are not very +# human friendly, nor do they match up with the what is written to the .png headers + +# So these are functions to try and get something consistent out the raw input from all +# these sources + +PARAMS_FORMATS = { + 9: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 10: { + "MODEL": "Model", + "VARIANT": "Variant", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 12: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", + }, +} + +PARAMS_FORMAT_CURRENT = { + "VARIANT": "Model", + "VAE": "VAE", + "LORA": "LoRA", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", +} + + +def compact(metadata: dict) -> dict: + # we don't want to alter the original dictionary + result = dict(metadata) + + # discard the filename because we should already have it + if result.keys() & {"Filename"}: + result.pop("Filename") + + # make showing the sizes more compact by using only one line each + if result.keys() & {"Size-1", "Size-2"}: + result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}" + elif result.keys() & {"Height", "Width"}: + result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}" + + if result.keys() & {"Hires resize-1", "Hires resize-1"}: + hires_y = result.pop("Hires resize-1") + hires_x = result.pop("Hires resize-2") + + if hires_x == 0 and hires_y == 0: + result["Hires resize"] = "None" + else: + result["Hires resize"] = f"{hires_y}x{hires_x}" + + # remove VAE if it exists and is empty + if (result.keys() & {"VAE"}) and ( + not result["VAE"] or result["VAE"] == "None" + ): + result.pop("VAE") + + # remove LoRA if it exists and is empty + if (result.keys() & {"LoRA"}) and ( + not result["LoRA"] or result["LoRA"] == "None" + ): + result.pop("LoRA") + + return result + + +def humanizable(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + return lookup_key in PARAMS_FORMATS.keys() + + +def humanize(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + + # For lists we can only work based on the length, we have no other information + if isinstance(metadata, list): + if humanizable(metadata, includes_filename): + return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata)) + else: + raise KeyError( + f"Humanize could not find the format for a parameter list of length {len(metadata)}" + ) + + # For dictionaries we try to use the matching length parameter format if + # available, otherwise we just use the current format which is assumed to + # have everything currently known about. Then we swap keys in the metadata + # that match keys in the format for the friendlier name that we have set + # in the format value + if isinstance(metadata, dict): + if humanizable(metadata, includes_filename): + format = PARAMS_FORMATS[lookup_key] + else: + format = PARAMS_FORMAT_CURRENT + + return { + format[key]: metadata[key] + for key in format.keys() + if key in metadata.keys() and metadata[key] + } + + raise TypeError("Can only humanize parameter lists or dictionaries") diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py new file mode 100644 index 0000000000..cffc385ab7 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -0,0 +1,222 @@ +import re +from pathlib import Path +from apps.shark_studio.api.utils import ( + get_checkpoint_pathfile, +) +from apps.shark_studio.api.sd import ( + sd_model_map, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, +) + +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' +re_param = re.compile(re_param_code) +re_imagesize = re.compile(r"^(\d+)x(\d+)$") + + +def parse_generation_parameters(x: str): + res = {} + prompt = "" + negative_prompt = "" + done_with_prompt = False + + *lines, lastline = x.strip().split("\n") + if len(re_param.findall(lastline)) < 3: + lines.append(lastline) + lastline = "" + + for i, line in enumerate(lines): + line = line.strip() + if line.startswith("Negative prompt:"): + done_with_prompt = True + line = line[16:].strip() + + if done_with_prompt: + negative_prompt += ("" if negative_prompt == "" else "\n") + line + else: + prompt += ("" if prompt == "" else "\n") + line + + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt + + for k, v in re_param.findall(lastline): + v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v + m = re_imagesize.match(v) + if m is not None: + res[k + "-1"] = m.group(1) + res[k + "-2"] = m.group(2) + else: + res[k] = v + + # Missing CLIP skip means it was set to 1 (the default) + if "Clip skip" not in res: + res["Clip skip"] = "1" + + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res[ + "Prompt" + ] += f"""""" + + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + + return res + + +def try_find_model_base_from_png_metadata( + file: str, folder: str = "models" +) -> str: + custom = "" + + # Remove extension from file info + if file.endswith(".safetensors") or file.endswith(".ckpt"): + file = Path(file).stem + # Check for the file name match with one of the local ckpt or safetensors files + if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file(): + custom = file + ".ckpt" + if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file(): + custom = file + ".safetensors" + + return custom + + +def find_model_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + png_hf_id = "" + png_custom = "" + + if key in metadata: + model_file = metadata[key] + png_custom = try_find_model_base_from_png_metadata(model_file) + # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") + if model_file in sd_model_map: + png_custom = model_file + # If nothing had matched, check vendor/hf_model_id + if not png_custom and model_file.count("/"): + png_hf_id = model_file + # No matching model was found + if not png_custom and not png_hf_id: + print( + "Import PNG info: Unable to find a matching model for %s" + % model_file + ) + + return png_custom, png_hf_id + + +def find_vae_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> str: + vae_custom = "" + + if key in metadata: + vae_file = metadata[key] + vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae") + + # VAE input is optional, should not print or throw an error if missing + + return vae_custom + + +def find_lora_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + lora_hf_id = "" + lora_custom = "" + + if key in metadata: + lora_file = metadata[key] + lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora") + # If nothing had matched, check vendor/hf_model_id + if not lora_custom and lora_file.count("/"): + lora_hf_id = lora_file + + # LoRA input is optional, should not print or throw an error if missing + + return lora_custom, lora_hf_id + + +def import_png_metadata( + pil_data, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, +): + try: + png_info = pil_data.info["parameters"] + metadata = parse_generation_parameters(png_info) + + (png_custom_model, png_hf_model_id) = find_model_from_png_metadata( + "Model", metadata + ) + (lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata( + "LoRA", metadata + ) + vae_custom_model = find_vae_from_png_metadata("VAE", metadata) + + negative_prompt = metadata["Negative prompt"] + steps = int(metadata["Steps"]) + cfg_scale = float(metadata["CFG scale"]) + seed = int(metadata["Seed"]) + width = float(metadata["Size-1"]) + height = float(metadata["Size-2"]) + + if "Model" in metadata and png_custom_model: + custom_model = png_custom_model + elif "Model" in metadata and png_hf_model_id: + custom_model = png_hf_model_id + + if "LoRA" in metadata and lora_custom_model: + custom_lora = lora_custom_model + hf_lora_id = "" + if "LoRA" in metadata and lora_hf_model_id: + custom_lora = "None" + hf_lora_id = lora_hf_model_id + + if "VAE" in metadata and vae_custom_model: + custom_vae = vae_custom_model + + if "Prompt" in metadata: + prompt = metadata["Prompt"] + if "Sampler" in metadata: + if metadata["Sampler"] in scheduler_model_map: + sampler = metadata["Sampler"] + else: + print( + "Import PNG info: Unable to find a scheduler for %s" + % metadata["Sampler"] + ) + + except Exception as ex: + if pil_data and pil_data.info.get("parameters"): + print("import_png_metadata failed with %s" % ex) + pass + + return ( + None, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, + ) diff --git a/apps/shark_studio/web/utils/state.py b/apps/shark_studio/web/utils/state.py new file mode 100644 index 0000000000..626d4ce53f --- /dev/null +++ b/apps/shark_studio/web/utils/state.py @@ -0,0 +1,41 @@ +import apps.shark_studio.web.utils.globals as global_obj +import gc + + +def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1): + print(f"Getting status label for {tab_name}") + if batch_index < batch_count: + bs = f"x{batch_size}" if batch_size > 1 else "" + return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}" + else: + return f"{tab_name} complete" + + +def get_generation_text_info(seeds, device): + cfg_dump = {} + for cfg in global_obj.get_config_dict(): + cfg_dump[cfg] = cfg + text_output = f"prompt={cfg_dump['prompts']}" + text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}" + text_output += ( + f"\nmodel_id={cfg_dump['hf_model_id']}, " + f"ckpt_loc={cfg_dump['ckpt_loc']}" + ) + text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}" + text_output += ( + f"\nsteps={cfg_dump['steps']}, " + f"guidance_scale={cfg_dump['guidance_scale']}, " + f"seed={seeds}" + ) + text_output += ( + f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, " + if not cfg_dump.use_hiresfix + else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, " + ) + text_output += ( + f"batch_count={cfg_dump['batch_count']}, " + f"batch_size={cfg_dump['batch_size']}, " + f"max_length={cfg_dump['max_length']}" + ) + + return text_output diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py new file mode 100644 index 0000000000..3e6ba46bfe --- /dev/null +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -0,0 +1,77 @@ +import os +import shutil +from time import time + +shark_tmp = os.path.join(os.getcwd(), "shark_tmp/") + + +def clear_tmp_mlir(): + cleanup_start = time() + print( + "Clearing .mlir temporary files from a prior run. This may take some time..." + ) + mlir_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.endswith(".mlir") + ] + for filename in mlir_files: + os.remove(shark_tmp + filename) + print( + f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds." + ) + + +def clear_tmp_imgs(): + # tell gradio to use a directory under shark_tmp for its temporary + # image files unless somewhere else has been set + if "GRADIO_TEMP_DIR" not in os.environ: + os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio") + + print( + f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. " + + "You may change this by setting the GRADIO_TEMP_DIR environment variable." + ) + + # Clear all gradio tmp images from the last session + if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): + cleanup_start = time() + print( + "Clearing gradio UI temporary image files from a prior run. This may take some time..." + ) + shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True) + print( + f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds." + ) + + # older SHARK versions had to workaround gradio bugs and stored things differently + else: + image_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.startswith("tmp") + and filename.endswith(".png") + ] + if len(image_files) > 0: + print( + "Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..." + ) + cleanup_start = time() + for filename in image_files: + os.remove(shark_tmp + filename) + print( + f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds." + ) + else: + print("No temporary images files to clear.") + + +def config_tmp(): + # create shark_tmp if it does not exist + if not os.path.exists(shark_tmp): + os.mkdir(shark_tmp) + + clear_tmp_mlir() + clear_tmp_imgs()