diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 1ab00f3f..e2e8b14c 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -64,6 +64,13 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: ConfigParameterDef(name="random_mask", required=False, default_value= 0.0), ConfigParameterDef(name="random_mask_seed", required=False, default_value= None), ] + if self.sparsification_method == SparsificationMethod.magnitude_outliers: + res.append( + ConfigParameterDef( + name="gamma", + default_value=0.01, + ) + ) return res def make_task( @@ -83,7 +90,7 @@ def make_task( normalize=parameters["normalize"], rescale=parameters["rescale"], swapping=parameters["swapping"], - out_tensor_name=output_weight.name, + weight_info=output_weight, ) @@ -91,7 +98,7 @@ class GTATask(Task[torch.Tensor]): method: GeneralizedTaskArithmeticMerge tensors: GatherTensors base_model: ModelReference - out_tensor_name: str + weight_info: WeightInfo tensor_parameters: ImmutableMap[ModelReference, Any] int8_mask: bool normalize: bool @@ -111,7 +118,7 @@ def execute( ) -> torch.Tensor: # collect task vectors tvs, base = get_task_vectors( - self.out_tensor_name, + self.weight_info, self.base_model, tensors, tensor_parameters=self.tensor_parameters.data, @@ -123,11 +130,15 @@ def execute( # sparsify if self.method.sparsification_method: for tv_info in tvs: + kwargs = {} + if "gamma" in tv_info: + kwargs["gamma"] = tv_info["gamma"] tv_info["delta"] = sparsify( tv_info["delta"], density=tv_info["density"], method=self.method.sparsification_method, rescale=self.rescale, + **kwargs, ) deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) @@ -218,7 +229,7 @@ def rand_mask(base, x, percent, seed=None): def get_task_vectors( - parameter_name: str, + weight_info: WeightInfo, base_model: ModelReference, tensors: ImmutableMap[ModelReference, torch.Tensor], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], @@ -226,6 +237,7 @@ def get_task_vectors( ) -> Tuple[List[Dict[str, Any]], torch.Tensor]: keys = list(tensors.keys()) base = tensors[base_model] + parameter_name = weight_info.name res = [] for model in keys: @@ -235,7 +247,7 @@ def get_task_vectors( x = tensors[model].to(base.dtype) if x.shape != base.shape: - if "lm_head" in parameter_name or "embed_tokens" in parameter_name: + if weight_info.is_embed: x = x[: base.shape[0], : base.shape[1]] logging.warning(f"Using submatrix of {model}:{parameter_name}") else: