Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/2d_dice_scores
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Nov 12, 2024
2 parents 0ee7a53 + e2543c8 commit 6f487cb
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 16 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))



### Fixed

- Fixed segmentation `Dice` + `GeneralizedDice` for 2d index tensors ([#2832](https://github.com/Lightning-AI/torchmetrics/pull/2832))


---
- Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830))


## [1.5.2] - 2024-11-07

Expand Down
4 changes: 2 additions & 2 deletions requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ codecov ==2.1.13
coverage ==7.6.*
codecov ==2.1.13
pytest ==8.3.*
pytest-cov ==5.0.0
pytest-cov ==6.0.0
pytest-doctestplus ==1.2.1
pytest-rerunfailures ==14.0
pytest-timeout ==2.3.1
pytest-xdist ==3.6.1
phmdoctest ==1.4.0

psutil ==6.*
pyGithub >2.0.0, <2.5.0
pyGithub >2.0.0, <2.6.0
fire ==0.7.*

cloudpickle >1.3, <=3.1.0
Expand Down
4 changes: 2 additions & 2 deletions requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

nltk >3.8.1, <=3.9.1
tqdm <4.67.0
regex >=2021.9.24, <=2024.9.11
tqdm <4.68.0
regex >=2021.9.24, <=2024.11.6
transformers >4.4.0, <4.47.0
mecab-python3 >=1.0.6, <1.1.0
ipadic >=1.0.0, <1.1.0
Expand Down
9 changes: 3 additions & 6 deletions src/torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,9 @@ def _rouge_score_update(
list_results.append(result_inner.copy())

if accumulate == "best":
key_curr = rouge_keys_values[0]
all_fmeasure = torch.tensor([v[key_curr]["fmeasure"] for v in list_results])
highest_idx = int(torch.argmax(all_fmeasure).item())

for rouge_key in rouge_keys_values:
results[rouge_key].append(list_results[highest_idx][rouge_key]) # todo
for k in rouge_keys_values:
index = torch.argmax(torch.tensor([s[k]["fmeasure"] for s in list_results]))
results[k].append(list_results[index][k])

elif accumulate == "avg":
new_result_avg: dict[Union[int, str], dict[str, Tensor]] = {
Expand Down
63 changes: 59 additions & 4 deletions tests/unittests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ def _reference_rouge_score(
aggregator_avg = BootstrapAggregator()

if accumulate == "best":
key_curr = next(iter(list_results[0].keys()))
all_fmeasure = torch.tensor([v[key_curr].fmeasure for v in list_results])
highest_idx = torch.argmax(all_fmeasure).item()
aggregator.add_scores(list_results[highest_idx])
scores = {}
for rouge_key in list_results[0]:
all_fmeasure = torch.tensor([v[rouge_key].fmeasure for v in list_results])
highest_idx = torch.argmax(all_fmeasure).item()
scores[rouge_key] = list_results[highest_idx][rouge_key]
aggregator.add_scores(scores)
elif accumulate == "avg":
for _score in list_results:
aggregator_avg.add_scores(_score)
Expand Down Expand Up @@ -270,3 +272,56 @@ def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer):
use_stemmer=use_stemmer,
)
assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score)


@pytest.mark.parametrize(
("preds", "references", "expected_scores"),
[
(
"a b c",
["a b c", "c b a"],
{
"rouge1_fmeasure": 1.0,
"rouge1_precision": 1.0,
"rouge1_recall": 1.0,
"rouge2_fmeasure": 1.0,
"rouge2_precision": 1.0,
"rouge2_recall": 1.0,
"rougeL_fmeasure": 1.0,
"rougeL_precision": 1.0,
"rougeL_recall": 1.0,
"rougeLsum_fmeasure": 1.0,
"rougeLsum_precision": 1.0,
"rougeLsum_recall": 1.0,
},
),
(
"a b c",
["c b a", "a b c"],
{
"rouge1_fmeasure": 1.0,
"rouge1_precision": 1.0,
"rouge1_recall": 1.0,
"rouge2_fmeasure": 1.0,
"rouge2_precision": 1.0,
"rouge2_recall": 1.0,
"rougeL_fmeasure": 1.0,
"rougeL_precision": 1.0,
"rougeL_recall": 1.0,
"rougeLsum_fmeasure": 1.0,
"rougeLsum_precision": 1.0,
"rougeLsum_recall": 1.0,
},
),
],
)
def test_rouge_score_accumulate_best(preds, references, expected_scores):
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2148."""
# Calculate ROUGE scores
result = rouge_score(preds, references, accumulate="best")

# Assert each expected score
for key in expected_scores:
assert torch.isclose(
result[key], torch.tensor(expected_scores[key])
), f"Expected {expected_scores[key]} for {key}, but got {result[key]}"

0 comments on commit 6f487cb

Please sign in to comment.