Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a new method to shuffle/swap values #167

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
28 changes: 28 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,55 @@ def get(method: str) -> MergeMethod:
sparsification_method=None,
default_normalize=False,
default_rescale=False,
default_swapping=False,
)
elif method == "ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
default_rescale=False,
default_swapping=False,
)
elif method == "dare_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
default_swapping=False,
)
elif method == "dare_linear":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
default_swapping=False,
)
elif method == "task_swapping":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=None,
default_normalize=False,
default_rescale=False,
default_swapping=True,
)
elif method == "task_swapping_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
default_rescale=False,
default_swapping=True,
)
elif method == "task_swapping_dare_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.rescaled_random,
default_normalize=False,
default_rescale=True,
default_swapping=True,
)
elif method == "breadcrumbs":
return GeneralizedTaskArithmeticMerge(
Expand Down
73 changes: 72 additions & 1 deletion mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True):
sparsification_method: Optional[SparsificationMethod]
default_normalize: bool
default_rescale: bool
default_swapping: bool

def parameters(self) -> List[ConfigParameterDef]:
return [
Expand All @@ -52,12 +53,19 @@ def parameters(self) -> List[ConfigParameterDef]:
ConfigParameterDef(
name="rescale", required=False, default_value=self.default_rescale
),
ConfigParameterDef(
name="swapping", required=False, default_value=self.default_swapping
),
]

def tensor_parameters(self) -> List[ConfigParameterDef]:
res = [
ConfigParameterDef(name="weight", required=True),
ConfigParameterDef(name="density", required=False, default_value=1.0),
ConfigParameterDef(name="diagonal_offset", required=False),
ConfigParameterDef(name="invert_offset", required=False, default_value= False),
ConfigParameterDef(name="random_mask", required=False, default_value= 0.0),
ConfigParameterDef(name="random_mask_seed", required=False, default_value= None),
]
if self.sparsification_method == SparsificationMethod.magnitude_outliers:
res.append(
Expand Down Expand Up @@ -97,6 +105,7 @@ def make_task(
int8_mask=parameters["int8_mask"],
normalize=parameters["normalize"],
rescale=parameters["rescale"],
swapping=parameters["swapping"],
weight_info=output_weight,
)

Expand All @@ -110,6 +119,7 @@ class GTATask(Task[torch.Tensor]):
int8_mask: bool
normalize: bool
rescale: bool
swapping: bool

def uses_accelerator(self) -> bool:
return True
Expand All @@ -128,6 +138,7 @@ def execute(
self.base_model,
tensors,
tensor_parameters=self.tensor_parameters.data,
swapping=self.swapping,
)
if not tvs:
return base
Expand All @@ -139,9 +150,11 @@ def execute(
if "gamma" in tv_info:
kwargs["gamma"] = tv_info["gamma"]


if "epsilon" in tv_info:
kwargs["epsilon"] = tv_info["epsilon"]


tv_info["delta"] = sparsify(
tv_info["delta"],
density=tv_info["density"],
Expand Down Expand Up @@ -191,15 +204,68 @@ def group_label(self) -> Optional[str]:
return self.tensors.group_label()


def swapping_method(base, x, parameters):
def swap_values(shape, n, base, x):
if x.dim() == 2:
rows, cols = shape
rows_range = torch.arange(rows).view(-1, 1)
cols_range = torch.arange(cols).view(1, -1)
mask = ((rows_range + cols_range) % n == 0).to(base.device.type).bool()
x = torch.where(mask, x, base)
else:
rows_range = torch.arange(shape[0])
mask = ((rows_range) % n == 0).to(base.device.type).bool()
x = torch.where(mask, x, base)
return x

def rand_mask(base, x, percent, seed=None):
oldseed = torch.seed()
if seed is not None:
torch.manual_seed(seed)
random = torch.rand(base.shape)
mask = (random <= percent).to(base.device.type).bool()
del random
torch.manual_seed(oldseed)
x = torch.where(mask, x, base)
return x

bt = base.dtype
if x.device.type == "cpu":
x = x.to(torch.float32)
base = base.to(torch.float32)

diagonal_offset = None
diagonal_offset = parameters.get('diagonal_offset')
random_mask = parameters.get('random_mask')
random_mask_seed = parameters.get('random_mask_seed')
random_mask_seed = int(random_mask_seed) if random_mask_seed is not None else random_mask_seed

assert (diagonal_offset is not None) and (diagonal_offset % 1 == 0) and (diagonal_offset >= 2), "The diagonal_offset must be an integer greater than or equal to 2."

if random_mask != 0.0:
assert (random_mask is not None) and (random_mask < 1.0) and (random_mask > 0.0) , "The random_mask parameter can't be empty, 0, 1, or None, it must be a number between 0 and 1."
assert random_mask_seed is None or (isinstance(random_mask_seed, int) and random_mask_seed % 1 == 0), "The random_mask_seed parameter must be None or an integer, None is a random seed."
x = rand_mask(base, x, random_mask, random_mask_seed)

else:
if parameters.get('invert_offset') == False:
x = swap_values(x.shape, diagonal_offset, base, x)
else:
x = swap_values(x.shape, diagonal_offset, x, base)

del base
return x.to(bt)


def get_task_vectors(
weight_info: WeightInfo,
base_model: ModelReference,
tensors: ImmutableMap[ModelReference, torch.Tensor],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
swapping: bool,
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
keys = list(tensors.keys())
base = tensors[base_model]

parameter_name = weight_info.name

res = []
Expand All @@ -208,6 +274,7 @@ def get_task_vectors(
continue

x = tensors[model].to(base.dtype)

if x.shape != base.shape:
if weight_info.is_embed:
x = x[: base.shape[0], : base.shape[1]]
Expand All @@ -218,6 +285,10 @@ def get_task_vectors(
)
continue

if swapping:
x = swapping_method(base, x, dict(tensor_parameters[model].items()))


delta = x - base
del x
del tensors[model]
Expand Down
Loading