Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Label Info handling #4127

Merged
merged 19 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:

return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
10 changes: 4 additions & 6 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -313,14 +313,12 @@ def _build_model(self, head_config: dict) -> nn.Module:

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
in_channels = 960 if self.mode == "large" else 576

return HLabelClassifier(
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=in_channels),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.core.data.entity.classification import (
Expand Down Expand Up @@ -315,11 +315,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained)
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.in_features,
**head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**head_config, in_channels=backbone.in_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
8 changes: 2 additions & 6 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
MultiLabelLinearClsHead,
SemiSLVisionTransformerClsHead,
VisionTransformerClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.utils import get_classification_layers
from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -466,11 +466,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
return HLabelClassifier(
backbone=vit_backbone,
neck=None,
head=HierarchicalCBAMClsHead(
in_channels=vit_backbone.embed_dim,
step_size=1,
**head_config,
),
head=HierarchicalLinearClsHead(**head_config, in_channels=vit_backbone.embed_dim),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
init_cfg=init_cfg,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
self.task_type = task_type
super().__init__(
Expand Down
7 changes: 6 additions & 1 deletion src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
max_refetch: Maximum number of images to fetch in cache
image_color_channel: Color channel of images
stack_images: Whether or not to stack images in collate function in OTXBatchData entity.
data_format: Source data format, which was originally passed to datumaro (could be arrow for instance).

"""

Expand All @@ -83,6 +84,7 @@
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.dm_subset = dm_subset
self.transforms = transforms
Expand All @@ -92,8 +94,11 @@
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image
self.data_format = data_format

if self.dm_subset.categories():
if self.dm_subset.categories() and data_format == "arrow":
self.label_info = LabelInfo.from_dm_label_groups_arrow(self.dm_subset.categories()[AnnotationType.label])

Check warning on line 100 in src/otx/core/data/dataset/base.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/base.py#L100

Added line #L100 was not covered by tests
elif self.dm_subset.categories():
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
else:
self.label_info = NullLabelInfo()
Expand Down
64 changes: 42 additions & 22 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,21 @@
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
label_ids = set()
for ann in item.annotations:
# multilabel information stored in 'multi_label_ids' attribute when the source format is arrow
if "multi_label_ids" in ann.attributes:
for lbl_idx in ann.attributes["multi_label_ids"]:
label_ids.add(lbl_idx)

Check warning on line 88 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L87-L88

Added lines #L87 - L88 were not covered by tests

if isinstance(ann, Label):
label_anns.append(ann)
label_ids.add(ann.label)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
labels = torch.as_tensor([ann.label for ann in label_anns])
label_ids.add(label.label)
labels = torch.as_tensor(list(label_ids))

entity = MultilabelClsDataEntity(
image=img_data,
Expand Down Expand Up @@ -128,13 +132,22 @@
self.dm_categories = self.dm_subset.categories()[AnnotationType.label]

# Hlabel classification used HLabelInfo to insert the HLabelData.
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)
if self.data_format == "arrow":
# arrow format stores label IDs as names, have to deal with that here
self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories)

Check warning on line 137 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L137

Added line #L137 was not covered by tests
else:
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)

self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names))
self.id_to_name_mapping[""] = ""

if self.label_info.num_multiclass_heads == 0:
msg = "The number of multiclass heads should be larger than 0."
raise ValueError(msg)

for dm_item in self.dm_subset:
self._add_ancestors(dm_item.annotations)
if self.data_format != "arrow":
for dm_item in self.dm_subset:
self._add_ancestors(dm_item.annotations)

def _add_ancestors(self, label_anns: list[Label]) -> None:
"""Add ancestors recursively if some label miss the ancestor information.
Expand All @@ -149,14 +162,16 @@
"""

def _label_idx_to_name(idx: int) -> str:
return self.label_info.label_names[idx]
return self.dm_categories[idx].name

def _label_name_to_idx(name: str) -> int:
indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name]
return indices[0]

def _get_label_group_idx(label_name: str) -> int:
if isinstance(self.label_info, HLabelInfo):
if self.data_format == "arrow":
return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0]

Check warning on line 174 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L174

Added line #L174 was not covered by tests
return self.label_info.class_to_group_idx[label_name][0]
msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}"
raise ValueError(msg)
Expand Down Expand Up @@ -197,17 +212,22 @@
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
label_ids = set()
for ann in item.annotations:
# in h-cls scenario multilabel information stored in 'multi_label_ids' attribute
if "multi_label_ids" in ann.attributes:
for lbl_idx in ann.attributes["multi_label_ids"]:
label_ids.add(lbl_idx)

Check warning on line 220 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L219-L220

Added lines #L219 - L220 were not covered by tests

if isinstance(ann, Label):
label_anns.append(ann)
label_ids.add(ann.label)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
hlabel_labels = self._convert_label_to_hlabel_format(label_anns, ignored_labels)
label_ids.add(label.label)

hlabel_labels = self._convert_label_to_hlabel_format([Label(label=idx) for idx in label_ids], ignored_labels)

entity = HlabelClsDataEntity(
image=img_data,
Expand Down Expand Up @@ -256,18 +276,18 @@
class_indices[i] = -1

for ann in label_anns:
ann_name = self.dm_categories.items[ann.label].name
ann_parent = self.dm_categories.items[ann.label].parent
if self.data_format == "arrow":
# skips unknown labels for instance, the empty one
if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping:
continue
ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name]

Check warning on line 283 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L281-L283

Added lines #L281 - L283 were not covered by tests
else:
ann_name = self.dm_categories.items[ann.label].name
group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name]
(parent_group_idx, parent_in_group_idx) = (
self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None)
)

if group_idx < num_multiclass_heads:
class_indices[group_idx] = in_group_idx
if parent_group_idx is not None and parent_in_group_idx is not None:
class_indices[parent_group_idx] = parent_in_group_idx
elif not ignored_labels or ann.label not in ignored_labels:
elif ann.label not in ignored_labels:

Check warning on line 290 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L290

Added line #L290 was not covered by tests
class_indices[num_multiclass_heads + in_group_idx] = 1
else:
class_indices[num_multiclass_heads + in_group_idx] = -1
Expand Down
4 changes: 3 additions & 1 deletion src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(
self.dm_subset = self._get_single_bbox_dataset(dm_subset)

if self.dm_subset.categories():
kp_labels = self.dm_subset.categories()[AnnotationType.points][0].labels
self.label_info = LabelInfo(
label_names=self.dm_subset.categories()[AnnotationType.points][0].labels,
label_names=kp_labels,
label_groups=[],
label_ids=[str(i) for i in range(len(kp_labels))],
)
else:
self.label_info = NullLabelInfo()
Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
stack_images: bool = True,
to_tv_image: bool = True,
ignore_index: int = 255,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand All @@ -187,6 +188,7 @@ def __init__(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
label_ids=self.label_info.label_ids,
)
self.ignore_index = ignore_index

Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create( # noqa: PLR0911
dm_subset: DmDataset,
cfg_subset: SubsetConfig,
mem_cache_handler: MemCacheHandlerBase,
data_format: str,
mem_cache_img_max_size: tuple[int, int] | None = None,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -85,6 +86,7 @@ def create( # noqa: PLR0911
common_kwargs = {
"dm_subset": dm_subset,
"transforms": transforms,
"data_format": data_format,
"mem_cache_handler": mem_cache_handler,
"mem_cache_img_max_size": mem_cache_img_max_size,
"image_color_channel": image_color_channel,
Expand Down
10 changes: 3 additions & 7 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,6 @@ def __init__( # noqa: PLR0913
self.subsets: dict[str, OTXDataset] = {}
self.save_hyperparameters(ignore=["input_size"])

# TODO (Jaeguk): This is workaround for a bug in Datumaro.
# These lines should be removed after next datumaro release.
# https://github.com/openvinotoolkit/datumaro/pull/1223/files
from datumaro.plugins.data_formats.video import VIDEO_EXTENSIONS

VIDEO_EXTENSIONS.append(".mp4")

dataset = DmDataset.import_from(self.data_root, format=self.data_format)
if self.task != "H_LABEL_CLS":
dataset = pre_filtering(
Expand Down Expand Up @@ -193,6 +186,7 @@ def __init__( # noqa: PLR0913
dm_subset=dm_subset.as_dataset(),
cfg_subset=config_mapping[name],
mem_cache_handler=mem_cache_handler,
data_format=self.data_format,
mem_cache_img_max_size=mem_cache_img_max_size,
image_color_channel=image_color_channel,
stack_images=stack_images,
Expand Down Expand Up @@ -237,6 +231,7 @@ def __init__( # noqa: PLR0913
include_polygons=include_polygons,
ignore_index=ignore_index,
vpm_config=vpm_config,
data_format=self.data_format,
)
self.subsets[transform_key] = unlabeled_dataset
else:
Expand All @@ -251,6 +246,7 @@ def __init__( # noqa: PLR0913
include_polygons=include_polygons,
ignore_index=ignore_index,
vpm_config=vpm_config,
data_format=self.data_format,
)
self.subsets[name] = unlabeled_dataset

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/pre_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def remove_unused_labels(dataset: DmDataset, data_format: str, ignore_index: int
used_labels = [0, *used_labels]
if data_format == "common_semantic_segmentation_with_subset_dirs" and len(original_categories) < len(used_labels):
msg = (
"There are labeles mismatch in dataset categories and actuall categories comes from semantic masks."
"There are labels mismatch in dataset categories and actual categories comes from semantic masks."
"Please, check `dataset_meta.json` file."
)
raise ValueError(msg)
Expand Down
Loading
Loading