Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ABM corrections #387

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 28 additions & 58 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from mergekit.options import MergeOptions, add_merge_options


@click.command("mergekit-activation-based-merge")
@click.argument("model_path", type=str)
@click.command("mergekit-activation-based-align")
@click.argument("secondary_model_path", type=str)
@click.argument("merge_unmerge_directory", type=str)
@click.option("--out-path", "-o", required=True, type=str, help="Output model path")
@click.option(
"--out-path", "-o", required=True, type=str, help="Path to save the aligned model"
)
@click.option(
"--dtype",
type=str,
Expand All @@ -35,15 +36,13 @@
)
@add_merge_options
def main(
model_path: str,
secondary_model_path,
merge_unmerge_directory: str,
out_path: str,
dtype: Optional[str],
device: Optional[str],
merge_options: MergeOptions,
):
model = ModelReference.model_validate(model_path)
secondary_model = ModelReference.model_validate(secondary_model_path)

dtype = dtype_from_name(dtype) if dtype else None
Expand All @@ -52,117 +51,88 @@ def main(
cache.lazy_unpickle = merge_options.lazy_unpickle
cache.hf_cache_dir = merge_options.transformers_cache

for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"):
cache.get(m)
cache.get(secondary_model)

writer = TensorWriter(
out_path=out_path,
max_shard_size=merge_options.out_shard_size,
safe_serialization=merge_options.safe_serialization,
)

model_config = model.config(trust_remote_code=merge_options.trust_remote_code)
model_config = secondary_model.config(
trust_remote_code=merge_options.trust_remote_code
)
model_arch_info = get_architecture_info(
model.config(trust_remote_code=merge_options.trust_remote_code)
secondary_model.config(trust_remote_code=merge_options.trust_remote_code)
)

loader_1 = cache.get(model)
loader_2 = cache.get(secondary_model)
loader = cache.get(secondary_model)

os.makedirs(out_path, exist_ok=True)

merge_unmerge_dictionary = {}
# load files from merge_unmerge_directory

spaces = [
f.split("_unmerge")[0]
for f in os.listdir(merge_unmerge_directory)
if "_unmerge" in f
]
for i in spaces:
logging.info(f"Loading merge/unmerge tensors for {i}")
logging.info(f"Loading merge tensors for {i}")
m = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"),
device=device,
)
u = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"),
device=device,
)
merge_unmerge_dictionary[i] = (
m[i].to(device, dtype=dtype),
u[i].to(device, dtype=dtype),
)
merge_unmerge_dictionary[i] = m[i].to(device, dtype=dtype)

for weight_info in model_arch_info.all_weights(config=model_config):
merge_matrix, unmerge_matrix = None, None
merge_matrix = None

if weight_info.input_space in merge_unmerge_dictionary:
_, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space]
unmerge_matrix = unmerge_matrix.chunk(2, dim=0)
unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space].t()

if weight_info.output_space in merge_unmerge_dictionary:
merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space]
merge_matrix = merge_matrix.chunk(2, dim=1)
merge_matrix = merge_unmerge_dictionary[weight_info.output_space]

original_w = loader_1.get_tensor(weight_info.name, device=device)
original_w2 = loader_2.get_tensor(weight_info.name, device=device)
original_w = loader.get_tensor(weight_info.name, device=device)

if dtype is not None:
original_w = original_w.to(dtype=dtype)
original_w2 = original_w2.to(dtype=dtype)

w = torch.clone(original_w)
w2 = torch.clone(original_w2)

if not merge_matrix and not unmerge_matrix:
if merge_matrix is None and unmerge_matrix is None:
logging.warning(
f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix"
f"❌ Weight {weight_info.name} for model has no merge or unmerge matrix !!"
)

if merge_matrix is not None:
if weight_info.is_embed:
w = (merge_matrix[0] @ w.T).T
w2 = (merge_matrix[1] @ w2.T).T
w = w @ merge_matrix.T
else:
w = merge_matrix[0] @ w
w2 = merge_matrix[1] @ w2
w = merge_matrix @ w

if unmerge_matrix is not None:
w = w @ unmerge_matrix[0]
w2 = w2 @ unmerge_matrix[1]
w = w @ unmerge_matrix

# check if weights have not mutated, if yes then shoot warning
if torch.allclose(original_w, w):
logging.warning(
f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge"
f"❌ Weight {weight_info.name} for input model has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 1 has mutated during merge"
f"✅ Weight {weight_info.name} for input model has mutated during merge"
)

if torch.allclose(original_w2, w2):
logging.warning(
f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 2 has mutated during merge"
)

# average weights and save them
if merge_matrix:
w = w + w2
else:
w = (w + w2) / 2
writer.save_tensor(weight_info.name, w)
writer.finalize()

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(secondary_model_path)
tokenizer.save_pretrained(out_path, safe_serialization=True)

# write config
model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code)
model_out_config = secondary_model.config(
trust_remote_code=merge_options.trust_remote_code
)
if dtype:
model_out_config.torch_dtype = dtype
model_out_config.save_pretrained(out_path)
Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/ABM/extract_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def main(
logging.info("Using chat template for inference")
tokenize_function = lambda x: tokenizer.apply_chat_template(
x,
padding="longest",
padding="max_length",
max_length=max_length,
truncation=True,
return_dict=True,
Expand All @@ -230,7 +230,7 @@ def main(
logging.info("Using default tokenizer (no chat template) for inference")
tokenize_function = lambda x: tokenizer(
x,
padding="longest",
padding="max_length",
max_length=max_length,
truncation=True,
)
Expand Down
37 changes: 7 additions & 30 deletions mergekit/scripts/ABM/extract_permutation_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,14 @@ def match_tensors_permute(
Om = correlation_matrix.shape[0] // 2
device = correlation_matrix.device

mats = [torch.eye(Om, device=device)]

corr_submatrix = correlation_matrix[:Om, Om:].cpu().numpy()
if absval:
corr_submatrix = np.absolute(corr_submatrix)
_, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True)

new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)]
mats.append(new_mat.T)

unmerge_mats = mats

unmerge = torch.cat(unmerge_mats, dim=0)
merge = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)]

merge = torch.cat(mats, dim=0)
merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)

return merge.T, unmerge
return merge


def match_tensors_permute_MHA(
Expand All @@ -63,7 +53,6 @@ def match_tensors_permute_MHA(
device = correlation_matrix.device
query_size = Om // n_heads

mats = [torch.eye(Om, device=device)]
head_perms = []

costs = np.ones((n_heads, n_heads)) * -sys.maxsize
Expand Down Expand Up @@ -106,18 +95,11 @@ def match_tensors_permute_MHA(
head_perm = col_inds_storage[head_1][head_2]
head_perms.append(torch.tensor(head_perm + query_size * head_2))

new_mat = torch.eye(Om, device=device)[
merge = torch.eye(Om, device=device)[
torch.cat(head_perms).clone().detach().long().to(device)
]
mats.append(new_mat.T)

unmerge_mats = mats

unmerge = torch.cat(unmerge_mats, dim=0)
merge = torch.cat(mats, dim=0)
merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)

return merge.T, unmerge
return merge


@click.command("mergekit-abm-extract-permutations")
Expand Down Expand Up @@ -197,14 +179,14 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device):
correlation_matrix = calc_correlation_matrix(concatenated_feature)

if feature_space in (kq_spaces + v_spaces):
merge, unmerge = match_tensors_permute_MHA(
merge = match_tensors_permute_MHA(
correlation_matrix=correlation_matrix,
n_heads=model_config.num_attention_heads,
absval=absval,
)

else:
merge, unmerge = match_tensors_permute(
merge = match_tensors_permute(
correlation_matrix=correlation_matrix,
absval=absval,
)
Expand All @@ -214,12 +196,7 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device):
f"{out_path}/{feature_space}_merge.safetensor",
)

safetensors.torch.save_file(
{feature_space: unmerge.contiguous()},
f"{out_path}/{feature_space}_unmerge.safetensor",
)

del merge, unmerge, correlation_matrix, concatenated_feature
del merge, correlation_matrix, concatenated_feature


if __name__ == "__main__":
Expand Down
Loading