Skip to content

Commit

Permalink
Shark2/SD/UI: Respect ckpt_dir, share and server_port args (#2070)
Browse files Browse the repository at this point in the history
* Takes whether to generate a gradio live link from the existing --share command
line parameter, rather than hardcoding as True.
* Takes server port from existing --server_port command line parameter, rather than
hardcoding as 11911.
* Default --ckpt_dir parameter to '../models'
* Use --ckpt_dir rather than hardcoding ../models as the base directory for
checkpoints, vae, and lora, etc
* Add a 'checkpoints' directory below --ckpt_dir to match ComfyUI folder structure.
Read custom_weights choices from there, and/or subfolders below there matching
the selected base model.
* Fix --ckpt_dir possibly not working correctly when an absolute rather than relative path
is specified.
* Relabel "Custom Weights" to "Custom Weights Checkpoint" in the UI
  • Loading branch information
one-lithe-rune committed Feb 7, 2024
1 parent 019ba70 commit 01575a8
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 23 deletions.
2 changes: 1 addition & 1 deletion apps/shark_studio/modules/shared_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def is_valid_file(arg):
p.add_argument(
"--ckpt_dir",
type=str,
default="",
default="../models",
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)
Expand Down
4 changes: 2 additions & 2 deletions apps/shark_studio/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
# )
# t.start()
studio_web.launch(
share=True,
share=cmd_opts.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
server_port=cmd_opts.server_port,
)


Expand Down
23 changes: 20 additions & 3 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,17 @@ def import_original(original_img, width, height):
return EditorValue(img_dict)


def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")

return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
)


with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
Expand Down Expand Up @@ -259,13 +270,19 @@ def import_original(original_img, width, height):
choices=sd_default_models,
) # base_model_id
custom_weights = gr.Dropdown(
label="Custom Weights",
label="Custom Weights Checkpoint",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
allow_custom_value=True,
choices=["None"] + get_checkpoints(base_model_id),
) #
choices=["None"]
+ get_checkpoints(os.path.basename(str(base_model_id))),
) # custom_weights
base_model_id.change(
fn=base_model_changed,
inputs=[base_model_id],
outputs=[custom_weights],
)
with gr.Column(scale=2):
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
"\\", "\n\\"
Expand Down
38 changes: 21 additions & 17 deletions apps/shark_studio/web/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ def get_path_stem(path):
return path.stem


def get_resource_path(relative_path):
def get_resource_path(path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, relative_path)).resolve(strict=False)
return result
if os.path.isabs(path):
return path
else:
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, path)).resolve(strict=False)
return result


def get_configs_path() -> Path:
Expand All @@ -48,36 +51,37 @@ def get_generated_imgs_todays_subdir() -> str:


def create_checkpoint_folders():
dir = ["vae", "lora", "../vmfb"]
if not cmd_opts.ckpt_dir:
dir.insert(0, "models")
else:
if not os.path.isdir(cmd_opts.ckpt_dir):
dir = ["checkpoints", "vae", "lora", "vmfb"]
if not os.path.isdir(cmd_opts.ckpt_dir):
try:
os.makedirs(cmd_opts.ckpt_dir)
except OSError:
sys.exit(
f"Invalid --ckpt_dir argument, "
f"{cmd_opts.ckpt_dir} folder does not exists."
f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created."
)

for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)


def get_checkpoints_path(model=""):
return get_resource_path(f"../models/{model}")
def get_checkpoints_path(model_type=""):
return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type))


def get_checkpoints(model="models"):
def get_checkpoints(model_type="checkpoints"):
ckpt_files = []
file_types = checkpoints_filetypes
if model == "lora":
if model_type == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
files = [
os.path.basename(x)
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn))
]
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)


def get_checkpoint_pathfile(checkpoint_name, model="models"):
return os.path.join(get_checkpoints_path(model), checkpoint_name)
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)

0 comments on commit 01575a8

Please sign in to comment.