Skip to content

Commit

Permalink
Allow empty tile annotation (#4124)
Browse files Browse the repository at this point in the history
* Add warnings for empty annotations in OTXTileDetTestDataset and OTXTileInstSegTestDataset

* Fix empty annotation handling in tiling
  • Loading branch information
eugene123tw authored Nov 22, 2024
1 parent 5e18121 commit ec610a9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4105>)
- Disable tiling classifier toggle in configurable parameters
(<https://github.com/openvinotoolkit/training_extensions/pull/4107>)
- Fix empty annotation in tiling
(<https://github.com/openvinotoolkit/training_extensions/pull/4124>)

## \[v2.1.0\]

Expand Down
21 changes: 14 additions & 7 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import logging as log
import operator
import warnings
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Callable
Expand Down Expand Up @@ -372,14 +373,17 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
img = item.media_as(Image)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]
gt_bboxes = [ann for ann in item.annotations if isinstance(ann, Bbox)]

if empty_anno := len(gt_bboxes) == 0:
warnings.warn(f"Empty annotation for image {item.id}!", stacklevel=2)

bboxes = (
np.stack([ann.points for ann in bbox_anns], axis=0).astype(np.float32)
if len(bbox_anns) > 0
else np.zeros((0, 4), dtype=np.float32)
np.empty((0, 4), dtype=np.float32)
if empty_anno
else np.stack([ann.points for ann in gt_bboxes], axis=0).astype(np.float32)
)
labels = torch.as_tensor([ann.label for ann in bbox_anns])
labels = torch.as_tensor([ann.label for ann in gt_bboxes])

tile_entities, tile_attrs = self.get_tiles(img_data, item, index)

Expand Down Expand Up @@ -476,11 +480,14 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
else:
gt_masks.append(polygon_to_bitmap([annotation], *img_shape)[0])

if empty_anno := len(gt_bboxes) == 0:
warnings.warn(f"Empty annotation for image {item.id}", stacklevel=2)

# convert xywh to xyxy format
bboxes = np.array(gt_bboxes, dtype=np.float32)
bboxes = np.empty((0, 4), dtype=np.float32) if empty_anno else np.stack(gt_bboxes, dtype=np.float32)
bboxes[:, 2:] += bboxes[:, :2]

masks = np.stack(gt_masks, axis=0) if gt_masks else np.zeros((0, *img_shape), dtype=bool)
masks = np.stack(gt_masks, axis=0) if gt_masks else np.empty((0, *img_shape), dtype=bool)
labels = np.array(gt_labels, dtype=np.int64)

tile_entities, tile_attrs = self.get_tiles(img_data, item, index)
Expand Down

0 comments on commit ec610a9

Please sign in to comment.