From c264c6027ea2e79b7d79dbf9a3df99a253746a01 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Tue, 30 Apr 2024 06:00:58 -0400 Subject: [PATCH] Address comments --- apps/shark_studio/api/utils.py | 26 -------------------- apps/shark_studio/modules/ckpt_processing.py | 25 ++++++------------- apps/shark_studio/web/ui/sd.py | 3 --- 3 files changed, 8 insertions(+), 46 deletions(-) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 2531029af5..e9268aa83b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -318,29 +318,3 @@ def parse_seed_input(seed_input: str | list | int): raise TypeError( "Seed input must be an integer or an array of integers in JSON format" ) - - -config = {} - - -def load_config(path): - # Create if file doesn't exist - # Read file contents - pass - - -def save_config(path): - pass - - -def get_config_value(key): - load_config(cmd_opts.conf) - # Return None if key doesn't exist - # Otherwise return value - pass - - -def set_config_value(key, value): - load_config(cmd_opts.conf) - config[key] = value - save_config(cmd_opts.conf) diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index f6cf896b3f..fc0bd3b7b8 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -81,23 +81,14 @@ def process_custom_pipe_weights(custom_weights): # act as if we were given the local file as custom_weights originally custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path) custom_weights_params = weights_path - return custom_weights_params, custom_weights_tgt - - elif not custom_weights.lower().endswith((".ckpt", ".safetensors")): - # TODO: HF downloader - # Load pretrained model - model = StableDiffusionPipeline.from_pretrained(custom_weights) - custom_weights = custom_weights + ".ckpt" - # Save pretrained to disk - model.save_pretrained(custom_weights) - # Delete pretrained model - del model - - assert custom_weights.lower().endswith( - (".ckpt", ".safetensors") - ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" - custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights) - custom_weights_params = custom_weights + + else: + assert custom_weights.lower().endswith( + (".ckpt", ".safetensors") + ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" + custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights) + custom_weights_params = custom_weights + return custom_weights_params, custom_weights_tgt diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 2bc82b3f3d..fa2c1836fd 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -397,9 +397,6 @@ def base_model_changed(base_model_id): elem_id="negative_prompt_box", show_copy_button=True, ) - with gr.Row(): - save_prompt = gr.Checkbox(label="Save Prompt") - save_negative_prompt = gr.Checkbox(label="Save Negative Prompt") with gr.Row(equal_height=True): seed = gr.Textbox( value=cmd_opts.seed,