Skip to content

Commit

Permalink
Fixes to UI config defaults, config loading, and warnings. (#2153)
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 31, 2024
1 parent d2c3752 commit 26f80cc
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 64 deletions.
27 changes: 27 additions & 0 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import torch
import gradio as gr
import time
import os
import json
Expand Down Expand Up @@ -285,6 +286,32 @@ def shark_sd_fn_dict_input(
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])

# TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function.
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.")
return None, ""
if sd_kwargs["guidance_scale"] > 3:
gr.Warning(
"sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise."
)
return None, ""
if sd_kwargs["target_triple"] == "":
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
gr.Warning(
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
)
return None, ""

generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs

Expand Down
28 changes: 0 additions & 28 deletions apps/shark_studio/web/configs/default_sd_config.json

This file was deleted.

6 changes: 3 additions & 3 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_checkpoints_path,
get_checkpoints,
get_configs_path,
write_default_sd_config,
write_default_sd_configs,
)
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
Expand Down Expand Up @@ -257,7 +257,7 @@ def base_model_changed(base_model_id):
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="triple",
elem_id="target_triple",
label="Architecture",
value="",
)
Expand Down Expand Up @@ -629,7 +629,7 @@ def base_model_changed(base_model_id):
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_config(default_config_file)
write_default_sd_configs(get_configs_path())
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
Expand Down
95 changes: 95 additions & 0 deletions apps/shark_studio/web/utils/default_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""

sdxl_30steps = r"""{
"prompt": [
"a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 1024,
"width": 1024,
"steps": 30,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""

sdxl_turbo = r"""{
"prompt": [
"A cat wearing a hat that says 'TURBO' on it. The cat is sitting on a skateboard."
],
"negative_prompt": [
""
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 2,
"strength": 0.8,
"guidance_scale": 0,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerAncestralDiscrete",
"base_model_id": "stabilityai/sdxl-turbo",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "",
"target_triple": "",
"ondemand": false,
"compiled_pipeline": true,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""

default_sd_configs = {
"default_sd_config.json": default_sd_config,
"sdxl-30steps.json": sdxl_30steps,
"sdxl-turbo.json": sdxl_turbo,
}
41 changes: 8 additions & 33 deletions apps/shark_studio/web/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,14 @@
"*.safetensors",
)

default_sd_config = r"""{
"prompt": [
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
],
"negative_prompt": [
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
],
"sd_init_image": [null],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": "-1",
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
"custom_weights": null,
"custom_vae": null,
"precision": "fp16",
"device": "AMD Radeon RX 7900 XTX => vulkan://0",
"ondemand": false,
"repeatable_seeds": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}"""


def write_default_sd_config(path):
with open(path, "w") as f:
f.write(default_sd_config)
from apps.shark_studio.web.utils.default_configs import default_sd_configs


def write_default_sd_configs(path):
for key in default_sd_configs.keys():
config_fpath = os.path.join(path, key)
with open(config_fpath, "w") as f:
f.write(default_sd_configs[key])


def safe_name(name):
Expand Down

0 comments on commit 26f80cc

Please sign in to comment.