Skip to content

Commit

Permalink
Studio2/SD: Use more correct LoRA alpha calculation (#2034)
Browse files Browse the repository at this point in the history
* Updates ProcessLoRA to use both embedded LoRA alpha, and lora_strength
optional parameter (default 1.0) when applying LoRA weights.
* Updates ProcessLoRA to cover more dim cases.
* This bring ProcessLoRA into line with PR #2015 against Studio1
  • Loading branch information
one-lithe-rune authored and monorimet committed Dec 15, 2023
1 parent b3d5add commit 782485e
Showing 1 changed file with 69 additions and 50 deletions.
119 changes: 69 additions & 50 deletions apps/shark_studio/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,56 @@
import torch
import json
import safetensors
from dataclasses import dataclass
from safetensors.torch import load_file
from apps.shark_studio.api.utils import get_checkpoint_pathfile


def processLoRA(model, use_lora, splitting_prefix):
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0


def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []

# directly update weight in model
process_unet = "te" not in splitting_prefix
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if ".alpha" in key or key in visited:
continue

if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")

if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0,
)

# Directly update weight in model

# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064

# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")

# find the target layer
temp_name = layer_infos.pop(0)
Expand All @@ -47,42 +69,39 @@ def processLoRA(model, use_lora, splitting_prefix):
else:
temp_name = layer_infos.pop(0)

pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))

# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = (
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
)
weight_down = (
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
change = (
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)

curr_layer.weight.data += change * scale

return model


def update_lora_weight_for_unet(unet, use_lora):
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
Expand All @@ -104,14 +123,14 @@ def update_lora_weight_for_unet(unet, use_lora):
unet.load_attn_procs(dir_name, weight_name=main_file_name)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)


def update_lora_weight(model, use_lora, model_name):
def update_lora_weight(model, use_lora, model_name, lora_strength=1.0):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_")
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None

Expand Down

0 comments on commit 782485e

Please sign in to comment.