Skip to content

Commit

Permalink
Filesystem cleanup and custom model fixes (#2127)
Browse files Browse the repository at this point in the history
* Initial filesystem cleanup

* More filesystem cleanup

* Fix some formatting issues

* Address comments
  • Loading branch information
gpetters-amd authored and monorimet committed May 23, 2024
1 parent 1c30d49 commit 1b9bb92
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 57 deletions.
4 changes: 2 additions & 2 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def initialize():
clear_tmp_imgs()

from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
create_model_folders,
)

# Create custom models folders if they don't exist
create_checkpoint_folders()
create_model_folders()

import gradio as gr

Expand Down
4 changes: 3 additions & 1 deletion apps/shark_studio/modules/ckpt_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
Expand Down Expand Up @@ -87,6 +88,7 @@ def process_custom_pipe_weights(custom_weights):
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights

return custom_weights_params, custom_weights_tgt


Expand All @@ -98,7 +100,7 @@ def get_civitai_checkpoint(url: str):
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename

# we don't have this model downloaded yet
if not destination_path.is_file():
Expand Down
6 changes: 2 additions & 4 deletions apps/shark_studio/modules/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
Expand All @@ -55,9 +55,7 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(
os.path.join(get_checkpoints_path(".."), self.pipe_id)
)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
Expand Down
21 changes: 18 additions & 3 deletions apps/shark_studio/modules/shared_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def is_valid_file(arg):
p.add_argument(
"--output_dir",
type=str,
default=None,
default=os.path.join(os.getcwd(), "generated_imgs"),
help="Directory path to save the output images and json.",
)

Expand Down Expand Up @@ -613,12 +613,27 @@ def is_valid_file(arg):
)

p.add_argument(
"--ckpt_dir",
"--tmp_dir",
type=str,
default=os.path.join(os.getcwd(), "shark_tmp"),
help="Path to tmp directory",
)

p.add_argument(
"--config_dir",
type=str,
default="../models",
default=os.path.join(os.getcwd(), "configs"),
help="Path to config directory",
)

p.add_argument(
"--model_dir",
type=str,
default=os.path.join(os.getcwd(), "models"),
help="Path to directory where all .ckpts are stored in order to populate "
"them in the web UI.",
)

# TODO: replace API flag when these can be run together
p.add_argument(
"--ui",
Expand Down
70 changes: 39 additions & 31 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,14 @@ def import_original(original_img, width, height):


def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
ckpt_path = Path(
os.path.join(
cmd_opts.model_dir, "checkpoints", os.path.basename(str(base_model_id))
)
)
ckpt_path.mkdir(parents=True, exist_ok=True)

new_choices = get_checkpoints(ckpt_path) + get_checkpoints(model_type="checkpoints")

return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
Expand Down Expand Up @@ -580,21 +585,6 @@ def base_model_changed(base_model_id):
object_fit="fit",
preview=True,
)
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Row():
batch_count = gr.Slider(
1,
Expand Down Expand Up @@ -630,19 +620,18 @@ def base_model_changed(base_model_id):
stop_batch = gr.Button("Stop")
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_config(default_config_file)
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_config(default_config_file)
sd_json = gr.JSON(
label="SD Config",
elem_classes=["fill"],
value=view_json_file(default_config_file),
render=False,
)
with gr.Row():
with gr.Column(scale=3):
load_sd_config = gr.FileExplorer(
Expand Down Expand Up @@ -705,11 +694,30 @@ def base_model_changed(base_model_id):
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Row(elem_classes=["fill"]):
sd_json.render()
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Tab(label="Automation", id=104) as sd_tab_automation:
pass

pull_kwargs = dict(
fn=pull_sd_configs,
Expand Down
32 changes: 19 additions & 13 deletions apps/shark_studio/web/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,47 @@ def get_resource_path(path):


def get_configs_path() -> Path:
configs = get_resource_path(os.path.join("..", "configs"))
configs = get_resource_path(cmd_opts.config_dir)
if not os.path.exists(configs):
os.mkdir(configs)
return Path(get_resource_path("../configs"))
return Path(configs)


def get_generated_imgs_path() -> Path:
return Path(
cmd_opts.output_dir
if cmd_opts.output_dir
else get_resource_path("../generated_imgs")
)
outputs = get_resource_path(cmd_opts.output_dir)
if not os.path.exists(outputs):
os.mkdir(outputs)
return Path(outputs)


def get_tmp_path() -> Path:
tmpdir = get_resource_path(cmd_opts.model_dir)
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
return Path(tmpdir)


def get_generated_imgs_todays_subdir() -> str:
return dt.now().strftime("%Y%m%d")


def create_checkpoint_folders():
def create_model_folders():
dir = ["checkpoints", "vae", "lora", "vmfb"]
if not os.path.isdir(cmd_opts.ckpt_dir):
if not os.path.isdir(cmd_opts.model_dir):
try:
os.makedirs(cmd_opts.ckpt_dir)
os.makedirs(cmd_opts.model_dir)
except OSError:
sys.exit(
f"Invalid --ckpt_dir argument, "
f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created."
f"Invalid --model_dir argument, "
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
)

for root in dir:
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)


def get_checkpoints_path(model_type=""):
return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type))
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))


def get_checkpoints(model_type="checkpoints"):
Expand Down
4 changes: 3 additions & 1 deletion apps/shark_studio/web/utils/tmp_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import shutil
from time import time

shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts

shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")


def clear_tmp_mlir():
Expand Down
5 changes: 3 additions & 2 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import hashlib

from apps.shark_studio.modules.shared_cmd_opts import cmd_opts

def create_hash(file_name):
with open(file_name, "rb") as f:
Expand Down Expand Up @@ -120,7 +121,7 @@ def import_mlir(
is_dynamic=False,
tracing_required=False,
func_name="forward",
save_dir="./shark_tmp/",
save_dir=cmd_opts.tmp_dir, #"./shark_tmp/",
mlir_type="linalg",
):
if self.frontend in ["torch", "pytorch"]:
Expand Down Expand Up @@ -806,7 +807,7 @@ def save_mlir(
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = os.path.join(".", "shark_tmp")
dir = cmd_opts.tmp_dir, #os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if not os.path.exists(dir):
Expand Down

0 comments on commit 1b9bb92

Please sign in to comment.