Skip to content

Commit

Permalink
Address PR feedback: resolve merge architecture error
Browse files Browse the repository at this point in the history
Updated the infer_architecture logic to handle cases where architectures appear mismatched
  • Loading branch information
ElliotStein committed Dec 2, 2024
1 parent f081a0b commit 84260f0
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

import importlib.resources
import logging
import re
import string
import warnings
Expand Down Expand Up @@ -499,7 +500,7 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:

def strip_prefix(name: str, prefixes: List[str]) -> str:
"""Remove any prefix in prefixes from the start of the name."""
prefixes = [prefixes] if isinstance(prefixes, str) else prefixes
prefixes = [prefixes] if not isinstance(prefixes, list) else prefixes
for prefix in prefixes:
if name.startswith(prefix + "."):
return name[len(prefix) + 1 :]
Expand All @@ -510,28 +511,29 @@ def is_ordered_sublist_with_prefix(
list1: List[str], list2: List[str], prefixes: List[str]
) -> bool:
"""
Check if list1 matches a subset of list2 in the correct order after optional prefix removal.
Check if list2 is a sublist of list1, after optional prefix removal from list1.
"""
stripped_list2 = [strip_prefix(name, prefixes) for name in list2]
stripped_list1 = [strip_prefix(name, prefixes) for name in list1]

try:
start_index = stripped_list2.index(list1[0])
for i, item in enumerate(list1):
if stripped_list2[start_index + i] != item:
return False
start_index = stripped_list1.index(list2[0])
stripped_list1 = stripped_list1[start_index : start_index + len(list2)]
if stripped_list1 != list2:
return False
return True
except (ValueError, IndexError):
except (ValueError, IndexError) as e:
logging.error(f"Failed to find common parameter names between models: {e}")
return False


def find_prefix_and_check_sublist(list1: List[str], list2: List[str]) -> Optional[str]:
"""
Attempts to find a prefix from elements in list2 that makes list1 an ordered sublist of list2.
Attempts to find a prefix from elements in list1 that makes list2 an ordered sublist of list1.
"""
if len(list1) > len(list2):
list1, list2 = list2, list1
assert len(list1) >= len(list2), "params name list1 can't be shorter than list2"

possible_prefixes = {item.split(".")[0] for item in list2 if "." in item}
possible_prefixes = [""] + list(possible_prefixes)

for prefix in possible_prefixes:
if is_ordered_sublist_with_prefix(list1, list2, [prefix]):
Expand All @@ -544,9 +546,12 @@ def find_prefixes_for_alignment(param_names: List[List[str]]) -> List[str]:
"""Determine prefixes needed to align parameter names in order of the longest list."""
prefixes = [""]
for i in range(1, len(param_names)):
assert len(param_names[0]) >= len(
param_names[i]
), "params name list1 can't be shorter than list2"
if param_names[0] != param_names[i]:
prefix = find_prefix_and_check_sublist(param_names[0], param_names[i])
if not prefix:
if prefix is None:
raise ValueError("Could not resolve model architecture automatically.")
else:
prefix = ""
Expand Down

0 comments on commit 84260f0

Please sign in to comment.