Skip to content

Commit

Permalink
tests: cleaning classif. (#2815)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Oct 31, 2024
1 parent d52cb48 commit dfc9e33
Show file tree
Hide file tree
Showing 21 changed files with 114 additions and 79 deletions.
26 changes: 25 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens
denom: denominator tensor, which may contain zeros
zero_division: value to replace elements divided by zero
Example:
>>> import torch
>>> num = torch.tensor([1.0, 2.0, 3.0])
>>> denom = torch.tensor([0.0, 1.0, 2.0])
>>> _safe_divide(num, denom)
tensor([0.0000, 2.0000, 1.5000])
"""
num = num if num.is_floating_point() else num.float()
denom = denom if denom.is_floating_point() else denom.float()
Expand Down Expand Up @@ -102,6 +109,16 @@ def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float, axis: int


def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
"""Compute area under the curve using the trapezoidal rule.
Example:
>>> import torch
>>> x = torch.tensor([1, 2, 3, 4])
>>> y = torch.tensor([1, 2, 3, 4])
>>> _auc_compute(x, y)
tensor(7.5000)
"""
with torch.no_grad():
if reorder:
x, x_idx = torch.sort(x, stable=True)
Expand Down Expand Up @@ -139,7 +156,7 @@ def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
"""One-dimensional linear interpolation for monotonically increasing sample points.
Returns the one-dimensional piecewise linear interpolant to a function with
Returns the one-dimensional piecewise linear interpolation to a function with
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
Adjusted version of this https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964
Expand All @@ -152,6 +169,13 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
Returns:
the interpolated values, same size as `x`.
Example:
>>> x = torch.tensor([0.5, 1.5, 2.5])
>>> xp = torch.tensor([1, 2, 3])
>>> fp = torch.tensor([1, 2, 3])
>>> interp(x, xp, fp)
tensor([0.5000, 1.5000, 2.5000])
"""
m = _safe_divide(fp[1:] - fp[:-1], xp[1:] - xp[:-1])
b = fp[:-1] - (m * xp[:-1])
Expand Down
11 changes: 10 additions & 1 deletion tests/unittests/_helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,16 @@ def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor:


def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> Tuple[Tensor, Tensor]:
"""Remove samples that are equal to the ignore_index in comparison functions."""
"""Remove samples that are equal to the ignore_index in comparison functions.
Example:
>>> target = torch.tensor([0, 1, 2, 3, 4])
>>> preds = torch.tensor([0, 1, 2, 3, 4])
>>> ignore_index = 2
>>> remove_ignore_index(target, preds, ignore_index)
(tensor([0, 1, 3, 4]), tensor([0, 1, 3, 4]))
"""
if ignore_index is not None:
idx = target == ignore_index
target, preds = deepcopy(target[~idx]), deepcopy(preds[~idx])
Expand Down
16 changes: 8 additions & 8 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def _reference_sklearn_accuracy_binary(preds, target, ignore_index, multidim_ave
preds = (preds >= THRESHOLD).astype(np.uint8)

if multidim_average == "global":
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return _reference_sklearn_accuracy(target, preds)

res = []
for pred, true in zip(preds, target):
pred = pred.flatten()
true = true.flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
res.append(_reference_sklearn_accuracy(true, pred))
return np.stack(res)

Expand Down Expand Up @@ -185,7 +185,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim
if multidim_average == "global":
preds = preds.numpy().flatten()
target = target.numpy().flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
if average == "micro":
return _reference_sklearn_accuracy(target, preds)
confmat = sk_confusion_matrix(target, preds, labels=list(range(NUM_CLASSES)))
Expand All @@ -207,7 +207,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim
for pred, true in zip(preds, target):
pred = pred.flatten()
true = true.flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
if average == "micro":
res.append(_reference_sklearn_accuracy(true, pred))
else:
Expand Down Expand Up @@ -445,13 +445,13 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim
if average == "micro":
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return _reference_sklearn_accuracy(target, preds)

accuracy, weights = [], []
for i in range(preds.shape[1]):
pred, true = preds[:, i].flatten(), target[:, i].flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
accuracy.append(_reference_sklearn_accuracy(true, pred))
weights.append(confmat[1, 1] + confmat[1, 0])
Expand All @@ -472,15 +472,15 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim
for i in range(preds.shape[0]):
if average == "micro":
pred, true = preds[i].flatten(), target[i].flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
accuracy.append(_reference_sklearn_accuracy(true, pred))
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
weights.append(confmat[1, 1] + confmat[1, 0])
else:
scores, w = [], []
for j in range(preds.shape[1]):
pred, true = preds[i, j], target[i, j]
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
scores.append(_reference_sklearn_accuracy(true, pred))
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
w.append(confmat[1, 1] + confmat[1, 0])
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _reference_sklearn_auroc_binary(preds, target, max_fpr=None, ignore_index=No
target = target.flatten().numpy()
if not ((preds > 0) & (preds < 1)).all():
preds = sigmoid(preds)
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_roc_auc_score(target, preds, max_fpr=max_fpr)


Expand Down Expand Up @@ -144,7 +144,7 @@ def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_i
target = target.numpy().flatten()
if not ((preds > 0) & (preds < 1)).all():
preds = softmax(preds, 1)
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_roc_auc_score(target, preds, average=average, multi_class="ovr", labels=list(range(NUM_CLASSES)))


Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _reference_sklearn_avg_precision_binary(preds, target, ignore_index=None):
target = target.flatten().numpy()
if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all():
preds = sigmoid(preds)
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_average_precision_score(target, preds)


Expand Down Expand Up @@ -156,7 +156,7 @@ def _reference_sklearn_avg_precision_multiclass(preds, target, average="macro",
target = target.numpy().flatten()
if not ((preds > 0) & (preds < 1)).all():
preds = softmax(preds, 1)
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)

res = []
for i in range(NUM_CLASSES):
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _reference_netcal_binary_calibration_error(preds, target, n_bins, norm, igno
target = target.numpy().flatten()
if not ((preds > 0) & (preds < 1)).all():
preds = sigmoid(preds)
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
metric = ECE if norm == "l1" else MCE
return metric(n_bins).measure(preds, target)

Expand Down Expand Up @@ -149,7 +149,7 @@ def _reference_netcal_multiclass_calibration_error(preds, target, n_bins, norm,
if not ((preds > 0) & (preds < 1)).all():
preds = softmax(preds, 1)
preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1]))
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
metric = ECE if norm == "l1" else MCE
return metric(n_bins).measure(preds, target)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _reference_sklearn_cohen_kappa_binary(preds, target, weights=None, ignore_in
if not ((preds > 0) & (preds < 1)).all():
preds = sigmoid(preds)
preds = (preds >= THRESHOLD).astype(np.uint8)
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_cohen_kappa(y1=target, y2=preds, weights=weights)


Expand Down Expand Up @@ -136,7 +136,7 @@ def _reference_sklearn_cohen_kappa_multiclass(preds, target, weights, ignore_ind
preds = np.argmax(preds, axis=1)
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_cohen_kappa(y1=target, y2=preds, weights=weights)


Expand Down
6 changes: 3 additions & 3 deletions tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _reference_sklearn_confusion_matrix_binary(preds, target, normalize=None, ig
if not ((preds > 0) & (preds < 1)).all():
preds = sigmoid(preds)
preds = (preds >= THRESHOLD).astype(np.uint8)
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1], normalize=normalize)


Expand Down Expand Up @@ -147,7 +147,7 @@ def _reference_sklearn_confusion_matrix_multiclass(preds, target, normalize=None
preds = np.argmax(preds, axis=1)
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES)))


Expand Down Expand Up @@ -298,7 +298,7 @@ def _reference_sklearn_confusion_matrix_multilabel(preds, target, normalize=None
confmat = []
for i in range(preds.shape[1]):
pred, true = preds[:, i], target[:, i]
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
confmat.append(sk_confusion_matrix(true, pred, normalize=normalize, labels=[0, 1]))
return np.stack(confmat, axis=0)

Expand Down
16 changes: 8 additions & 8 deletions tests/unittests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def _reference_sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, mu
preds = (preds >= THRESHOLD).astype(np.uint8)

if multidim_average == "global":
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_fn(target, preds, zero_division=zero_division)

res = []
for pred, true in zip(preds, target):
pred = pred.flatten()
true = true.flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
res.append(sk_fn(true, pred, zero_division=zero_division))
return np.stack(res)

Expand Down Expand Up @@ -205,7 +205,7 @@ def _reference_sklearn_fbeta_score_multiclass(
if multidim_average == "global":
preds = preds.numpy().flatten()
target = target.numpy().flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_fn(
target,
preds,
Expand All @@ -220,7 +220,7 @@ def _reference_sklearn_fbeta_score_multiclass(
for pred, true in zip(preds, target):
pred = pred.flatten()
true = true.flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)

if len(pred) == 0 and average == "weighted":
# The result of sk_fn([], [], labels=None, average="weighted", zero_division=zero_division)
Expand Down Expand Up @@ -417,13 +417,13 @@ def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignor
if average == "micro":
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return sk_fn(target, preds, zero_division=zero_division)

fbeta_score, weights = [], []
for i in range(preds.shape[1]):
pred, true = preds[:, i].flatten(), target[:, i].flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
fbeta_score.append(sk_fn(true, pred, zero_division=zero_division))
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
weights.append(confmat[1, 1] + confmat[1, 0])
Expand All @@ -446,15 +446,15 @@ def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore
for i in range(preds.shape[0]):
if average == "micro":
pred, true = preds[i].flatten(), target[i].flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
fbeta_score.append(sk_fn(true, pred, zero_division=zero_division))
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
weights.append(confmat[1, 1] + confmat[1, 0])
else:
scores, w = [], []
for j in range(preds.shape[1]):
pred, true = preds[i, j], target[i, j]
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
scores.append(sk_fn(true, pred, zero_division=zero_division))
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
w.append(confmat[1, 1] + confmat[1, 0])
Expand Down
16 changes: 8 additions & 8 deletions tests/unittests/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def _reference_sklearn_hamming_distance_binary(preds, target, ignore_index, mult
preds = (preds >= THRESHOLD).astype(np.uint8)

if multidim_average == "global":
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return _reference_sklearn_hamming_loss(target, preds)

res = []
for pred, true in zip(preds, target):
pred = pred.flatten()
true = true.flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
res.append(_reference_sklearn_hamming_loss(true, pred))
return np.stack(res)

Expand Down Expand Up @@ -167,7 +167,7 @@ def test_binary_hamming_distance_dtype_gpu(self, inputs, dtype):
def _reference_sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, average):
preds = preds.numpy().flatten()
target = target.numpy().flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
if average == "micro":
return _reference_sklearn_hamming_loss(target, preds)
confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES)))
Expand All @@ -191,7 +191,7 @@ def _reference_sklearn_hamming_distance_multiclass_local(preds, target, ignore_i
for pred, true in zip(preds, target):
pred = pred.flatten()
true = true.flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
if average == "micro":
res.append(_reference_sklearn_hamming_loss(true, pred))
else:
Expand Down Expand Up @@ -331,13 +331,13 @@ def _reference_sklearn_hamming_distance_multilabel_global(preds, target, ignore_
if average == "micro":
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index)
return _reference_sklearn_hamming_loss(target, preds)

hamming, weights = [], []
for i in range(preds.shape[1]):
pred, true = preds[:, i].flatten(), target[:, i].flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
hamming.append(_reference_sklearn_hamming_loss(true, pred))
weights.append(confmat[1, 1] + confmat[1, 0])
Expand All @@ -360,13 +360,13 @@ def _reference_sklearn_hamming_distance_multilabel_local(preds, target, ignore_i
for i in range(preds.shape[0]):
if average == "micro":
pred, true = preds[i].flatten(), target[i].flatten()
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
hamming.append(_reference_sklearn_hamming_loss(true, pred))
else:
scores, w = [], []
for j in range(preds.shape[1]):
pred, true = preds[i, j], target[i, j]
true, pred = remove_ignore_index(true, pred, ignore_index)
true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index)
scores.append(_reference_sklearn_hamming_loss(true, pred))
confmat = sk_confusion_matrix(true, pred, labels=[0, 1])
w.append(confmat[1, 1] + confmat[1, 0])
Expand Down
Loading

0 comments on commit dfc9e33

Please sign in to comment.