diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index 0bbfc440a4..fe77022068 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -23,7 +23,6 @@ def is_valid_file(arg): ############################################################################## # Stable Diffusion Params ############################################################################## - p.add_argument( "-a", "--app", @@ -595,9 +594,10 @@ def is_valid_file(arg): # Web UI flags ############################################################################## p.add_argument( - "--default_config", + "--defaults", default="sdxl-turbo.json", type=str, + help="Path to the default API request .json file. Works for CLI and webui." ) p.add_argument( diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index ddf8276a92..a7b5e2068a 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -14,6 +14,7 @@ get_checkpoints_path, get_checkpoints, get_configs_path, + get_configs, write_default_sd_configs, ) from apps.shark_studio.api.sd import ( @@ -148,7 +149,14 @@ def pull_sd_configs( def load_sd_cfg(sd_json: dict, load_sd_config: str): - new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config))) + if os.path.exists(load_sd_config): + config = load_sd_config + elif os.path.exists(os.path.join(get_configs_path(), load_sd_config)): + config = os.path.join(get_configs_path(), load_sd_config) + else: + print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.") + config = sd_json + new_sd_config = none_to_str_none(json.loads(view_json_file(config))) if sd_json: for key in new_sd_config: sd_json[key] = new_sd_config[key] @@ -241,17 +249,17 @@ def base_model_changed(base_model_id): ) + get_checkpoints(model_type="checkpoints") if "turbo" in base_model_id: new_steps = gr.Dropdown( - value=cmd_opts.steps, + value=2, choices=[1, 2], label="\U0001F3C3\U0000FE0F Steps", - allow_custom_value=False, + allow_custom_value=True, ) if "stable-diffusion-xl-base-1.0" in base_model_id: new_steps = gr.Dropdown( value=40, choices=[20, 25, 30, 35, 40, 45, 50], label="\U0001F3C3\U0000FE0F Steps", - allow_custom_value=False, + allow_custom_value=True, ) elif ".py" in base_model_id: new_steps = gr.Dropdown( @@ -262,7 +270,7 @@ def base_model_changed(base_model_id): ) else: new_steps = gr.Dropdown( - value=cmd_opts.steps, + value=20, choices=[10, 20, 30, 40, 50], label="\U0001F3C3\U0000FE0F Steps", allow_custom_value=True, @@ -276,70 +284,197 @@ def base_model_changed(base_model_id): new_steps, ] +init_config = global_obj.get_init_config() +init_config = none_to_str_none(json.loads(view_json_file(init_config))) with gr.Blocks(title="Stable Diffusion") as sd_element: with gr.Column(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=2, min_width=600): + with gr.Group(elem_id="prompt_box_outer"): + prompt = gr.Textbox( + label="\U00002795\U0000FE0F Prompt", + value=init_config["prompt"][0], + lines=4, + elem_id="prompt_box", + show_copy_button=True, + ) + negative_prompt = gr.Textbox( + label="\U00002796\U0000FE0F Negative Prompt", + value=init_config["negative_prompt"][0], + lines=4, + elem_id="negative_prompt_box", + show_copy_button=True, + ) with gr.Accordion( - label="\U0001F4D0\U0000FE0F Device Settings", open=False + label="\U0001F4D0\U0000FE0F Advanced Settings", open=True ): - device = gr.Dropdown( - elem_id="device", - label="Device", - value=global_obj.get_device_list()[0], - choices=global_obj.get_device_list(), - allow_custom_value=False, - ) - with gr.Row(): - ondemand = gr.Checkbox( - value=cmd_opts.lowvram, - label="Low VRAM", - interactive=True, - visible=False, + with gr.Accordion( + label="Device Settings", open=False + ): + device = gr.Dropdown( + elem_id="device", + label="Device", + value=init_config["device"] if init_config["device"] else "rocm", + choices=global_obj.get_device_list(), + allow_custom_value=True, ) target_triple = gr.Textbox( elem_id="target_triple", label="Architecture", - value="", + value=init_config["target_triple"], ) - precision = gr.Radio( - label="Precision", - value=cmd_opts.precision, - choices=[ - "fp16", - "fp32", - ], + with gr.Row(): + ondemand = gr.Checkbox( + value=init_config["ondemand"], + label="Low VRAM", + interactive=True, + visible=False, + ) + precision = gr.Radio( + label="Precision", + value=init_config["precision"], + choices=[ + "fp16", + "fp32", + ], + visible=False, + ) + with gr.Row(): + height = gr.Slider( + 512, + 1024, + value=512, + step=512, + label="\U00002195\U0000FE0F Height", + interactive=False, # DEMO + visible=False, # DEMO + ) + width = gr.Slider( + 512, + 1024, + value=512, + step=512, + label="\U00002194\U0000FE0F Width", + interactive=False, # DEMO + visible=False, # DEMO + ) + + with gr.Accordion( + label="\U0001F9EA\U0000FE0F Input Image Processing", + open=False, + visible=False, + ): + strength = gr.Slider( + 0, + 1, + value=init_config["strength"], + step=0.01, + label="Denoising Strength", + ) + resample_type = gr.Dropdown( + value=init_config["resample_type"], + choices=resampler_list, + label="Resample Type", + allow_custom_value=True, + ) + with gr.Row(): + sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}" + base_model_id = gr.Dropdown( + label="\U000026F0\U0000FE0F Base Model", + info="Select or enter HF model ID", + elem_id="custom_model", + value=init_config["base_model_id"], + choices=sd_default_models, + allow_custom_value=True, + ) # base_model_id + with gr.Row(equal_height=True): + seed = gr.Textbox( + value=init_config["seed"], + label="\U0001F331\U0000FE0F Seed", + info="An integer, -1 for random", + show_copy_button=True, + ) + scheduler = gr.Dropdown( + elem_id="scheduler", + label="\U0001F4C5\U0000FE0F Scheduler", + info="\U000E0020", # forces same height as seed + value=init_config["scheduler"], + choices=scheduler_model_map.keys(), + allow_custom_value=False, + ) + with gr.Row(): + steps = gr.Dropdown( + value=20, + choices=[10, 15, 20], + label="\U0001F3C3\U0000FE0F Steps", + allow_custom_value=True, + ) + guidance_scale = gr.Slider( + 0, + 5, #DEMO + value=4, + step=0.1, + label="\U0001F5C3\U0000FE0F CFG Scale", + ) + with gr.Row(): + batch_count = gr.Slider( + 1, + 100, + value=init_config["batch_count"], + step=1, + label="Batch Count", + interactive=True, visible=True, ) - sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}" - base_model_id = gr.Dropdown( - label="\U000026F0\U0000FE0F Base Model", - info="Select or enter HF model ID", - elem_id="custom_model", - value="stabilityai/sdxl-turbo", - choices=sd_default_models, - allow_custom_value=True, - ) # base_model_id - with gr.Row(): - height = gr.Slider( - 512, - 1024, - value=512, - step=512, - label="\U00002195\U0000FE0F Height", - interactive=False, # DEMO - visible=False, # DEMO - ) - width = gr.Slider( - 512, - 1024, - value=512, - step=512, - label="\U00002194\U0000FE0F Width", - interactive=False, # DEMO - visible=False, # DEMO - ) + batch_size = gr.Slider( + 1, + 4, + value=init_config["batch_size"], + step=1, + label="Batch Size", + interactive=False, # DEMO + visible=True, + ) + compiled_pipeline = gr.Checkbox( + value=init_config["compiled_pipeline"], + label="Faster txt2img (SDXL only)", + visible=False, # DEMO + ) + with gr.Row(elem_classes=["fill"], visible=False): + Path(get_configs_path()).mkdir( + parents=True, exist_ok=True + ) + write_default_sd_configs(get_configs_path()) + default_config_file = global_obj.get_init_config() + sd_json = gr.JSON( + elem_classes=["fill"], + value=view_json_file(default_config_file), + ) + with gr.Row(): + with gr.Row(): + load_sd_config = gr.Dropdown( + label="Load Config", + value=cmd_opts.defaults, + choices=get_configs(), + allow_custom_value=True, + ) + with gr.Row(): + save_sd_config = gr.Button( + value="Save Config", size="sm" + ) + clear_sd_config = gr.ClearButton( + value="Clear Config", + size="sm", + components=sd_json, + ) + # with gr.Row(): + sd_config_name = gr.Textbox( + value="Config Name", + info="Name of the file this config will be saved to.", + interactive=True, + show_label=False, + ) with gr.Accordion( label="\U00002696\U0000FE0F Model Weights", open=False, @@ -350,7 +485,7 @@ def base_model_changed(base_model_id): label="Checkpoint Weights", info="Select or enter HF model ID", elem_id="custom_model", - value="None", + value=init_config["custom_weights"], allow_custom_value=True, choices=["None"] + get_checkpoints(os.path.basename(str(base_model_id))), @@ -363,11 +498,7 @@ def base_model_changed(base_model_id): label=f"VAE Model", info=sd_vae_info, elem_id="custom_model", - value=( - os.path.basename(cmd_opts.custom_vae) - if cmd_opts.custom_vae - else "None" - ), + value=init_config["custom_vae"], choices=["None"] + get_checkpoints("vae"), allow_custom_value=True, scale=1, @@ -380,7 +511,7 @@ def base_model_changed(base_model_id): label=f"Standalone LoRA Weights", info=sd_lora_info, elem_id="lora_weights", - value=None, + value=init_config["embeddings"][0] if (len(init_config["embeddings"].keys()) > 1) else "None", multiselect=True, choices=[] + get_checkpoints("lora"), scale=2, @@ -405,68 +536,6 @@ def base_model_changed(base_model_id): outputs=[embeddings_config], show_progress=False, ) - with gr.Accordion( - label="\U0001F9EA\U0000FE0F Input Image Processing", - open=False, - visible=False, - ): - strength = gr.Slider( - 0, - 1, - value=cmd_opts.strength, - step=0.01, - label="Denoising Strength", - ) - resample_type = gr.Dropdown( - value=cmd_opts.resample_type, - choices=resampler_list, - label="Resample Type", - allow_custom_value=True, - ) - with gr.Group(elem_id="prompt_box_outer"): - prompt = gr.Textbox( - label="\U00002795\U0000FE0F Prompt", - value=cmd_opts.prompt[0], - lines=2, - elem_id="prompt_box", - show_copy_button=True, - ) - negative_prompt = gr.Textbox( - label="\U00002796\U0000FE0F Negative Prompt", - value=cmd_opts.negative_prompt[0], - lines=2, - elem_id="negative_prompt_box", - show_copy_button=True, - ) - with gr.Row(equal_height=True): - seed = gr.Textbox( - value=cmd_opts.seed, - label="\U0001F331\U0000FE0F Seed", - info="An integer, -1 for random", - show_copy_button=True, - ) - scheduler = gr.Dropdown( - elem_id="scheduler", - label="\U0001F4C5\U0000FE0F Scheduler", - info="\U000E0020", # forces same height as seed - value="EulerAncestralDiscrete", - choices=scheduler_model_map.keys(), - allow_custom_value=False, - ) - with gr.Row(): - steps = gr.Dropdown( - value=cmd_opts.steps, - choices=[1, 2], - label="\U0001F3C3\U0000FE0F Steps", - allow_custom_value=True, - ) - guidance_scale = gr.Slider( - 0, - 5, #DEMO - value=cmd_opts.guidance_scale, - step=0.1, - label="\U0001F5C3\U0000FE0F CFG Scale", - ) with gr.Accordion( label="Controlnet Options", open=False, @@ -628,30 +697,6 @@ def base_model_changed(base_model_id): object_fit="fit", preview=True, ) - with gr.Row(): - batch_count = gr.Slider( - 1, - 100, - value=cmd_opts.batch_count, - step=1, - label="Batch Count", - interactive=True, - visible=True, - ) - batch_size = gr.Slider( - 1, - 4, - value=cmd_opts.batch_size, - step=1, - label="Batch Size", - interactive=True, - visible=False, # DEMO - ) - compiled_pipeline = gr.Checkbox( - True, - label="Faster txt2img (SDXL only)", - visible=False, # DEMO - ) with gr.Row(): stable_diffusion = gr.Button("Start") unload = gr.Button("Unload Models") @@ -661,90 +706,43 @@ def base_model_changed(base_model_id): show_progress=False, ) stop_batch = gr.Button("Stop", visible=False) - with gr.Tab(label="Config", id=102) as sd_tab_config: - with gr.Column(elem_classes=["sd-right-panel"]): - with gr.Row(elem_classes=["fill"]): - Path(get_configs_path()).mkdir( - parents=True, exist_ok=True - ) - write_default_sd_configs(get_configs_path()) - default_config_file = os.path.join( - get_configs_path(), - "sdxl-turbo.json", - ) - sd_json = gr.JSON( - elem_classes=["fill"], - value=view_json_file(default_config_file), - ) - with gr.Row(): - with gr.Column(scale=3): - load_sd_config = gr.FileExplorer( - label="Load Config", - file_count="single", - root_dir=( - cmd_opts.configs_path - if cmd_opts.configs_path - else get_configs_path() - ), - height=200, - ) - with gr.Column(scale=1): - save_sd_config = gr.Button( - value="Save Config", size="sm" - ) - clear_sd_config = gr.ClearButton( - value="Clear Config", - size="sm", - components=sd_json, - ) - # with gr.Row(): - sd_config_name = gr.Textbox( - value="Config Name", - info="Name of the file this config will be saved to.", - interactive=True, - show_label=False, - ) - load_sd_config.change( - fn=load_sd_cfg, - inputs=[sd_json, load_sd_config], - outputs=[ - prompt, - negative_prompt, - sd_init_image, - height, - width, - steps, - strength, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - base_model_id, - custom_weights, - custom_vae, - precision, - device, - target_triple, - ondemand, - compiled_pipeline, - resample_type, - cnet_config, - embeddings_config, - sd_json, - ], - ) - save_sd_config.click( - fn=save_sd_cfg, - inputs=[sd_json, sd_config_name], - outputs=[sd_config_name], - ) - save_sd_config.click( - fn=save_sd_cfg, - inputs=[sd_json, sd_config_name], - outputs=[sd_config_name], - ) - with gr.Tab(label="Log", id=103) as sd_tab_log: + # with gr.Tab(label="Config", id=102) as sd_tab_config: + # with gr.Group():#elem_classes=["sd-right-panel"]): + # with gr.Row(elem_classes=["fill"], visible=False): + # Path(get_configs_path()).mkdir( + # parents=True, exist_ok=True + # ) + # write_default_sd_configs(get_configs_path()) + # default_config_file = global_obj.get_init_config() + # sd_json = gr.JSON( + # elem_classes=["fill"], + # value=view_json_file(default_config_file), + # ) + # with gr.Row(): + # with gr.Row(): + # load_sd_config = gr.Dropdown( + # label="Load Config", + # value=cmd_opts.defaults, + # choices=get_configs(), + # allow_custom_value=True, + # ) + # with gr.Row(): + # save_sd_config = gr.Button( + # value="Save Config", size="sm" + # ) + # clear_sd_config = gr.ClearButton( + # value="Clear Config", + # size="sm", + # components=sd_json, + # ) + # # with gr.Row(): + # sd_config_name = gr.Textbox( + # value="Config Name", + # info="Name of the file this config will be saved to.", + # interactive=True, + # show_label=False, + # ) + with gr.Tab(label="Log", id=103, visible=False) as sd_tab_log: with gr.Row(): std_output = gr.Textbox( value=f"{sd_model_info}\n" @@ -765,7 +763,41 @@ def base_model_changed(base_model_id): inputs=[base_model_id], outputs=[custom_weights, steps], ) - + load_sd_config.change( + fn=load_sd_cfg, + inputs=[sd_json, load_sd_config], + outputs=[ + prompt, + negative_prompt, + sd_init_image, + height, + width, + steps, + strength, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + base_model_id, + custom_weights, + custom_vae, + precision, + device, + target_triple, + ondemand, + compiled_pipeline, + resample_type, + cnet_config, + embeddings_config, + sd_json, + ], + ) + save_sd_config.click( + fn=save_sd_cfg, + inputs=[sd_json, sd_config_name], + outputs=[sd_config_name], + ) pull_kwargs = dict( fn=pull_sd_configs, inputs=[ diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index f27c1dca62..b83b989ec4 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -100,6 +100,15 @@ def get_checkpoints(model_type="checkpoints"): ckpt_files.extend(files) return sorted(ckpt_files, key=str.casefold) +def get_configs(): + return sorted( + [ + os.path.basename(x) + for x in glob.glob(os.path.join(get_configs_path(), "*.json")) + ], + key=str.casefold, + ) + def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"): return os.path.join(get_checkpoints_path(model_type), checkpoint_name) diff --git a/apps/shark_studio/web/utils/globals.py b/apps/shark_studio/web/utils/globals.py index 27910e74ef..963cef3d5f 100644 --- a/apps/shark_studio/web/utils/globals.py +++ b/apps/shark_studio/web/utils/globals.py @@ -1,12 +1,18 @@ import gc from ...api.utils import get_available_devices - +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +import os +from apps.shark_studio.web.utils.file_utils import get_configs_path """ The global objects include SD pipeline and config. Maintaining the global objects would avoid creating extra pipeline objects when switching modes. Also we could avoid memory leak when switching models by clearing the cache. """ - +def view_json_file(file_path): + content = "" + with open(file_path, "r") as fopen: + content = fopen.read() + return content def _init(): global _sd_obj @@ -89,6 +95,16 @@ def get_device_list(): global _devices return _devices +def get_init_config(): + global _init_config + if os.path.exists(cmd_opts.defaults): + _init_config = cmd_opts.defaults + elif os.path.exists(os.path.join(get_configs_path(), cmd_opts.defaults)): + _init_config = os.path.join(get_configs_path(), cmd_opts.defaults) + else: + print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.") + _init_config = os.path.join(get_configs_path(), "sdxl-turbo.json") + return _init_config def get_sd_status(): global _sd_obj