From c2f717d6e95555fa390f38ea4c687866197778fb Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 27 May 2024 00:02:18 -0500 Subject: [PATCH] Formatting --- apps/shark_studio/api/sd.py | 41 +++++++++++++++----- apps/shark_studio/api/utils.py | 1 + apps/shark_studio/modules/ckpt_processing.py | 3 +- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index a3c6f294e6..83574d294d 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -109,7 +109,9 @@ def __init__( self.pipe_id = "_".join(pipe_id_list) self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) self.weights_path = Path( - os.path.join(get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)) + os.path.join( + get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision) + ) ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) @@ -155,29 +157,48 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): weights = copy.deepcopy(self.model_map) if custom_weights: - custom_weights = os.path.join(get_checkpoints_path("checkpoints"), safe_name(self.base_model_id.split("/")[-1]), custom_weights) + custom_weights = os.path.join( + get_checkpoints_path("checkpoints"), + safe_name(self.base_model_id.split("/")[-1]), + custom_weights, + ) diffusers_weights_path = preprocessCKPT(custom_weights, self.precision) for key in weights: if key in ["scheduled_unet", "unet"]: - unet_weights_path = os.path.join(diffusers_weights_path, "unet", "diffusion_pytorch_model.safetensors") + unet_weights_path = os.path.join( + diffusers_weights_path, + "unet", + "diffusion_pytorch_model.safetensors", + ) weights[key] = save_irpa(unet_weights_path, "unet.") - + elif key in ["clip", "prompt_encoder"]: if not self.is_sdxl: - sd1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") + sd1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) weights[key] = save_irpa(sd1_path, "text_encoder_model.") else: - clip_1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") - clip_2_path = os.path.join(diffusers_weights_path, "text_encoder_2", "model.safetensors") + clip_1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) + clip_2_path = os.path.join( + diffusers_weights_path, + "text_encoder_2", + "model.safetensors", + ) weights[key] = [ save_irpa(clip_1_path, "text_encoder_model_1."), - save_irpa(clip_2_path, "text_encoder_model_2.") + save_irpa(clip_2_path, "text_encoder_model_2."), ] elif key in ["vae_decode"] and weights[key] is None: - vae_weights_path = os.path.join(diffusers_weights_path, "vae", "diffusion_pytorch_model.safetensors") + vae_weights_path = os.path.join( + diffusers_weights_path, + "vae", + "diffusion_pytorch_model.safetensors", + ) weights[key] = save_irpa(vae_weights_path, "vae.") - vmfbs, weights = self.sd_pipe.check_prepared( mlirs, vmfbs, weights, interactive=False diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index d63da5fc1a..0516255d2b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -135,6 +135,7 @@ def parse_device(device_str): get_iree_target_triple, iree_target_map, ) + rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) if device_id: diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index 523ca08b57..433df13654 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -30,7 +30,7 @@ def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"): return path_to_diffusers -def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False): +def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False): path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision) if next(Path(path_to_diffusers).iterdir(), None): print("Checkpoint already loaded at : ", path_to_diffusers) @@ -63,6 +63,7 @@ def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False): print("Loading complete") return path_to_diffusers + def save_irpa(weights_path, prepend_str): weights = safetensors.torch.load_file(weights_path) archive = ParameterArchiveBuilder()