From 5941c8f583f2e37a373dcd77835c8a3a12137cdd Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 31 May 2024 14:26:17 -0500 Subject: [PATCH] Fix default configs, config loading, and add warnings/early returns for bad configs. --- apps/shark_studio/api/sd.py | 23 +++++ .../web/configs/default_sd_config.json | 28 ------ apps/shark_studio/web/ui/sd.py | 6 +- .../shark_studio/web/utils/default_configs.py | 95 +++++++++++++++++++ apps/shark_studio/web/utils/file_utils.py | 41 ++------ 5 files changed, 129 insertions(+), 64 deletions(-) delete mode 100644 apps/shark_studio/web/configs/default_sd_config.json create mode 100644 apps/shark_studio/web/utils/default_configs.py diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index e0534db5de..d74e9d0a92 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -1,5 +1,6 @@ import gc import torch +import gradio as gr import time import os import json @@ -285,6 +286,28 @@ def shark_sd_fn_dict_input( if key == "seed": sd_kwargs[key] = int(sd_kwargs[key]) + #TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function. + if sd_kwargs["device"] == "": + gr.Warning("No device specified. Please specify a device.") + return None, "" + if sd_kwargs["height"] not in [512, 1024]: + gr.Warning("Height must be 512 or 1024. This is a temporary limitation.") + return None, "" + if sd_kwargs["height"] != sd_kwargs["width"]: + gr.Warning("Height and width must be the same. This is a temporary limitation.") + return None, "" + if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo": + if sd_kwargs["steps"] > 10: + gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.") + return None, "" + if sd_kwargs["guidance_scale"] > 3: + gr.Warning("sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise.") + return None, "" + if sd_kwargs["target_triple"] == "": + if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "": + gr.Warning("Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx.") + return None, "" + generated_imgs = yield from shark_sd_fn(**sd_kwargs) return generated_imgs diff --git a/apps/shark_studio/web/configs/default_sd_config.json b/apps/shark_studio/web/configs/default_sd_config.json deleted file mode 100644 index 323a6a329c..0000000000 --- a/apps/shark_studio/web/configs/default_sd_config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "prompt": [ - "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" - ], - "negative_prompt": [ - "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" - ], - "sd_init_image": [null], - "height": 512, - "width": 512, - "steps": 50, - "strength": 0.8, - "guidance_scale": 7.5, - "seed": "-1", - "batch_count": 1, - "batch_size": 1, - "scheduler": "EulerDiscrete", - "base_model_id": "stabilityai/stable-diffusion-2-1-base", - "custom_weights": null, - "custom_vae": null, - "precision": "fp16", - "device": "AMD Radeon RX 7900 XTX => vulkan://0", - "ondemand": false, - "repeatable_seeds": false, - "resample_type": "Nearest Neighbor", - "controlnets": {}, - "embeddings": {} -} \ No newline at end of file diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 13daa83aa8..c658dab0e6 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -14,7 +14,7 @@ get_checkpoints_path, get_checkpoints, get_configs_path, - write_default_sd_config, + write_default_sd_configs, ) from apps.shark_studio.api.sd import ( shark_sd_fn_dict_input, @@ -257,7 +257,7 @@ def base_model_changed(base_model_id): allow_custom_value=False, ) target_triple = gr.Textbox( - elem_id="triple", + elem_id="target_triple", label="Architecture", value="", ) @@ -629,7 +629,7 @@ def base_model_changed(base_model_id): get_configs_path(), "default_sd_config.json", ) - write_default_sd_config(default_config_file) + write_default_sd_configs(get_configs_path()) sd_json = gr.JSON( elem_classes=["fill"], value=view_json_file(default_config_file), diff --git a/apps/shark_studio/web/utils/default_configs.py b/apps/shark_studio/web/utils/default_configs.py new file mode 100644 index 0000000000..b52ba63fec --- /dev/null +++ b/apps/shark_studio/web/utils/default_configs.py @@ -0,0 +1,95 @@ +default_sd_config = r"""{ + "prompt": [ + "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" + ], + "negative_prompt": [ + "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" + ], + "sd_init_image": [null], + "height": 512, + "width": 512, + "steps": 50, + "strength": 0.8, + "guidance_scale": 7.5, + "seed": "-1", + "batch_count": 1, + "batch_size": 1, + "scheduler": "EulerDiscrete", + "base_model_id": "stabilityai/stable-diffusion-2-1-base", + "custom_weights": null, + "custom_vae": null, + "precision": "fp16", + "device": "", + "target_triple": "", + "ondemand": false, + "compiled_pipeline": false, + "resample_type": "Nearest Neighbor", + "controlnets": {}, + "embeddings": {} +}""" + +sdxl_30steps = r"""{ + "prompt": [ + "a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "negative_prompt": [ + "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" + ], + "sd_init_image": [null], + "height": 1024, + "width": 1024, + "steps": 30, + "strength": 0.8, + "guidance_scale": 7.5, + "seed": "-1", + "batch_count": 1, + "batch_size": 1, + "scheduler": "EulerDiscrete", + "base_model_id": "stabilityai/stable-diffusion-xl-base-1.0", + "custom_weights": null, + "custom_vae": null, + "precision": "fp16", + "device": "", + "target_triple": "", + "ondemand": false, + "compiled_pipeline": true, + "resample_type": "Nearest Neighbor", + "controlnets": {}, + "embeddings": {} +}""" + +sdxl_turbo = r"""{ + "prompt": [ + "A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard." + ], + "negative_prompt": [ + "" + ], + "sd_init_image": [null], + "height": 512, + "width": 512, + "steps": 2, + "strength": 0.8, + "guidance_scale": 0, + "seed": "-1", + "batch_count": 1, + "batch_size": 1, + "scheduler": "EulerAncestralDiscrete", + "base_model_id": "stabilityai/sdxl-turbo", + "custom_weights": null, + "custom_vae": null, + "precision": "fp16", + "device": "", + "target_triple": "", + "ondemand": false, + "compiled_pipeline": true, + "resample_type": "Nearest Neighbor", + "controlnets": {}, + "embeddings": {} +}""" + +default_sd_configs = { + "default_sd_config.json": default_sd_config, + "sdxl-30steps.json": sdxl_30steps, + "sdxl-turbo.json": sdxl_turbo, +} \ No newline at end of file diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index 3619055676..84ba6a6c06 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -11,39 +11,14 @@ "*.safetensors", ) -default_sd_config = r"""{ - "prompt": [ - "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" - ], - "negative_prompt": [ - "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" - ], - "sd_init_image": [null], - "height": 512, - "width": 512, - "steps": 50, - "strength": 0.8, - "guidance_scale": 7.5, - "seed": "-1", - "batch_count": 1, - "batch_size": 1, - "scheduler": "EulerDiscrete", - "base_model_id": "stabilityai/stable-diffusion-2-1-base", - "custom_weights": null, - "custom_vae": null, - "precision": "fp16", - "device": "AMD Radeon RX 7900 XTX => vulkan://0", - "ondemand": false, - "repeatable_seeds": false, - "resample_type": "Nearest Neighbor", - "controlnets": {}, - "embeddings": {} -}""" - - -def write_default_sd_config(path): - with open(path, "w") as f: - f.write(default_sd_config) +from apps.shark_studio.web.utils.default_configs import default_sd_configs + + +def write_default_sd_configs(path): + for key in default_sd_configs.keys(): + config_fpath = os.path.join(path, key) + with open(config_fpath, "w") as f: + f.write(default_sd_configs[key]) def safe_name(name):