Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Samoed committed Dec 13, 2024
1 parent d643976 commit fb2a4c0
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 32 deletions.
33 changes: 13 additions & 20 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,8 @@ def run(
task_eval_splits = (
eval_splits if eval_splits is not None else task.eval_splits
)
task_eval_langs = (
task.metadata.eval_langs
if isinstance(task.metadata.eval_langs, dict)
else None
)
task_eval_langs = task.metadata.hf_subsets_to_langscripts

existing_results = None
save_path = None

Expand Down Expand Up @@ -680,7 +677,7 @@ def get_last_evaluated_splits(self):
def _get_missing_evaluations(
existing_results: TaskResult | None,
task_eval_splits: list[str],
task_eval_langs: dict[str, list[str]] | None,
task_eval_langs: dict[str, list[str]],
eval_langs: list[str] | None,
) -> dict[str, dict[str, Any]]:
"""Return a dictionary for each split, indicating if the whole split is missing and which subsets are missing."""
Expand All @@ -690,21 +687,17 @@ def _get_missing_evaluations(
}

# Determine subsets to consider if multilingual
if task_eval_langs is not None:
if eval_langs is None:
# If no eval_langs specified, consider all subsets
subsets_to_consider = list(task_eval_langs.keys())
else:
subsets_to_consider = []
for hf_subset, lang_list in task_eval_langs.items():
# lang_list are like ["eng-Latn", "deu-Latn"]
iso_langs = [l.split("-")[0] for l in lang_list]
# If any requested language is present in this subset, consider it
if any(run_lang in iso_langs for run_lang in eval_langs):
subsets_to_consider.append(hf_subset)
if eval_langs is None:
# If no eval_langs specified, consider all subsets
subsets_to_consider = list(task_eval_langs.keys())
else:
# Not multilingual, just one "default" subset
subsets_to_consider = ["default"]
subsets_to_consider = []
for hf_subset, lang_list in task_eval_langs.items():
# lang_list are like ["eng-Latn", "deu-Latn"]
iso_langs = [l.split("-")[0] for l in lang_list]
# If any requested language is present in this subset, consider it
if any(run_lang in iso_langs for run_lang in eval_langs):
subsets_to_consider.append(hf_subset)

# If no existing results, all splits and subsets are missing
if existing_results is None:
Expand Down
65 changes: 53 additions & 12 deletions tests/test_evaluation/test_split_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_all_splits_evaluated(model, tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"val", "test"}
assert len(last_evaluated_splits["MockRetrievalTask"]) == 2
assert results[0].scores.keys() == {"val", "test"}


def test_one_missing_split(model, tasks, tmp_path):
Expand All @@ -55,6 +56,7 @@ def test_one_missing_split(model, tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"val"}
assert len(last_evaluated_splits["MockRetrievalTask"]) == 1
assert results[0].scores.keys() == {"val"}

results2 = evaluation.run(
model,
Expand All @@ -68,11 +70,12 @@ def test_one_missing_split(model, tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"test"}
assert len(last_evaluated_splits["MockRetrievalTask"]) == 1
assert results2[0].scores.keys() == {"test", "val"}


def test_no_missing_splits(model, tasks, tmp_path):
evaluation = MTEB(tasks=tasks)
_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "testcase3"),
Expand All @@ -82,9 +85,10 @@ def test_no_missing_splits(model, tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert "MockRetrievalTask" in last_evaluated_splits
assert len(last_evaluated_splits["MockRetrievalTask"]) == 2
assert results[0].scores.keys() == {"test", "val"}

evaluation = MTEB(tasks=tasks)
_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "testcase3"),
Expand All @@ -95,6 +99,7 @@ def test_no_missing_splits(model, tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert "MockRetrievalTask" in last_evaluated_splits
assert len(last_evaluated_splits["MockRetrievalTask"]) == 0
assert results[0].scores.keys() == {"test", "val"}


def test_all_languages_evaluated(model, multilingual_tasks, tmp_path):
Expand All @@ -111,6 +116,8 @@ def test_all_languages_evaluated(model, multilingual_tasks, tmp_path):
assert "MockMultilingualRetrievalTask" in last_evaluated_splits
assert len(last_evaluated_splits["MockMultilingualRetrievalTask"]) == 1
assert last_evaluated_splits["MockMultilingualRetrievalTask"] == ["test"]
assert results[0].scores.keys() == {"test"}
assert len(results[0].scores["test"]) == 2


def test_missing_language(model, multilingual_tasks, tmp_path):
Expand All @@ -127,8 +134,10 @@ def test_missing_language(model, multilingual_tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert len(last_evaluated_splits["MockMultilingualRetrievalTask"]) == 1
assert last_evaluated_splits["MockMultilingualRetrievalTask"] == ["test"]
assert results[0].scores.keys() == {"test"}
assert results[0].languages == ["eng"]

_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["test"],
output_folder=str(tmp_path / "missing_lang_test"),
Expand All @@ -140,11 +149,14 @@ def test_missing_language(model, multilingual_tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert len(last_evaluated_splits["MockMultilingualRetrievalTask"]) == 1
assert last_evaluated_splits["MockMultilingualRetrievalTask"] == ["test"]
assert sorted(results[0].languages) == ["eng", "fra"]
assert results[0].scores.keys() == {"test"}
assert len(results[0].scores["test"]) == 2


def test_no_missing_languages(model, multilingual_tasks, tmp_path):
evaluation = MTEB(tasks=multilingual_tasks)
_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["test"],
output_folder=str(tmp_path / "no_missing_lang_test"),
Expand All @@ -154,9 +166,12 @@ def test_no_missing_languages(model, multilingual_tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert "MockMultilingualRetrievalTask" in last_evaluated_splits
assert len(last_evaluated_splits["MockMultilingualRetrievalTask"]) == 1
assert results[0].scores.keys() == {"test"}
assert len(results[0].scores["test"]) == 2
assert sorted(results[0].languages) == ["eng", "fra"]

evaluation = MTEB(tasks=multilingual_tasks)
_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["test"],
output_folder=str(tmp_path / "no_missing_lang_test"),
Expand All @@ -167,11 +182,14 @@ def test_no_missing_languages(model, multilingual_tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert "MockMultilingualRetrievalTask" in last_evaluated_splits
assert len(last_evaluated_splits["MockMultilingualRetrievalTask"]) == 0
assert results[0].scores.keys() == {"test"}
assert len(results[0].scores["test"]) == 2
assert sorted(results[0].languages) == ["eng", "fra"]


def test_partial_languages(model, multilingual_tasks, tmp_path):
evaluation = MTEB(tasks=multilingual_tasks)
_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["test"],
output_folder=str(tmp_path / "partial_lang_test"),
Expand All @@ -181,8 +199,10 @@ def test_partial_languages(model, multilingual_tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert len(last_evaluated_splits["MockMultilingualRetrievalTask"]) == 1
assert last_evaluated_splits["MockMultilingualRetrievalTask"] == ["test"]
assert results[0].languages == ["fra"]
assert results[0].scores.keys() == {"test"}

_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["test"],
output_folder=str(tmp_path / "partial_lang_test"),
Expand All @@ -193,13 +213,16 @@ def test_partial_languages(model, multilingual_tasks, tmp_path):
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert len(last_evaluated_splits["MockMultilingualRetrievalTask"]) == 1
assert last_evaluated_splits["MockMultilingualRetrievalTask"] == ["test"]
assert sorted(results[0].languages) == ["eng", "fra"]
assert results[0].scores.keys() == {"test"}
assert len(results[0].scores["test"]) == 2


def test_multilingual_multiple_splits_partial_langs_partial_splits(
model, multilingual_tasks, tmp_path
):
evaluation = MTEB(tasks=multilingual_tasks)
_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["val"],
output_folder=str(tmp_path / "partial_langs_partial_splits"),
Expand All @@ -209,8 +232,11 @@ def test_multilingual_multiple_splits_partial_langs_partial_splits(

last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockMultilingualRetrievalTask"]) == {"val"}
assert sorted(results[0].languages) == ["eng", "fra"]
assert results[0].scores.keys() == {"val"}
assert len(results[0].scores["val"]) == 2

_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "partial_langs_partial_splits"),
Expand All @@ -221,13 +247,17 @@ def test_multilingual_multiple_splits_partial_langs_partial_splits(

last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockMultilingualRetrievalTask"]) == {"test"}
assert sorted(results[0].languages) == ["eng", "fra"]
assert results[0].scores.keys() == {"test", "val"}
assert len(results[0].scores["test"]) == 2
assert len(results[0].scores["val"]) == 2


def test_multilingual_multiple_splits_missing_only_one_language_in_one_split(
model, multilingual_tasks, tmp_path
):
evaluation = MTEB(tasks=multilingual_tasks)
_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["val"],
output_folder=str(tmp_path / "one_lang_one_split"),
Expand All @@ -237,8 +267,11 @@ def test_multilingual_multiple_splits_missing_only_one_language_in_one_split(

last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockMultilingualRetrievalTask"]) == {"val"}
assert sorted(results[0].languages) == ["eng", "fra"]
assert results[0].scores.keys() == {"val"}
assert len(results[0].scores["val"]) == 2

_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "one_lang_one_split"),
Expand All @@ -249,8 +282,12 @@ def test_multilingual_multiple_splits_missing_only_one_language_in_one_split(

last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockMultilingualRetrievalTask"]) == {"test"}
assert sorted(results[0].languages) == ["eng", "fra"]
assert results[0].scores.keys() == {"test", "val"}
assert len(results[0].scores["test"]) == 1
assert len(results[0].scores["val"]) == 2

_ = evaluation.run(
results = evaluation.run(
model,
eval_splits=["test"],
output_folder=str(tmp_path / "one_lang_one_split"),
Expand All @@ -261,3 +298,7 @@ def test_multilingual_multiple_splits_missing_only_one_language_in_one_split(

last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockMultilingualRetrievalTask"]) == {"test"}
assert sorted(results[0].languages) == ["eng", "fra"]
# output merged result with previous results
assert results[0].scores.keys() == {"test", "val"}
assert len(results[0].scores["test"]) == 2

0 comments on commit fb2a4c0

Please sign in to comment.