Skip to content

Commit

Permalink
refactor get_architecture_info into ArchitectureInfoUtils.get_archite…
Browse files Browse the repository at this point in the history
…cture_info. Simplify implementations of get_architecture_info and infer_architecture_info
  • Loading branch information
ElliotStein committed Dec 9, 2024
1 parent 9cfa073 commit 97985c4
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 78 deletions.
85 changes: 28 additions & 57 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import ClassVar, Dict, List, Optional, Tuple, Union

from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError
from pydantic import BaseModel, Field
from transformers import PretrainedConfig
from typing_extensions import Literal
Expand Down Expand Up @@ -473,40 +472,12 @@ def _load_all_architectures() -> (
QWEN2_INFO = _load_json_arch("qwen2.json")


def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:
if len(config.architectures) != 1:
raise RuntimeError("More than one architecture in config?")

arch_name = config.architectures[0]

if arch_name == MixtralTensorNames.ARCHITECTURE_NAME:
return MixtralTensorNames.from_config(config)

if arch_name not in NAME_TO_ARCH:
warnings.warn(
f"Unsupported architecture {arch_name}, attempting automatic architecture generation"
)
return False

candidates = list(NAME_TO_ARCH[arch_name])
if len(candidates) == 1:
return candidates[0]

for c in candidates:
if c.definition.expected_model_type == config.model_type:
return c

warnings.warn(
f"Unsupported model_type {config.model_type} for architecture {arch_name}"
)
return False


class ArchitectureInfoUtils:
"""Functions for inferring architecture information from a merge configuration."""

@staticmethod
def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:
def get_architecture_info(config: PretrainedConfig) -> Optional[ArchitectureInfo]:
"""Get architecture info from an existing model config."""
if len(config.architectures) != 1:
raise RuntimeError("More than one architecture in config?")

Expand All @@ -515,28 +486,30 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:
if arch_name == MixtralTensorNames.ARCHITECTURE_NAME:
return MixtralTensorNames.from_config(config)

if arch_name not in NAME_TO_ARCH:
warnings.warn(
f"Unsupported architecture {arch_name}, attempting automatic architecture generation"
)
return False

candidates = list(NAME_TO_ARCH[arch_name])
if len(candidates) == 1:
return candidates[0]
if arch_name in NAME_TO_ARCH:
candidates = list(NAME_TO_ARCH[arch_name])
if len(candidates) == 1:
return candidates[0]

for c in candidates:
if c.definition.expected_model_type == config.model_type:
return c
for c in candidates:
if c.definition.expected_model_type == config.model_type:
return c

warnings.warn(
f"Unsupported model_type {config.model_type} for architecture {arch_name}"
)
return False
warnings.warn(f"No architecture config available for: {arch_name}.")
return None

@staticmethod
def infer_architecture_info(merge_config):
"""Infer architecture info and prefixes for alignment."""
def infer_architecture_info(merge_config) -> AutomaticArchitectureInfo:
"""
Infer architecture info and prefixes for alignment.
Prefixes typically denote where a model is used as a subcomponent of another model.
e.g., [layer.0, layer.1, ...] and []'vision_tower.layer.0', vision_tower.layer.1', ...]
inferring ßprefix = 'vision_tower' is required to align the two models.
Usage:
Similar to `get_architecture_info`, but requires a merge configuration object rather than a model config.
This is so the common parameter names between all models can be inferred.
"""
param_names = [
ParameterNamesUtils.get_model_parameter_names(source_model.model.path)
for source_model in merge_config.referenced_models()
Expand All @@ -550,7 +523,7 @@ def infer_architecture_info(merge_config):
paired_list.insert(0, paired_list.pop(i))
break
param_names, referenced_models = zip(*paired_list)
print(f"Base model selected: {referenced_models[0].model.path}")
logging.info(f"Base model selected: {referenced_models[0].model.path}")

prefixes = [""]
for i in range(1, len(param_names)):
Expand Down Expand Up @@ -578,13 +551,11 @@ def infer_architecture_info(merge_config):
arch_name = referenced_models[0].model.path
parameter_names = common_names

return [
AutomaticArchitectureInfo(
arch_name=arch_name,
parameter_names=parameter_names,
prefix_tracker=prefix_tracker,
)
]
return AutomaticArchitectureInfo(
arch_name=arch_name,
parameter_names=parameter_names,
prefix_tracker=prefix_tracker,
)

@staticmethod
def log_info(common_names, param_names, referenced_models):
Expand Down
6 changes: 4 additions & 2 deletions mergekit/evo/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
vllm = None


from mergekit.architecture import ConfiguredArchitectureInfo, get_architecture_info
from mergekit.architecture import ArchitectureInfoUtils, ConfiguredArchitectureInfo
from mergekit.config import MergeConfiguration
from mergekit.evo.config import EvolMergeConfiguration
from mergekit.evo.genome import InvalidGenotypeError, ModelGenome
Expand Down Expand Up @@ -144,7 +144,9 @@ def __init__(
super().__init__(*args, vllm=vllm, **kwargs)

def _maybe_init_model(self, config: MergeConfiguration):
ai = get_architecture_info(self.genome._input_config_example)
ai = ArchitectureInfoUtils.get_architecture_info(
self.genome._input_config_example
)
cfg_out = _model_out_config(
config,
ai,
Expand Down
7 changes: 5 additions & 2 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import shutil
import warnings
from collections import Counter
from pathlib import Path
from typing import List, Optional
Expand Down Expand Up @@ -291,17 +292,19 @@ def _load_arch_info(
for m in merge_config.referenced_models()
]

if not any(a is False for a in model_arch_info):
if all(a is not None for a in model_arch_info):
if not options.allow_crimes and not all(
a == model_arch_info[0] for a in model_arch_info[1:]
):
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
return model_arch_info[0]
else:
warnings.warn("Attempting Automatic Merge.")
model_arch_info = ArchitectureInfoUtils.infer_architecture_info(merge_config)

return model_arch_info[0]
return model_arch_info


__all__ = ["MergeOptions", "run_merge"]
4 changes: 2 additions & 2 deletions mergekit/moe/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tqdm
import transformers

from mergekit.architecture import get_architecture_info
from mergekit.architecture import ArchitectureInfoUtils
from mergekit.moe.arch import MoEOutputArchitecture
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
from mergekit.moe.config import MoEMergeConfig
Expand Down Expand Up @@ -138,7 +138,7 @@ def write_model(
loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
shared_loader = loaders.get(shared_def.source_model) if shared_def else None
for weight_info in tqdm.tqdm(
get_architecture_info(base_cfg).all_weights(base_cfg),
ArchitectureInfoUtils.get_architecture_info(base_cfg).all_weights(base_cfg),
desc="Weights",
):
tensor_name = weight_info.name
Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tqdm
from transformers import AutoTokenizer

from mergekit.architecture import get_architecture_info
from mergekit.architecture import ArchitectureInfoUtils
from mergekit.common import ModelReference, dtype_from_name
from mergekit.io.tasks import LoaderCache
from mergekit.io.tensor_writer import TensorWriter
Expand Down Expand Up @@ -62,7 +62,7 @@ def main(
)

model_config = model.config(trust_remote_code=merge_options.trust_remote_code)
model_arch_info = get_architecture_info(
model_arch_info = ArchitectureInfoUtils.get_architecture_info(
model.config(trust_remote_code=merge_options.trust_remote_code)
)

Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/ABM/extract_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, DefaultDataCollator

from mergekit.architecture import _template_substitution, get_architecture_info
from mergekit.architecture import ArchitectureInfoUtils, _template_substitution
from mergekit.common import ModelReference

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -130,7 +130,7 @@ def main(
model = ModelReference.model_validate(model_path)

model_config = model.config()
model_arch_info = get_architecture_info(model_config)
model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config)

_json = model_arch_info.definition

Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/ABM/extract_permutation_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import scipy
import torch

from mergekit.architecture import _template_substitution, get_architecture_info
from mergekit.architecture import ArchitectureInfoUtils, _template_substitution
from mergekit.common import ModelReference


Expand Down Expand Up @@ -147,7 +147,7 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device):

model_config = model.config()

model_arch_info = get_architecture_info(model_config)
model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config)

_json = model_arch_info.definition

Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/layershuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import click
import yaml

from mergekit.architecture import get_architecture_info
from mergekit.architecture import ArchitectureInfoUtils
from mergekit.common import ModelReference
from mergekit.config import (
InputSliceDefinition,
Expand Down Expand Up @@ -76,7 +76,7 @@ def main(
models = [ModelReference.parse(m) for m in model]

m0_cfg = models[0].config()
arch_info = get_architecture_info(m0_cfg)
arch_info = ArchitectureInfoUtils.get_architecture_info(m0_cfg)
total_num_layers = arch_info.num_layers(m0_cfg)

out_slices: List[OutputSliceDefinition] = []
Expand Down
8 changes: 4 additions & 4 deletions mergekit/scripts/tokensurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from typing_extensions import TypeAlias

from mergekit.architecture import (
ArchitectureInfoUtils,
ConfiguredArchitectureInfo,
WeightInfo,
get_architecture_info,
)
from mergekit.common import ModelReference
from mergekit.io import TensorWriter
Expand Down Expand Up @@ -281,7 +281,7 @@ def get_embedding_info(
) -> Tuple[WeightInfo, WeightInfo]:
"""Get WeightInfo for the input and output embeddings of a model."""
cfg = model.config(trust_remote_code=options.trust_remote_code)
arch_info = get_architecture_info(cfg)
arch_info = ArchitectureInfoUtils.get_architecture_info(cfg)

embed, lm_head = None, None
for weight_info in arch_info.pre_weights(cfg):
Expand Down Expand Up @@ -596,8 +596,8 @@ def validate_architecture(
"""
model_cfg = model.config(trust_remote_code=options.trust_remote_code)
donor_cfg = donor.config(trust_remote_code=options.trust_remote_code)
model_arch_info = get_architecture_info(model_cfg)
donor_arch_info = get_architecture_info(donor_cfg)
model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_cfg)
donor_arch_info = ArchitectureInfoUtils.get_architecture_info(donor_cfg)
if donor_arch_info != model_arch_info:
report_issue(
f"Model architectures do not match: {model_arch_info.name()} vs {donor_arch_info.name()}",
Expand Down
6 changes: 3 additions & 3 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LlavaForConditionalGeneration,
)

from mergekit.architecture import ArchitectureInfoUtils, get_architecture_info
from mergekit.architecture import ArchitectureInfoUtils
from mergekit.config import MergeConfiguration
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
from mergekit.merge import MergeOptions, run_merge
Expand Down Expand Up @@ -50,9 +50,9 @@ def run_and_check_merge(
if check_tensors:
model_config = AutoConfig.from_pretrained(tmpdir)
if auto_arch:
arch_info = ArchitectureInfoUtils.infer_architecture_info(config)[0]
arch_info = ArchitectureInfoUtils.infer_architecture_info(config)
else:
arch_info = get_architecture_info(model_config)
arch_info = ArchitectureInfoUtils.get_architecture_info(model_config)

index = ShardedTensorIndex.from_disk(tmpdir)
for weight_info in arch_info.all_weights(model_config):
Expand Down

0 comments on commit 97985c4

Please sign in to comment.