Skip to content

Commit

Permalink
More UI fixes and txt2img_sdxl presets.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 3, 2023
1 parent 7975356 commit 9bfa20b
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 20 deletions.
38 changes: 28 additions & 10 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def __init__(
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
precision = "fp16" if "fp16" in custom_vae else None
print(f"Loading custom vae, with target {custom_vae}")
if os.path.exists(custom_vae):
self.vae = AutoencoderKL.from_pretrained(
Expand All @@ -457,12 +458,19 @@ def __init__(
]
)
print("Using hub to get custom vae")
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
try:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
variant=precision,
)
except:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
print(f"Loading custom vae, with target {custom_vae}")
print(f"Loading custom vae, with state {custom_vae}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
Expand Down Expand Up @@ -938,11 +946,19 @@ def __init__(
low_cpu_mem_usage=False,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
try:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
except:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
Expand Down Expand Up @@ -1084,13 +1100,15 @@ def __init__(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ def generate_images(
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
# imgs = self.decode_latents_sdxl(None)
# all_imgs.extend(imgs)
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
Expand Down Expand Up @@ -52,6 +55,7 @@ def __init__(
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
Expand Down
17 changes: 14 additions & 3 deletions apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
cancel_sd,
set_model_default_configs,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
Expand Down Expand Up @@ -271,7 +272,7 @@ def txt2img_sdxl_inf(
elem_id="custom_model",
value="None",
choices=[
"None",
None,
"madebyollin/sdxl-vae-fp16-fix",
]
+ get_custom_model_files("vae"),
Expand Down Expand Up @@ -339,6 +340,8 @@ def txt2img_sdxl_inf(
"DDIM",
"SharkEulerAncestralDiscrete",
"SharkEulerDiscrete",
"EulerAncestralDiscrete",
"EulerDiscrete",
],
allow_custom_value=False,
visible=True,
Expand Down Expand Up @@ -402,7 +405,7 @@ def txt2img_sdxl_inf(
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
label="Guidance Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
Expand Down Expand Up @@ -562,12 +565,14 @@ def txt2img_sdxl_inf(
custom_vae,
],
)
txt2img_sdxl_custom_model.select(
txt2img_sdxl_custom_model.change(
fn=set_model_default_configs,
inputs=[
txt2img_sdxl_custom_model,
],
outputs=[
prompt,
negative_prompt,
steps,
scheduler,
guidance_scale,
Expand All @@ -576,3 +581,9 @@ def txt2img_sdxl_inf(
custom_vae,
],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)
5 changes: 5 additions & 0 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ def resource_path(relative_path):
lines=2,
elem_id="prompt_box",
)
# TODO: coming soon
autogen = gr.Checkbox(
label="Continuous Generation",
visible=False,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
Expand Down
27 changes: 23 additions & 4 deletions apps/stable_diffusion/web/ui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import json
import safetensors
import gradio as gr

from pathlib import Path
from apps.stable_diffusion.src import args
Expand Down Expand Up @@ -272,6 +273,8 @@ def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
else:
# We don't have default metadata to setup a good config. Do not change configs.
return [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
Expand All @@ -285,6 +288,8 @@ def get_config_from_json(model_ckpt_or_id, jsonconfig):
# TODO: make this work properly. It is currently not user-exposed.
cfgdata = json.load(jsonconfig)
return [
cfgdata["prompt_box_behavior"],
cfgdata["neg_prompt_box_behavior"],
cfgdata["steps"],
cfgdata["scheduler"],
cfgdata["guidance_scale"],
Expand All @@ -305,13 +310,27 @@ def default_config_exists(model_ckpt_or_id):


default_configs = {
"stabilityai/sdxl-turbo": [1, "DDIM", 0, 512, 512, ""],
"stabilityai/sdxl-turbo": [
gr.Textbox(label="", interactive=False, value=None, visible=False),
gr.Textbox(
label="Prompt",
value="A shark lady watching her friend build a snowman, deep orange sky, color block, high resolution, ((8k uhd, excellent artwork))",
),
gr.Slider(0, 5, value=2),
gr.Dropdown(value="DDIM"),
gr.Slider(0, value=0),
512,
512,
"madebyollin/sdxl-vae-fp16-fix",
],
"stabilityai/stable-diffusion-xl-base-1.0": [
50,
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.Textbox(label="Negative Prompt", interactive=True),
40,
"DDIM",
7.5,
512,
512,
gr.Slider(value=1024, interactive=False),
gr.Slider(value=1024, interactive=False),
"madebyollin/sdxl-vae-fp16-fix",
],
}
Expand Down

0 comments on commit 9bfa20b

Please sign in to comment.