Skip to content

Commit

Permalink
SD/UI Include LoRA strength in args, png data + fixes
Browse files Browse the repository at this point in the history
* Add a `--lora_strength` command line argument
* Include lora strength when reading and writing png metadata
* Allow lora_strength to be set above 1.0 in the UI
* Remove LoRA analysis json files mistakenly commited
* Fix deletion of StableDiffusionPipeline.fromPretrained return by previous commit
  • Loading branch information
one-lithe-rune committed Dec 15, 2023
1 parent 7f5f65f commit e2cf8cd
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 1,628 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ def from_pretrained(
is_fp32_vae,
)

return cls(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)

# #####################################################
# Implements text embeddings with weights from prompts
# https://huggingface.co/AlanB/lpw_stable_diffusion_mod
Expand Down
7 changes: 7 additions & 0 deletions apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,13 @@ def is_valid_file(arg):
"file (~3 MB).",
)

p.add_argument(
"--lora_strength",
type=float,
default=1.0,
help="Strength (alpha) scaling factor to use when applying LoRA weights",
)

p.add_argument(
"--use_quantize",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def save_output_img(output_img, img_seed, extra_info=None):

img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
img_lora = f"{Path(os.path.basename(args.use_lora)).stem}:{args.lora_strength}"

if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
Expand Down
10 changes: 5 additions & 5 deletions apps/stable_diffusion/web/api/sdapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def txt2img_api(InputData: Txt2ImgInputData):
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=frozen_args.use_lora,
lora_strength=1.0,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
use_hiresfix=InputData.enable_hr,
Expand Down Expand Up @@ -306,7 +306,7 @@ def img2img_api(
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=frozen_args.use_lora,
lora_strength=1.0,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
resample_type=frozen_args.resample_type,
Expand Down Expand Up @@ -390,7 +390,7 @@ def inpaint_api(
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=frozen_args.use_lora,
lora_strength=1.0,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
Expand Down Expand Up @@ -480,7 +480,7 @@ def outpaint_api(
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=frozen_args.use_lora,
lora_strength=1.0,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
Expand Down Expand Up @@ -558,7 +558,7 @@ def upscaler_api(
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=frozen_args.use_lora,
lora_strength=1.0,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
Expand Down
7 changes: 7 additions & 0 deletions apps/stable_diffusion/web/ui/common_ui_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,10 @@ def lora_changed(lora_file):
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]


def lora_strength_changed(strength):
if strength > 1.0:
return gr.Number(elem_classes="value-out-of-range")
else:
return gr.Number(elem_classes="")
5 changes: 5 additions & 0 deletions apps/stable_diffusion/web/ui/css/sd_dark_theme.css
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ footer {
padding-right: 8px;
}

/* number input value is out of range */
.value-out-of-range input[type="number"] {
color: red !important;
}

/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
Expand Down
21 changes: 17 additions & 4 deletions apps/stable_diffusion/web/ui/img2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
Expand Down Expand Up @@ -821,9 +824,11 @@ def update_cn_input(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
minimum=0.1,
maximum=1.0,
value=1.0,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
Expand Down Expand Up @@ -1051,3 +1056,11 @@ def update_cn_input(
outputs=[lora_tags],
queue=True,
)

lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)
21 changes: 17 additions & 4 deletions apps/stable_diffusion/web/ui/inpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
Expand Down Expand Up @@ -364,9 +367,11 @@ def inpaint_inf(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
minimum=0.1,
maximum=1.0,
value=1.0,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
Expand Down Expand Up @@ -618,3 +623,11 @@ def inpaint_inf(
outputs=[lora_tags],
queue=True,
)

lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)
21 changes: 17 additions & 4 deletions apps/stable_diffusion/web/ui/outpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import gradio as gr
from PIL import Image

from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
Expand Down Expand Up @@ -310,9 +313,11 @@ def outpaint_inf(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
minimum=0.1,
maximum=1.0,
value=1.0,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
Expand Down Expand Up @@ -552,3 +557,11 @@ def outpaint_inf(
outputs=[lora_tags],
queue=True,
)

lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)
21 changes: 17 additions & 4 deletions apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
cancel_sd,
set_model_default_configs,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
Expand Down Expand Up @@ -330,9 +333,11 @@ def txt2img_sdxl_inf(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
minimum=0.1,
maximum=1.0,
value=1.0,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
Expand Down Expand Up @@ -645,3 +650,11 @@ def check_last_input(prompt):
outputs=[lora_tags],
queue=True,
)

lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)
29 changes: 22 additions & 7 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
Expand Down Expand Up @@ -387,7 +390,7 @@ def load_settings():
loaded_settings.get("prompt", args.prompts[0]),
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
loaded_settings.get("lora_weights", "None"),
loaded_settings.get("lora_strength", 1.0),
loaded_settings.get("lora_strength", args.lora_strength),
loaded_settings.get("scheduler", args.scheduler),
loaded_settings.get(
"save_metadata_to_png", args.write_metadata_to_png
Expand Down Expand Up @@ -504,14 +507,17 @@ def onload_load_settings():
value=default_settings.get("lora_weights"),
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
minimum=0.1,
maximum=1.0,
value=default_settings.get("lora_strength")
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=default_settings.get("lora_strength"),
scale=1,
)
with gr.Row():
Expand Down Expand Up @@ -732,7 +738,7 @@ def onload_load_settings():
prompt,
negative_prompt,
lora_weights,
lora_f,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
Expand Down Expand Up @@ -765,7 +771,7 @@ def onload_load_settings():
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
Expand Down Expand Up @@ -866,6 +872,7 @@ def onload_load_settings():
height,
txt2img_custom_model,
lora_weights,
lora_strength,
custom_vae,
],
)
Expand Down Expand Up @@ -896,3 +903,11 @@ def set_compatible_schedulers(hires_fix_selected):
outputs=[lora_tags],
queue=True,
)

lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)
21 changes: 17 additions & 4 deletions apps/stable_diffusion/web/ui/upscaler_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Expand Down Expand Up @@ -332,9 +335,11 @@ def upscaler_inf(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
minimum=0.1,
maximum=1.0,
value=1.0,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
Expand Down Expand Up @@ -548,3 +553,11 @@ def upscaler_inf(
outputs=[lora_tags],
queue=True,
)

lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)
Loading

0 comments on commit e2cf8cd

Please sign in to comment.