Skip to content

Commit

Permalink
Fix segmentation Dice + GeneralizedDice for 2d index tensors (#2832)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Nov 12, 2024
1 parent e2543c8 commit 8f6936d
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))


### Fixed

- Fixed segmentation `Dice` + `GeneralizedDice` for 2d index tensors ([#2832](https://github.com/Lightning-AI/torchmetrics/pull/2832))


- Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830))


Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def _dice_score_update(
) -> tuple[Tensor, Tensor, Tensor]:
"""Update the state with the current prediction and target."""
_check_same_shape(preds, target)
if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if not include_background:
preds, target = _ignore_background(preds, target)

Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor
from typing_extensions import Literal
Expand Down Expand Up @@ -49,16 +51,17 @@ def _generalized_dice_update(
include_background: bool,
weight_type: Literal["square", "simple", "linear"] = "square",
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
"""Update the state with the current prediction and target."""
_check_same_shape(preds, target)
if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if not include_background:
preds, target = _ignore_background(preds, target)

Expand All @@ -67,7 +70,6 @@ def _generalized_dice_update(
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)
cardinality = target_sum + pred_sum

if weight_type == "simple":
weights = 1.0 / target_sum
elif weight_type == "linear":
Expand All @@ -89,7 +91,7 @@ def _generalized_dice_update(

numerator = 2.0 * intersection * weights
denominator = cardinality * weights
return numerator, denominator # type:ignore[return-value]
return numerator, denominator


def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor:
Expand Down
4 changes: 4 additions & 0 deletions tests/unittests/segmentation/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@
preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
)
_input4 = _Input(
preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)),
target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)),
)
3 changes: 2 additions & 1 deletion tests/unittests/segmentation/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from unittests import NUM_CLASSES
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester
from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3
from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3

seed_all(42)

Expand Down Expand Up @@ -55,6 +55,7 @@ def _reference_dice_score(
(_inputs1.preds, _inputs1.target, "one-hot"),
(_inputs2.preds, _inputs2.target, "one-hot"),
(_inputs3.preds, _inputs3.target, "index"),
(_input4.preds, _input4.target, "index"),
],
)
@pytest.mark.parametrize("include_background", [True, False])
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from unittests import NUM_CLASSES
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester
from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3
from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3

seed_all(42)

Expand Down Expand Up @@ -53,6 +53,7 @@ def _reference_generalized_dice(
(_inputs1.preds, _inputs1.target, "one-hot"),
(_inputs2.preds, _inputs2.target, "one-hot"),
(_inputs3.preds, _inputs3.target, "index"),
(_input4.preds, _input4.target, "index"),
],
)
@pytest.mark.parametrize("include_background", [True, False])
Expand Down

0 comments on commit 8f6936d

Please sign in to comment.