diff --git a/.gitignore b/.gitignore index 395a677ba6..eeb217e2b6 100644 --- a/.gitignore +++ b/.gitignore @@ -159,7 +159,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ # vscode related .vscode diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 6d11f96d08..1fcc03db09 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -757,6 +757,14 @@ def save_output_img(output_img, img_seed, extra_info={}): if args.ckpt_loc: img_model = Path(os.path.basename(args.ckpt_loc)).stem + img_vae = None + if args.custom_vae: + img_vae = Path(os.path.basename(args.custom_vae)).stem + + img_lora = None + if args.use_lora: + img_lora = Path(os.path.basename(args.use_lora)).stem + if args.output_img_format == "jpg": out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") output_img.save(out_img_path, quality=95, subsampling=0) @@ -767,7 +775,9 @@ def save_output_img(output_img, img_seed, extra_info={}): if args.write_metadata_to_png: pngInfo.add_text( "parameters", - f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}", + f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps: {args.steps}," + f"Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}," + f"Size: {args.width}x{args.height}, Model: {img_model}, VAE: {img_vae}, LoRA: {img_lora}", ) output_img.save(out_img_path, "PNG", pnginfo=pngInfo) @@ -778,6 +788,9 @@ def save_output_img(output_img, img_seed, extra_info={}): "Image saved as png instead. Supported formats: png / jpg" ) + # To be as low-impact as possible to the existing CSV format, we append + # "VAE" and "LORA" to the end. However, it does not fit the hierarchy of + # importance for each data point. Something to consider. new_entry = { "VARIANT": img_model, "SCHEDULER": args.scheduler, @@ -791,6 +804,8 @@ def save_output_img(output_img, img_seed, extra_info={}): "WIDTH": args.width, "MAX_LENGTH": args.max_length, "OUTPUT": out_img_path, + "VAE": img_vae, + "LORA": img_lora, } new_entry.update(extra_info) diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 44e41f1d4c..00585046cf 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -553,6 +553,9 @@ def txt2img_api( height, txt2img_custom_model, txt2img_hf_model_id, + lora_weights, + lora_hf_id, + custom_vae, ], outputs=[ txt2img_png_info_img, @@ -566,5 +569,8 @@ def txt2img_api( height, txt2img_custom_model, txt2img_hf_model_id, + lora_weights, + lora_hf_id, + custom_vae, ], ) diff --git a/apps/stable_diffusion/web/utils/metadata/png_metadata.py b/apps/stable_diffusion/web/utils/metadata/png_metadata.py index a9128ee108..51a92f07d6 100644 --- a/apps/stable_diffusion/web/utils/metadata/png_metadata.py +++ b/apps/stable_diffusion/web/utils/metadata/png_metadata.py @@ -62,6 +62,82 @@ def parse_generation_parameters(x: str): return res +def try_find_model_base_from_png_metadata( + file: str, folder: str = "models" +) -> str: + custom = "" + + # Remove extension from file info + if file.endswith(".safetensors") or file.endswith(".ckpt"): + file = Path(file).stem + # Check for the file name match with one of the local ckpt or safetensors files + if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file(): + custom = file + ".ckpt" + if Path( + get_custom_model_pathfile(file + ".safetensors", folder) + ).is_file(): + custom = file + ".safetensors" + + return custom + + +def find_model_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + png_hf_id = "" + png_custom = "" + + if key in metadata: + model_file = metadata[key] + png_custom = try_find_model_base_from_png_metadata(model_file) + # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") + if model_file in predefined_models: + png_custom = model_file + # If nothing had matched, check vendor/hf_model_id + if not png_custom and model_file.count("/"): + png_hf_id = model_file + # No matching model was found + if not png_custom and not png_hf_id: + print( + "Import PNG info: Unable to find a matching model for %s" + % model_file + ) + + return png_custom, png_hf_id + + +def find_vae_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> str: + vae_custom = "" + + if key in metadata: + vae_file = metadata[key] + vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae") + + # VAE input is optional, should not print or throw an error if missing + + return vae_custom + + +def find_lora_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + lora_hf_id = "" + lora_custom = "" + + if key in metadata: + lora_file = metadata[key] + lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora") + # If nothing had matched, check vendor/hf_model_id + if not lora_custom and lora_file.count("/"): + lora_hf_id = lora_file + + # LoRA input is optional, should not print or throw an error if missing + + return lora_custom, lora_hf_id + + def import_png_metadata( pil_data, prompt, @@ -74,40 +150,21 @@ def import_png_metadata( height, custom_model, hf_model_id, + custom_lora, + hf_lora_id, + custom_vae, ): try: png_info = pil_data.info["parameters"] metadata = parse_generation_parameters(png_info) - png_hf_model_id = "" - png_custom_model = "" - - if "Model" in metadata: - # Remove extension from model info - if metadata["Model"].endswith(".safetensors") or metadata[ - "Model" - ].endswith(".ckpt"): - metadata["Model"] = Path(metadata["Model"]).stem - # Check for the model name match with one of the local ckpt or safetensors files - if Path( - get_custom_model_pathfile(metadata["Model"] + ".ckpt") - ).is_file(): - png_custom_model = metadata["Model"] + ".ckpt" - if Path( - get_custom_model_pathfile(metadata["Model"] + ".safetensors") - ).is_file(): - png_custom_model = metadata["Model"] + ".safetensors" - # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") - if metadata["Model"] in predefined_models: - png_custom_model = metadata["Model"] - # If nothing had matched, check vendor/hf_model_id - if not png_custom_model and metadata["Model"].count("/"): - png_hf_model_id = metadata["Model"] - # No matching model was found - if not png_custom_model and not png_hf_model_id: - print( - "Import PNG info: Unable to find a matching model for %s" - % metadata["Model"] - ) + + (png_custom_model, png_hf_model_id) = find_model_from_png_metadata( + "Model", metadata + ) + (lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata( + "LoRA", metadata + ) + vae_custom_model = find_vae_from_png_metadata("VAE", metadata) negative_prompt = metadata["Negative prompt"] steps = int(metadata["Steps"]) @@ -115,12 +172,24 @@ def import_png_metadata( seed = int(metadata["Seed"]) width = float(metadata["Size-1"]) height = float(metadata["Size-2"]) + if "Model" in metadata and png_custom_model: custom_model = png_custom_model hf_model_id = "" if "Model" in metadata and png_hf_model_id: custom_model = "None" hf_model_id = png_hf_model_id + + if "LoRA" in metadata and lora_custom_model: + custom_lora = lora_custom_model + hf_lora_id = "" + if "LoRA" in metadata and lora_hf_model_id: + custom_lora = "None" + hf_lora_id = lora_hf_model_id + + if "VAE" in metadata and vae_custom_model: + custom_vae = vae_custom_model + if "Prompt" in metadata: prompt = metadata["Prompt"] if "Sampler" in metadata: @@ -149,4 +218,7 @@ def import_png_metadata( height, custom_model, hf_model_id, + custom_lora, + hf_lora_id, + custom_vae, )