Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 27, 2024
1 parent f7d0b40 commit a070a43
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
41 changes: 31 additions & 10 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def __init__(
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(
os.path.join(get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision))
os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
Expand Down Expand Up @@ -155,29 +157,48 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
weights = copy.deepcopy(self.model_map)

if custom_weights:
custom_weights = os.path.join(get_checkpoints_path("checkpoints"), safe_name(self.base_model_id.split("/")[-1]), custom_weights)
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
custom_weights,
)
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
for key in weights:
if key in ["scheduled_unet", "unet"]:
unet_weights_path = os.path.join(diffusers_weights_path, "unet", "diffusion_pytorch_model.safetensors")
unet_weights_path = os.path.join(
diffusers_weights_path,
"unet",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(unet_weights_path, "unet.")

elif key in ["clip", "prompt_encoder"]:
if not self.is_sdxl:
sd1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors")
sd1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
else:
clip_1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors")
clip_2_path = os.path.join(diffusers_weights_path, "text_encoder_2", "model.safetensors")
clip_1_path = os.path.join(
diffusers_weights_path, "text_encoder", "model.safetensors"
)
clip_2_path = os.path.join(
diffusers_weights_path,
"text_encoder_2",
"model.safetensors",
)
weights[key] = [
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2.")
save_irpa(clip_2_path, "text_encoder_model_2."),
]

elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(diffusers_weights_path, "vae", "diffusion_pytorch_model.safetensors")
vae_weights_path = os.path.join(
diffusers_weights_path,
"vae",
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")


vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
Expand Down
1 change: 1 addition & 0 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def parse_device(device_str):
get_iree_target_triple,
iree_target_map,
)

rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
Expand Down
3 changes: 2 additions & 1 deletion apps/shark_studio/modules/ckpt_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
return path_to_diffusers


def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False):
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
Expand Down Expand Up @@ -63,6 +63,7 @@ def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False):
print("Loading complete")
return path_to_diffusers


def save_irpa(weights_path, prepend_str):
weights = safetensors.torch.load_file(weights_path)
archive = ParameterArchiveBuilder()
Expand Down

0 comments on commit a070a43

Please sign in to comment.