Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/multiclass recall macro avg ignore index #2710

Draft
wants to merge 61 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
176711d
Fix: Corrected MulticlassRecall macro average calculation when ignore…
rittik9 Aug 31, 2024
df36d0f
style: format code to comply with pre-commit hooks
rittik9 Sep 1, 2024
0773bab
test: Add test for MulticlassRecall with ignore_index+macro (fixes #2…
rittik9 Sep 2, 2024
78177ac
chlog
Borda Sep 9, 2024
f7701ea
Merge branch 'master' into master
Borda Sep 16, 2024
d6f041b
Merge branch 'master' into master
mergify[bot] Sep 16, 2024
fb6c23d
Merge branch 'master' into master
mergify[bot] Sep 16, 2024
3ae861b
Merge branch 'master' into master
mergify[bot] Sep 16, 2024
a0401f6
Merge branch 'master' into master
Borda Sep 17, 2024
858e0d1
Merge branch 'master' into master
mergify[bot] Sep 17, 2024
bac6267
Merge branch 'master' into master
mergify[bot] Sep 24, 2024
2976947
Merge branch 'master' into master
mergify[bot] Sep 24, 2024
dbe1a5a
Merge branch 'master' into master
mergify[bot] Oct 1, 2024
bb36be4
Merge branch 'master' into master
Borda Oct 8, 2024
ead62fe
Merge branch 'master' into master
mergify[bot] Oct 9, 2024
e0ed7e7
Merge branch 'master' into master
mergify[bot] Oct 9, 2024
263548d
Merge branch 'master' into master
mergify[bot] Oct 10, 2024
8cc5bf1
Merge branch 'master' into master
mergify[bot] Oct 10, 2024
0483219
Merge branch 'master' into master
Borda Oct 10, 2024
982cfea
Merge branch 'master' into master
mergify[bot] Oct 11, 2024
d16c815
Merge branch 'master' into master
mergify[bot] Oct 11, 2024
c078bd2
Merge branch 'master' into master
mergify[bot] Oct 11, 2024
d61727e
Merge branch 'master' into master
mergify[bot] Oct 14, 2024
9aa5928
Merge branch 'master' into master
Borda Oct 15, 2024
581d3ec
Merge branch 'master' into master
mergify[bot] Oct 18, 2024
61a4b56
Merge branch 'master' into master
Borda Oct 18, 2024
a75c5e3
Merge branch 'master' into master
Borda Oct 21, 2024
005cc94
Merge branch 'master' into master
Borda Oct 22, 2024
4c8bea2
Merge branch 'master' into master
mergify[bot] Oct 22, 2024
6767a4c
Merge branch 'master' into master
mergify[bot] Oct 22, 2024
daa1006
Merge branch 'master' into master
mergify[bot] Oct 22, 2024
2642f1c
Merge branch 'master' into master
mergify[bot] Oct 22, 2024
52fda88
Merge branch 'master' into master
Borda Oct 23, 2024
8fa7dbf
Merge branch 'master' into master
Borda Oct 23, 2024
0b4818b
Merge branch 'master' into master
mergify[bot] Oct 24, 2024
c396f5b
Merge branch 'master' into master
mergify[bot] Oct 29, 2024
94e3b37
Merge branch 'master' into master
mergify[bot] Oct 29, 2024
844ad9c
Merge branch 'master' into master
Borda Oct 30, 2024
b63a661
Merge branch 'master' into master
mergify[bot] Oct 30, 2024
a1cbaad
Merge branch 'master' into master
mergify[bot] Oct 30, 2024
4b65013
Merge branch 'master' into master
mergify[bot] Oct 30, 2024
69c3a31
fix args
Borda Oct 31, 2024
48bea14
more kwargs
Borda Oct 31, 2024
3753de1
Merge branch 'master' into master
Borda Oct 31, 2024
043ca6b
Merge branch 'master' into master
mergify[bot] Oct 31, 2024
bd84496
Apply suggestions from code review
Borda Oct 31, 2024
214e78b
Merge branch 'master' into master
mergify[bot] Oct 31, 2024
34c88f7
dims + doctests
Borda Oct 31, 2024
a2b108a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
b3feb3f
args
Borda Oct 31, 2024
580ee6e
Merge branch 'master' of https://github.com/rittik9/torchmetrics into…
Borda Oct 31, 2024
8209416
kwargs
Borda Oct 31, 2024
daf6d72
Merge branch 'master' into master
Borda Oct 31, 2024
561f791
fix test
Borda Oct 31, 2024
401e923
update
Borda Oct 31, 2024
4568288
refactor: multiclass macro avg ignore index
rittik9 Nov 2, 2024
20e116a
fix:MultilabelRecall
rittik9 Nov 2, 2024
7d804bd
refactor: precision_recall.py
rittik9 Nov 2, 2024
9c25e0b
Merge branch 'master' into master
rittik9 Nov 4, 2024
08bc4dd
Merge branch 'Lightning-AI:master' into master
rittik9 Nov 7, 2024
62436fd
Merge branch 'master' into master
Borda Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed multiclass recall macro avg. ignore index ([#2710](https://github.com/Lightning-AI/torchmetrics/pull/2710))


- Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805))


---
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ def compute(self) -> Tensor:
multidim_average=self.multidim_average,
top_k=self.top_k,
zero_division=self.zero_division,
ignore_index=self.ignore_index,
)

def plot(
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def _accuracy_reduce(
return _safe_divide(tp, tp + fn)

score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k)
Borda marked this conversation as resolved.
Show resolved Hide resolved
return _adjust_weights_safe_divide(
score=score, average=average, multilabel=multilabel, tp=tp, fp=fp, fn=fn, top_k=top_k
)


def binary_accuracy(
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _fbeta_reduce(
return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division)

fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division)
return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(score=fbeta_score, average=average, multilabel=multilabel, tp=tp, fp=fp, fn=fn)


def _binary_fbeta_score_arg_validation(
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _hamming_distance_reduce(
return 1 - _safe_divide(tp, tp + fn)

score = 1 - _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else 1 - _safe_divide(tp, tp + fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(score=score, average=average, multilabel=multilabel, tp=tp, fp=fp, fn=fn)


def binary_hamming_distance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def _negative_predictive_value_reduce(
fn = fn.sum(dim=0 if multidim_average == "global" else 1)
return _safe_divide(tn, tn + fn, zero_division)
score = _safe_divide(tn, tn + fn, zero_division)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)
return _adjust_weights_safe_divide(
score=score, average=average, multilabel=multilabel, tp=tp, fp=fp, fn=fn, top_k=top_k
)


def binary_negative_predictive_value(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _precision_recall_reduce(
multilabel: bool = False,
top_k: int = 1,
zero_division: float = 0,
ignore_index: Optional[int] = None,
) -> Tensor:
different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
if average == "binary":
Expand All @@ -56,7 +57,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat, zero_division)

score = _safe_divide(tp, tp + different_stat, zero_division)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k, ignore_index=ignore_index)


def binary_precision(
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/functional/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _specificity_reduce(
return _safe_divide(tn, tn + fp)

specificity_score = _safe_divide(tn, tn + fp)
return _adjust_weights_safe_divide(specificity_score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(
score=specificity_score, average=average, multilabel=multilabel, tp=tp, fp=fp, fn=fn
)


def binary_specificity(
Expand Down
14 changes: 13 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1
score: Tensor,
average: Optional[str],
multilabel: bool,
tp: Tensor,
fp: Tensor,
fn: Tensor,
top_k: int = 1,
ignore_index: Optional[int] = None,
) -> Tensor:
if average is None or average == "none":
return score
Expand All @@ -78,6 +85,11 @@ def _adjust_weights_safe_divide(
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0

# Add this line to handle ignore_index
if ignore_index is not None and 0 <= ignore_index < len(score):
weights[ignore_index] = 0.0

return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


Expand Down
28 changes: 28 additions & 0 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,34 @@ def test_corner_case():
assert res == 1.0


@pytest.mark.parametrize(
("ignore_index", "average", "expected"),
[
(0, "macro", 0.5),
(1, "macro", 1.0),
(None, "macro", 0.75),
(0, "none", torch.tensor([0.0, 0.5])),
],
)
def test_multiclass_recall_ignore_index(ignore_index, average, expected):
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441."""
y_true = torch.tensor([0, 0, 1, 1])
y_pred = torch.tensor([
[0.9, 0.1],
[0.9, 0.1],
[0.9, 0.1],
[0.1, 0.9],
])

# Test with ignore_index=0 and average="macro"
metric_ignore = MulticlassRecall(num_classes=2, ignore_index=ignore_index, average=average)
res_ignore = metric_ignore(y_pred, y_true)
if isinstance(expected, float):
assert res_ignore == expected, f"Expected {expected}, but got {res_ignore}"
else:
assert torch.allclose(res_ignore, expected), f"Expected {expected}, but got {res_ignore}"


@pytest.mark.parametrize(
("metric", "kwargs", "base_metric"),
[
Expand Down
Loading