Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Ar57m authored May 7, 2024
1 parent 09a4fb9 commit cbbd2cc
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
ConfigParameterDef(name="random_mask", required=False, default_value= 0.0),
ConfigParameterDef(name="random_mask_seed", required=False, default_value= None),
]
if self.sparsification_method == SparsificationMethod.magnitude_outliers:
res.append(
ConfigParameterDef(
name="gamma",
default_value=0.01,
)
)
return res

def make_task(
Expand All @@ -83,15 +90,15 @@ def make_task(
normalize=parameters["normalize"],
rescale=parameters["rescale"],
swapping=parameters["swapping"],
out_tensor_name=output_weight.name,
weight_info=output_weight,
)


class GTATask(Task[torch.Tensor]):
method: GeneralizedTaskArithmeticMerge
tensors: GatherTensors
base_model: ModelReference
out_tensor_name: str
weight_info: WeightInfo
tensor_parameters: ImmutableMap[ModelReference, Any]
int8_mask: bool
normalize: bool
Expand All @@ -111,7 +118,7 @@ def execute(
) -> torch.Tensor:
# collect task vectors
tvs, base = get_task_vectors(
self.out_tensor_name,
self.weight_info,
self.base_model,
tensors,
tensor_parameters=self.tensor_parameters.data,
Expand All @@ -123,11 +130,15 @@ def execute(
# sparsify
if self.method.sparsification_method:
for tv_info in tvs:
kwargs = {}
if "gamma" in tv_info:
kwargs["gamma"] = tv_info["gamma"]
tv_info["delta"] = sparsify(
tv_info["delta"],
density=tv_info["density"],
method=self.method.sparsification_method,
rescale=self.rescale,
**kwargs,
)

deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
Expand Down Expand Up @@ -218,14 +229,15 @@ def rand_mask(base, x, percent, seed=None):


def get_task_vectors(
parameter_name: str,
weight_info: WeightInfo,
base_model: ModelReference,
tensors: ImmutableMap[ModelReference, torch.Tensor],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
swapping: bool,
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
keys = list(tensors.keys())
base = tensors[base_model]
parameter_name = weight_info.name

res = []
for model in keys:
Expand All @@ -235,7 +247,7 @@ def get_task_vectors(
x = tensors[model].to(base.dtype)

if x.shape != base.shape:
if "lm_head" in parameter_name or "embed_tokens" in parameter_name:
if weight_info.is_embed:
x = x[: base.shape[0], : base.shape[1]]
logging.warning(f"Using submatrix of {model}:{parameter_name}")
else:
Expand Down

0 comments on commit cbbd2cc

Please sign in to comment.