Skip to content

Commit

Permalink
Fix custom weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 27, 2024
1 parent 353b930 commit f7d0b40
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 14 deletions.
33 changes: 27 additions & 6 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
)

from apps.shark_studio.modules.ckpt_processing import (
process_custom_pipe_weights,
preprocessCKPT,
save_irpa,
)

EMPTY_SD_MAP = {
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
import_ir: bool = True,
is_controlled: bool = False,
):
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
self.custom_vae = custom_vae
Expand Down Expand Up @@ -107,7 +109,7 @@ def __init__(
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(
os.path.join(get_checkpoints_path(), safe_name(self.base_model_id))
os.path.join(get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision))
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
Expand Down Expand Up @@ -153,10 +155,29 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
weights = copy.deepcopy(self.model_map)

if custom_weights:
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
custom_weights = os.path.join(get_checkpoints_path("checkpoints"), safe_name(self.base_model_id.split("/")[-1]), custom_weights)
diffusers_weights_path = preprocessCKPT(custom_weights, self.precision)
for key in weights:
if key not in ["vae_decode", "pipeline", "full_pipeline"]:
weights[key] = custom_weights_params
if key in ["scheduled_unet", "unet"]:
unet_weights_path = os.path.join(diffusers_weights_path, "unet", "diffusion_pytorch_model.safetensors")
weights[key] = save_irpa(unet_weights_path, "unet.")

elif key in ["clip", "prompt_encoder"]:
if not self.is_sdxl:
sd1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors")
weights[key] = save_irpa(sd1_path, "text_encoder_model.")
else:
clip_1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors")
clip_2_path = os.path.join(diffusers_weights_path, "text_encoder_2", "model.safetensors")
weights[key] = [
save_irpa(clip_1_path, "text_encoder_model_1."),
save_irpa(clip_2_path, "text_encoder_model_2.")
]

elif key in ["vae_decode"] and weights[key] is None:
vae_weights_path = os.path.join(diffusers_weights_path, "vae", "diffusion_pytorch_model.safetensors")
weights[key] = save_irpa(vae_weights_path, "vae.")


vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
Expand Down Expand Up @@ -369,7 +390,7 @@ def view_json_file(file_path):


def safe_name(name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_")
return name.replace("/", "_").replace("\\", "_").replace(".", "_")


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def parse_device(device_str):
get_iree_target_triple,
iree_target_map,
)

rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
Expand All @@ -150,7 +149,7 @@ def parse_device(device_str):
case "rocm":
triple = get_rocm_target_chip(device_str)
return target_backend, rt_device, triple
case "cpu":
case "llvm-cpu":
return "llvm-cpu", "local-task", "x86_64-linux-gnu"


Expand Down
30 changes: 25 additions & 5 deletions apps/shark_studio/modules/ckpt_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import json
import re
import requests
import torch
import safetensors
from shark_turbine.aot.params import (
ParameterArchiveBuilder,
)
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
Expand All @@ -15,21 +20,21 @@
)


def get_path_to_diffusers_checkpoint(custom_weights):
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = os.path.join("diffusers", path.stem)
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers


def preprocessCKPT(custom_weights, is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return
return path_to_diffusers
else:
print(
"Diffusers' checkpoint will be identified here : ",
Expand All @@ -51,8 +56,23 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
)
if precision == "fp16":
pipe.to(dtype=torch.float16)
pipe.save_pretrained(path_to_diffusers)
del pipe
print("Loading complete")
return path_to_diffusers

def save_irpa(weights_path, prepend_str):
weights = safetensors.torch.load_file(weights_path)
archive = ParameterArchiveBuilder()
for key in weights.keys():
new_key = prepend_str + key
archive.add_tensor(new_key, weights[key])

irpa_file = weights_path.replace(".safetensors", ".irpa")
archive.save(irpa_file)
return irpa_file


def convert_original_vae(vae_checkpoint):
Expand Down
3 changes: 2 additions & 1 deletion apps/shark_studio/modules/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ def export_scheduler_model(model):


scheduler_model_map = {
"PNDM": export_scheduler_model("PNDMScheduler"),
"DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
"LCM": export_scheduler_model("LCMScheduler"),
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
"PNDM": export_scheduler_model("PNDMScheduler"),
"DDPM": export_scheduler_model("DDPMScheduler"),
"DDIM": export_scheduler_model("DDIMScheduler"),
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
Expand Down

0 comments on commit f7d0b40

Please sign in to comment.