Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters-amd committed Apr 30, 2024
1 parent f88f6f5 commit c264c60
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 46 deletions.
26 changes: 0 additions & 26 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,29 +318,3 @@ def parse_seed_input(seed_input: str | list | int):
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)


config = {}


def load_config(path):
# Create if file doesn't exist
# Read file contents
pass


def save_config(path):
pass


def get_config_value(key):
load_config(cmd_opts.conf)
# Return None if key doesn't exist
# Otherwise return value
pass


def set_config_value(key, value):
load_config(cmd_opts.conf)
config[key] = value
save_config(cmd_opts.conf)
25 changes: 8 additions & 17 deletions apps/shark_studio/modules/ckpt_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,14 @@ def process_custom_pipe_weights(custom_weights):
# act as if we were given the local file as custom_weights originally
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
custom_weights_params = weights_path
return custom_weights_params, custom_weights_tgt

elif not custom_weights.lower().endswith((".ckpt", ".safetensors")):
# TODO: HF downloader
# Load pretrained model
model = StableDiffusionPipeline.from_pretrained(custom_weights)
custom_weights = custom_weights + ".ckpt"
# Save pretrained to disk
model.save_pretrained(custom_weights)
# Delete pretrained model
del model

assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights

else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights

return custom_weights_params, custom_weights_tgt


Expand Down
3 changes: 0 additions & 3 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,6 @@ def base_model_changed(base_model_id):
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Row():
save_prompt = gr.Checkbox(label="Save Prompt")
save_negative_prompt = gr.Checkbox(label="Save Negative Prompt")
with gr.Row(equal_height=True):
seed = gr.Textbox(
value=cmd_opts.seed,
Expand Down

0 comments on commit c264c60

Please sign in to comment.