diff --git a/CHANGELOG.md b/CHANGELOG.md index c22a5a19710..de726f75ba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,8 +50,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) + ### Fixed +- Fixed segmentation `Dice` + `GeneralizedDice` for 2d index tensors ([#2832](https://github.com/Lightning-AI/torchmetrics/pull/2832)) + + - Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830)) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index de9a8068bf9..53c4b3d6622 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -49,13 +49,14 @@ def _dice_score_update( ) -> tuple[Tensor, Tensor, Tensor]: """Update the state with the current prediction and target.""" _check_same_shape(preds, target) - if preds.ndim < 3: - raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + if preds.ndim < 3: + raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") + if not include_background: preds, target = _ignore_background(preds, target) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 8bfc9bab18a..efa6143d87a 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import torch from torch import Tensor from typing_extensions import Literal @@ -49,16 +51,17 @@ def _generalized_dice_update( include_background: bool, weight_type: Literal["square", "simple", "linear"] = "square", input_format: Literal["one-hot", "index"] = "one-hot", -) -> Tensor: +) -> Tuple[Tensor, Tensor]: """Update the state with the current prediction and target.""" _check_same_shape(preds, target) - if preds.ndim < 3: - raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + if preds.ndim < 3: + raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") + if not include_background: preds, target = _ignore_background(preds, target) @@ -67,7 +70,6 @@ def _generalized_dice_update( target_sum = torch.sum(target, dim=reduce_axis) pred_sum = torch.sum(preds, dim=reduce_axis) cardinality = target_sum + pred_sum - if weight_type == "simple": weights = 1.0 / target_sum elif weight_type == "linear": @@ -89,7 +91,7 @@ def _generalized_dice_update( numerator = 2.0 * intersection * weights denominator = cardinality * weights - return numerator, denominator # type:ignore[return-value] + return numerator, denominator def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor: diff --git a/tests/unittests/segmentation/inputs.py b/tests/unittests/segmentation/inputs.py index b773ba29ebd..ce46c8b597c 100644 --- a/tests/unittests/segmentation/inputs.py +++ b/tests/unittests/segmentation/inputs.py @@ -34,3 +34,7 @@ preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), ) +_input4 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)), +) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index d5bfc08b4ae..b009401f481 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -22,7 +22,7 @@ from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester -from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 +from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3 seed_all(42) @@ -55,6 +55,7 @@ def _reference_dice_score( (_inputs1.preds, _inputs1.target, "one-hot"), (_inputs2.preds, _inputs2.target, "one-hot"), (_inputs3.preds, _inputs3.target, "index"), + (_input4.preds, _input4.target, "index"), ], ) @pytest.mark.parametrize("include_background", [True, False]) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 6c353800379..c87fd6aa22e 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -24,7 +24,7 @@ from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester -from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 +from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3 seed_all(42) @@ -53,6 +53,7 @@ def _reference_generalized_dice( (_inputs1.preds, _inputs1.target, "one-hot"), (_inputs2.preds, _inputs2.target, "one-hot"), (_inputs3.preds, _inputs3.target, "index"), + (_input4.preds, _input4.target, "index"), ], ) @pytest.mark.parametrize("include_background", [True, False])