diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py index d8cf544f81..131c9006e5 100644 --- a/apps/shark_studio/modules/embeddings.py +++ b/apps/shark_studio/modules/embeddings.py @@ -3,34 +3,56 @@ import torch import json import safetensors +from dataclasses import dataclass from safetensors.torch import load_file from apps.shark_studio.api.utils import get_checkpoint_pathfile -def processLoRA(model, use_lora, splitting_prefix): +@dataclass +class LoRAweight: + up: torch.tensor + down: torch.tensor + mid: torch.tensor + alpha: torch.float32 = 1.0 + + +def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75): state_dict = "" if ".safetensors" in use_lora: state_dict = load_file(use_lora) else: state_dict = torch.load(use_lora) - alpha = 0.75 - visited = [] - # directly update weight in model - process_unet = "te" not in splitting_prefix + # gather the weights from the LoRA in a more convenient form, assumes + # everything will have an up.weight. + weight_dict: dict[str, LoRAweight] = {} for key in state_dict: - if ".alpha" in key or key in visited: - continue - + if key.startswith(splitting_prefix) and key.endswith("up.weight"): + stem = key.split("up.weight")[0] + weight_key = stem.removesuffix(".lora_") + weight_key = weight_key.removesuffix("_lora_") + weight_key = weight_key.removesuffix(".lora_linear_layer.") + + if weight_key not in weight_dict: + weight_dict[weight_key] = LoRAweight( + state_dict[f"{stem}up.weight"], + state_dict[f"{stem}down.weight"], + state_dict.get(f"{stem}mid.weight", None), + state_dict[f"{weight_key}.alpha"] + / state_dict[f"{stem}up.weight"].shape[1] + if f"{weight_key}.alpha" in state_dict + else 1.0, + ) + + # Directly update weight in model + + # Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py + # and similar code in https://github.com/huggingface/diffusers/issues/3064 + + # TODO: handle mid weights (how do they even work?) + for key, lora_weight in weight_dict.items(): curr_layer = model - if ("text" not in key and process_unet) or ( - "text" in key and not process_unet - ): - layer_infos = ( - key.split(".")[0].split(splitting_prefix)[-1].split("_") - ) - else: - continue + layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_") # find the target layer temp_name = layer_infos.pop(0) @@ -47,42 +69,39 @@ def processLoRA(model, use_lora, splitting_prefix): else: temp_name = layer_infos.pop(0) - pair_keys = [] - if "lora_down" in key: - pair_keys.append(key.replace("lora_down", "lora_up")) - pair_keys.append(key) - else: - pair_keys.append(key) - pair_keys.append(key.replace("lora_up", "lora_down")) - - # update weight - if len(state_dict[pair_keys[0]].shape) == 4: - weight_up = ( - state_dict[pair_keys[0]] - .squeeze(3) - .squeeze(2) - .to(torch.float32) - ) + weight = curr_layer.weight.data + scale = lora_weight.alpha * lora_strength + if len(weight.size()) == 2: + if len(lora_weight.up.shape) == 4: + weight_up = ( + lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) + ) + weight_down = ( + lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) + ) + change = ( + torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + ) + else: + change = torch.mm(lora_weight.up, lora_weight.down) + elif lora_weight.down.size()[2:4] == (1, 1): + weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) weight_down = ( - state_dict[pair_keys[1]] - .squeeze(3) - .squeeze(2) - .to(torch.float32) + lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) ) - curr_layer.weight.data += alpha * torch.mm( - weight_up, weight_down - ).unsqueeze(2).unsqueeze(3) + change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) else: - weight_up = state_dict[pair_keys[0]].to(torch.float32) - weight_down = state_dict[pair_keys[1]].to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) - # update visited list - for item in pair_keys: - visited.append(item) + change = torch.nn.functional.conv2d( + lora_weight.down.permute(1, 0, 2, 3), + lora_weight.up, + ).permute(1, 0, 2, 3) + + curr_layer.weight.data += change * scale + return model -def update_lora_weight_for_unet(unet, use_lora): +def update_lora_weight_for_unet(unet, use_lora, lora_strength): extensions = [".bin", ".safetensors", ".pt"] if not any([extension in use_lora for extension in extensions]): # We assume if it is a HF ID with standalone LoRA weights. @@ -104,14 +123,14 @@ def update_lora_weight_for_unet(unet, use_lora): unet.load_attn_procs(dir_name, weight_name=main_file_name) return unet except: - return processLoRA(unet, use_lora, "lora_unet_") + return processLoRA(unet, use_lora, "lora_unet_", lora_strength) -def update_lora_weight(model, use_lora, model_name): +def update_lora_weight(model, use_lora, model_name, lora_strength=1.0): if "unet" in model_name: - return update_lora_weight_for_unet(model, use_lora) + return update_lora_weight_for_unet(model, use_lora, lora_strength) try: - return processLoRA(model, use_lora, "lora_te_") + return processLoRA(model, use_lora, "lora_te_", lora_strength) except: return None