From 8d689e409195efcba51eb73eeecd400821e78f6e Mon Sep 17 00:00:00 2001 From: "Charles O. Goddard" Date: Fri, 3 May 2024 23:02:31 -0700 Subject: [PATCH] Fix BERT merging (#295) Various changes to better support merging BERT based models. --- mergekit/_data/architectures/bert.json | 121 ++++++++++++++---- mergekit/architecture.py | 2 +- mergekit/common.py | 34 +---- .../generalized_task_arithmetic.py | 12 +- mergekit/merge_methods/linear.py | 11 +- mergekit/merge_methods/model_stock.py | 16 ++- mergekit/merge_methods/rectify_embed.py | 47 +++++++ mergekit/merge_methods/slerp.py | 9 +- 8 files changed, 177 insertions(+), 75 deletions(-) create mode 100644 mergekit/merge_methods/rectify_embed.py diff --git a/mergekit/_data/architectures/bert.json b/mergekit/_data/architectures/bert.json index 60384d0f..5de5f5b0 100644 --- a/mergekit/_data/architectures/bert.json +++ b/mergekit/_data/architectures/bert.json @@ -5,101 +5,168 @@ ], "pre_weights": [ { - "name": "bert.embeddings.position_embeddings.weight" + "name": "embeddings.position_embeddings.weight", + "aliases": [ + "bert.embeddings.position_embeddings.weight" + ] }, { - "name": "bert.embeddings.token_type_embeddings.weight" + "name": "embeddings.token_type_embeddings.weight", + "aliases": [ + "bert.embeddings.token_type_embeddings.weight" + ] }, { - "name": "bert.embeddings.word_embeddings.weight", - "is_embed": true + "name": "embeddings.word_embeddings.weight", + "is_embed": true, + "aliases": [ + "bert.embeddings.word_embeddings.weight" + ] }, { - "name": "bert.embeddings.LayerNorm.bias", + "name": "embeddings.LayerNorm.bias", "aliases": [ + "embeddings.LayerNorm.beta", + "bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.beta" ] }, { - "name": "bert.embeddings.LayerNorm.weight", + "name": "embeddings.LayerNorm.weight", "aliases": [ - "bert.embeddings.LayerNorm.gamma" + "embeddings.LayerNorm.gamma", + "bert.embeddings.LayerNorm.weight", + "bert.embeddings.LayerNorm.gamma", + "bert.embeddings.LayerNorm.weight" ] }, { - "name": "bert.embeddings.position_ids", + "name": "embeddings.position_ids", "optional": true, - "force_dtype": "int64" + "force_dtype": "int64", + "aliases": [ + "bert.embeddings.position_ids" + ] } ], "post_weights": [ { - "name": "pooler.dense.weight" + "name": "pooler.dense.weight", + "aliases": [ + "bert.pooler.dense.weight" + ] }, { - "name": "pooler.dense.bias" + "name": "pooler.dense.bias", + "aliases": [ + "bert.pooler.dense.bias" + ] } ], "num_layers_config_key": "num_hidden_layers", "layer_templates": { "weights": [ { - "name": "bert.encoder.layer.${layer_index}.attention.self.query.weight" + "name": "encoder.layer.${layer_index}.attention.self.query.weight", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.self.query.weight" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.self.query.bias" + "name": "encoder.layer.${layer_index}.attention.self.query.bias", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.self.query.bias" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.self.key.weight" + "name": "encoder.layer.${layer_index}.attention.self.key.weight", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.self.key.weight" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.self.key.bias" + "name": "encoder.layer.${layer_index}.attention.self.key.bias", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.self.key.bias" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.self.value.weight" + "name": "encoder.layer.${layer_index}.attention.self.value.weight", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.self.value.weight" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.self.value.bias" + "name": "encoder.layer.${layer_index}.attention.self.value.bias", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.self.value.bias" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.output.dense.weight" + "name": "encoder.layer.${layer_index}.attention.output.dense.weight", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.output.dense.weight" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.output.dense.bias" + "name": "encoder.layer.${layer_index}.attention.output.dense.bias", + "aliases": [ + "bert.encoder.layer.${layer_index}.attention.output.dense.bias" + ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.bias", + "name": "encoder.layer.${layer_index}.attention.output.LayerNorm.bias", "aliases": [ + "encoder.layer.${layer_index}.attention.output.LayerNorm.beta", + "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.bias", "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.beta" ] }, { - "name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.weight", + "name": "encoder.layer.${layer_index}.attention.output.LayerNorm.weight", "aliases": [ + "encoder.layer.${layer_index}.attention.output.LayerNorm.gamma", + "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.weight", "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.gamma" ] }, { - "name": "bert.encoder.layer.${layer_index}.intermediate.dense.weight" + "name": "encoder.layer.${layer_index}.intermediate.dense.weight", + "aliases": [ + "bert.encoder.layer.${layer_index}.intermediate.dense.weight" + ] }, { - "name": "bert.encoder.layer.${layer_index}.intermediate.dense.bias" + "name": "encoder.layer.${layer_index}.intermediate.dense.bias", + "aliases": [ + "bert.encoder.layer.${layer_index}.intermediate.dense.bias" + ] }, { - "name": "bert.encoder.layer.${layer_index}.output.dense.weight" + "name": "encoder.layer.${layer_index}.output.dense.weight", + "aliases": [ + "bert.encoder.layer.${layer_index}.output.dense.weight" + ] }, { - "name": "bert.encoder.layer.${layer_index}.output.dense.bias" + "name": "encoder.layer.${layer_index}.output.dense.bias", + "aliases": [ + "bert.encoder.layer.${layer_index}.output.dense.bias" + ] }, { - "name": "bert.encoder.layer.${layer_index}.output.LayerNorm.bias", + "name": "encoder.layer.${layer_index}.output.LayerNorm.bias", "aliases": [ + "encoder.layer.${layer_index}.output.LayerNorm.beta", + "bert.encoder.layer.${layer_index}.output.LayerNorm.bias", "bert.encoder.layer.${layer_index}.output.LayerNorm.beta" ] }, { - "name": "bert.encoder.layer.${layer_index}.output.LayerNorm.weight", + "name": "encoder.layer.${layer_index}.output.LayerNorm.weight", "aliases": [ + "encoder.layer.${layer_index}.output.LayerNorm.gamma", + "bert.encoder.layer.${layer_index}.output.LayerNorm.weight", "bert.encoder.layer.${layer_index}.output.LayerNorm.gamma" ] } diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 16acbbab..653f1ac3 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -50,7 +50,7 @@ class WeightInfo(BaseModel, frozen=True): input_space: Optional[str] = None output_space: Optional[str] = None optional: bool = False - aliases: Optional[List[str]] = None + aliases: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None diff --git a/mergekit/common.py b/mergekit/common.py index f652bc13..d7d4eac5 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -23,7 +23,6 @@ Dict, Generic, Iterator, - List, Mapping, Optional, Tuple, @@ -81,6 +80,7 @@ class ModelReference(BaseModel, frozen=True): model: ModelPath lora: Optional[ModelPath] = None + override_architecture: Optional[str] = None def merged( self, cache_dir: Optional[str] = None, trust_remote_code: bool = False @@ -122,11 +122,14 @@ def merged( return ModelReference(model=out_path) def config(self, trust_remote_code: bool = False) -> PretrainedConfig: - return AutoConfig.from_pretrained( + res = AutoConfig.from_pretrained( self.model.path, revision=self.model.revision, trust_remote_code=trust_remote_code, ) + if self.override_architecture: + res.architectures = [self.override_architecture] + return res def tensor_index(self, cache_dir: Optional[str] = None) -> ShardedTensorIndex: assert self.lora is None @@ -209,33 +212,6 @@ def dtype_from_name(name: Optional[str]) -> Optional[torch.dtype]: raise RuntimeError(f'Unimplemented dtype "{name}"') -def rectify_embed_sizes(param_name: str, tensors: List[torch.Tensor]): - # TODO: use arch_info.embed_weights() instead - if ("lm_head" in param_name or "embed_tokens" in param_name) and all( - len(t.shape) == 2 for t in tensors - ): - # special case - if lm_head.weight or embed_tokens.weight have a size - # mismatch, take the largest common submatrix of all of them - if take_common_submatrix(tensors): - logging.warning( - f"Using common submatrix of size {tensors[0].shape} for {param_name}" - ) - - -def take_common_submatrix(tensors: List[torch.Tensor]) -> bool: - min_size = [None, None] - for t in tensors: - for idx in range(2): - if min_size[idx] is None or t.shape[idx] < min_size[idx]: - min_size[idx] = t.shape[idx] - - if not all(t.shape == torch.Size(min_size) for t in tensors): - for idx in range(len(tensors)): - tensors[idx] = tensors[idx][: min_size[0], : min_size[1]] - return True - return False - - def parse_kmb(value: Union[str, int]) -> int: if isinstance(value, int): return value diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index b4000f1e..02c1277f 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -81,7 +81,7 @@ def make_task( int8_mask=parameters["int8_mask"], normalize=parameters["normalize"], rescale=parameters["rescale"], - out_tensor_name=output_weight.name, + weight_info=output_weight, ) @@ -89,7 +89,7 @@ class GTATask(Task[torch.Tensor]): method: GeneralizedTaskArithmeticMerge tensors: GatherTensors base_model: ModelReference - out_tensor_name: str + weight_info: WeightInfo tensor_parameters: ImmutableMap[ModelReference, Any] int8_mask: bool normalize: bool @@ -108,7 +108,7 @@ def execute( ) -> torch.Tensor: # collect task vectors tvs, base = get_task_vectors( - self.out_tensor_name, + self.weight_info, self.base_model, tensors, tensor_parameters=self.tensor_parameters.data, @@ -166,7 +166,7 @@ def group_label(self) -> Optional[str]: def get_task_vectors( - parameter_name: str, + weight_info: WeightInfo, base_model: ModelReference, tensors: ImmutableMap[ModelReference, torch.Tensor], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], @@ -174,6 +174,8 @@ def get_task_vectors( keys = list(tensors.keys()) base = tensors[base_model] + parameter_name = weight_info.name + res = [] for model in keys: if model == base_model: @@ -181,7 +183,7 @@ def get_task_vectors( x = tensors[model].to(base.dtype) if x.shape != base.shape: - if "lm_head" in parameter_name or "embed_tokens" in parameter_name: + if weight_info.is_embed: x = x[: base.shape[0], : base.shape[1]] logging.warning(f"Using submatrix of {model}:{parameter_name}") else: diff --git a/mergekit/merge_methods/linear.py b/mergekit/merge_methods/linear.py index 5e637588..81826a97 100644 --- a/mergekit/merge_methods/linear.py +++ b/mergekit/merge_methods/linear.py @@ -18,17 +18,18 @@ import torch from mergekit.architecture import WeightInfo -from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes +from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes class LinearMergeTask(Task[torch.Tensor]): gather_tensors: GatherTensors tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] normalize: bool - parameter_name: str + weight_info: WeightInfo def uses_accelerator(self) -> bool: return True @@ -44,12 +45,12 @@ def execute( tensors = [tensors[key] for key in keys] weights = [self.tensor_parameters[key]["weight"] for key in keys] - rectify_embed_sizes(self.parameter_name, tensors) + rectify_embed_sizes(self.weight_info, tensors) unique_shapes = set(t.shape for t in tensors) if len(unique_shapes) != 1: raise RuntimeError( - f"Tensor size mismatch for {self.parameter_name}, sizes: {list(unique_shapes)}" + f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" ) tensors = torch.stack(tensors, dim=0) @@ -89,5 +90,5 @@ def make_task( gather_tensors=tensors, tensor_parameters=tensor_parameters, normalize=parameters["normalize"], - parameter_name=output_weight.name, + weight_info=output_weight, ) diff --git a/mergekit/merge_methods/model_stock.py b/mergekit/merge_methods/model_stock.py index 4e55b347..5130f3ea 100644 --- a/mergekit/merge_methods/model_stock.py +++ b/mergekit/merge_methods/model_stock.py @@ -13,21 +13,23 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. +import logging from typing import Any, Dict, List, Optional import torch from mergekit.architecture import WeightInfo -from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes +from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes class ModelStockMergeTask(Task[torch.Tensor]): gather_tensors: GatherTensors base_model: ModelReference - parameter_name: str + weight_info: WeightInfo filter_wise: bool = False def uses_accelerator(self) -> bool: @@ -40,6 +42,12 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: if len(tensors) == 1 and self.base_model in tensors: return tensors[self.base_model] if len(tensors) < 3: + if self.weight_info.optional: + logging.warning( + f"Optional weight {self.weight_info.name} not present in enough models, discarding" + ) + return None + raise ValueError( "ModelStockMerge requires at least 3 models (base plus two+ others)" ) @@ -93,7 +101,7 @@ def get_rectified_weights(self, tensors: Dict[ModelReference, torch.Tensor]): all_weights = [tensors[self.base_model]] + [ tensors[k] for k in tensors if k != self.base_model ] - rectify_embed_sizes(self.parameter_name, all_weights) + rectify_embed_sizes(self.weight_info, all_weights) w_0 = all_weights[0] ws = all_weights[1:] return w_0, ws @@ -120,6 +128,6 @@ def make_task( return ModelStockMergeTask( gather_tensors=tensors, base_model=base_model, - parameter_name=output_weight.name, + weight_info=output_weight, filter_wise=parameters["filter_wise"], ) diff --git a/mergekit/merge_methods/rectify_embed.py b/mergekit/merge_methods/rectify_embed.py new file mode 100644 index 00000000..0d116b4f --- /dev/null +++ b/mergekit/merge_methods/rectify_embed.py @@ -0,0 +1,47 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + + +import logging +from typing import List + +import torch + +from mergekit.architecture import WeightInfo + + +def rectify_embed_sizes(weight_info: WeightInfo, tensors: List[torch.Tensor]): + # TODO: use arch_info.embed_weights() instead + if weight_info.is_embed and all(len(t.shape) == 2 for t in tensors): + # special case - if lm_head.weight or embed_tokens.weight have a size + # mismatch, take the largest common submatrix of all of them + if take_common_submatrix(tensors): + logging.warning( + f"Using common submatrix of size {tensors[0].shape} for {weight_info.name}" + ) + + +def take_common_submatrix(tensors: List[torch.Tensor]) -> bool: + min_size = [None, None] + for t in tensors: + for idx in range(2): + if min_size[idx] is None or t.shape[idx] < min_size[idx]: + min_size[idx] = t.shape[idx] + + if not all(t.shape == torch.Size(min_size) for t in tensors): + for idx in range(len(tensors)): + tensors[idx] = tensors[idx][: min_size[0], : min_size[1]] + return True + return False diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index 81a189fb..dd89d09e 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -19,17 +19,18 @@ import torch from mergekit.architecture import WeightInfo -from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes +from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes class SlerpTask(Task[torch.Tensor]): gather_tensors: GatherTensors base_model: ModelReference t: float - parameter_name: str + weight_info: WeightInfo def uses_accelerator(self) -> bool: return True @@ -50,7 +51,7 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: [a, b] = [b, a] prepped_tensors = [a[1], b[1]] - rectify_embed_sizes(self.parameter_name, prepped_tensors) + rectify_embed_sizes(self.weight_info, prepped_tensors) return ( slerp( @@ -82,7 +83,7 @@ def make_task( return SlerpTask( gather_tensors=tensors, base_model=base_model, - parameter_name=output_weight.name, + weight_info=output_weight, t=parameters["t"], )