From f7d0b40d436e6e3b220f94b7da69933d9a0f365b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 27 May 2024 00:01:57 -0500 Subject: [PATCH] Fix custom weights. --- apps/shark_studio/api/sd.py | 33 ++++++++++++++++---- apps/shark_studio/api/utils.py | 3 +- apps/shark_studio/modules/ckpt_processing.py | 30 +++++++++++++++--- apps/shark_studio/modules/schedulers.py | 3 +- 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 150a72d28e..a3c6f294e6 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -29,7 +29,8 @@ ) from apps.shark_studio.modules.ckpt_processing import ( - process_custom_pipe_weights, + preprocessCKPT, + save_irpa, ) EMPTY_SD_MAP = { @@ -77,6 +78,7 @@ def __init__( import_ir: bool = True, is_controlled: bool = False, ): + self.precision = precision self.compiled_pipeline = False self.base_model_id = base_model_id self.custom_vae = custom_vae @@ -107,7 +109,7 @@ def __init__( self.pipe_id = "_".join(pipe_id_list) self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) self.weights_path = Path( - os.path.join(get_checkpoints_path(), safe_name(self.base_model_id)) + os.path.join(get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)) ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) @@ -153,10 +155,29 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): weights = copy.deepcopy(self.model_map) if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights(custom_weights) + custom_weights = os.path.join(get_checkpoints_path("checkpoints"), safe_name(self.base_model_id.split("/")[-1]), custom_weights) + diffusers_weights_path = preprocessCKPT(custom_weights, self.precision) for key in weights: - if key not in ["vae_decode", "pipeline", "full_pipeline"]: - weights[key] = custom_weights_params + if key in ["scheduled_unet", "unet"]: + unet_weights_path = os.path.join(diffusers_weights_path, "unet", "diffusion_pytorch_model.safetensors") + weights[key] = save_irpa(unet_weights_path, "unet.") + + elif key in ["clip", "prompt_encoder"]: + if not self.is_sdxl: + sd1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") + weights[key] = save_irpa(sd1_path, "text_encoder_model.") + else: + clip_1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") + clip_2_path = os.path.join(diffusers_weights_path, "text_encoder_2", "model.safetensors") + weights[key] = [ + save_irpa(clip_1_path, "text_encoder_model_1."), + save_irpa(clip_2_path, "text_encoder_model_2.") + ] + + elif key in ["vae_decode"] and weights[key] is None: + vae_weights_path = os.path.join(diffusers_weights_path, "vae", "diffusion_pytorch_model.safetensors") + weights[key] = save_irpa(vae_weights_path, "vae.") + vmfbs, weights = self.sd_pipe.check_prepared( mlirs, vmfbs, weights, interactive=False @@ -369,7 +390,7 @@ def view_json_file(file_path): def safe_name(name): - return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_") + return name.replace("/", "_").replace("\\", "_").replace(".", "_") if __name__ == "__main__": diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 0e53bd4a5a..d63da5fc1a 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -135,7 +135,6 @@ def parse_device(device_str): get_iree_target_triple, iree_target_map, ) - rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) if device_id: @@ -150,7 +149,7 @@ def parse_device(device_str): case "rocm": triple = get_rocm_target_chip(device_str) return target_backend, rt_device, triple - case "cpu": + case "llvm-cpu": return "llvm-cpu", "local-task", "x86_64-linux-gnu" diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index fc0bd3b7b8..523ca08b57 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -2,6 +2,11 @@ import json import re import requests +import torch +import safetensors +from shark_turbine.aot.params import ( + ParameterArchiveBuilder, +) from io import BytesIO from pathlib import Path from tqdm import tqdm @@ -15,21 +20,21 @@ ) -def get_path_to_diffusers_checkpoint(custom_weights): +def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"): path = Path(custom_weights) diffusers_path = path.parent.absolute() - diffusers_directory_name = os.path.join("diffusers", path.stem) + diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}") complete_path_to_diffusers = diffusers_path / diffusers_directory_name complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) path_to_diffusers = complete_path_to_diffusers.as_posix() return path_to_diffusers -def preprocessCKPT(custom_weights, is_inpaint=False): - path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) +def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision) if next(Path(path_to_diffusers).iterdir(), None): print("Checkpoint already loaded at : ", path_to_diffusers) - return + return path_to_diffusers else: print( "Diffusers' checkpoint will be identified here : ", @@ -51,8 +56,23 @@ def preprocessCKPT(custom_weights, is_inpaint=False): from_safetensors=from_safetensors, num_in_channels=num_in_channels, ) + if precision == "fp16": + pipe.to(dtype=torch.float16) pipe.save_pretrained(path_to_diffusers) + del pipe print("Loading complete") + return path_to_diffusers + +def save_irpa(weights_path, prepend_str): + weights = safetensors.torch.load_file(weights_path) + archive = ParameterArchiveBuilder() + for key in weights.keys(): + new_key = prepend_str + key + archive.add_tensor(new_key, weights[key]) + + irpa_file = weights_path.replace(".safetensors", ".irpa") + archive.save(irpa_file) + return irpa_file def convert_original_vae(vae_checkpoint): diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 8c2413c638..1f5f99bc6d 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -101,11 +101,12 @@ def export_scheduler_model(model): scheduler_model_map = { + "PNDM": export_scheduler_model("PNDMScheduler"), + "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"), "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"), "LCM": export_scheduler_model("LCMScheduler"), "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), - "PNDM": export_scheduler_model("PNDMScheduler"), "DDPM": export_scheduler_model("DDPMScheduler"), "DDIM": export_scheduler_model("DDIMScheduler"), "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),