Skip to content

Commit

Permalink
Merge branch 'releases/2.2.0' into kp/update_converter
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi authored Dec 4, 2024
2 parents 6d29193 + c6e2952 commit cd14c7b
Show file tree
Hide file tree
Showing 32 changed files with 282 additions and 84 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4124>)
- Fix patching early stopping in tools/converter.py, update headers in templates, change training schedule for classification
(<https://github.com/openvinotoolkit/training_extensions/pull/4131>)
- Fix tensor type compatibility in dynamic soft label assigner and RTMDet head
(<https://github.com/openvinotoolkit/training_extensions/pull/4140>)

## \[v2.1.0\]

Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:

return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
10 changes: 4 additions & 6 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -313,14 +313,12 @@ def _build_model(self, head_config: dict) -> nn.Module:

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
in_channels = 960 if self.mode == "large" else 576

return HLabelClassifier(
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=in_channels),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.core.data.entity.classification import (
Expand Down Expand Up @@ -315,11 +315,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained)
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.in_features,
**head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**head_config, in_channels=backbone.in_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
8 changes: 2 additions & 6 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
MultiLabelLinearClsHead,
SemiSLVisionTransformerClsHead,
VisionTransformerClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.utils import get_classification_layers
from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -466,11 +466,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
return HLabelClassifier(
backbone=vit_backbone,
neck=None,
head=HierarchicalCBAMClsHead(
in_channels=vit_backbone.embed_dim,
step_size=1,
**head_config,
),
head=HierarchicalLinearClsHead(**head_config, in_channels=vit_backbone.embed_dim),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
init_cfg=init_cfg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def assign(
assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
max_overlaps = assigned_gt_inds.new_full((num_bboxes,), -INF, dtype=torch.float32)
max_overlaps[valid_mask] = matched_pred_ious
max_overlaps[valid_mask] = matched_pred_ious.to(max_overlaps)
return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)

def dynamic_k_matching(
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/detection/heads/rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _get_targets_single( # type: ignore[override]
if len(pos_inds) > 0:
# point-based
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_targets[pos_inds, :] = pos_bbox_targets.to(bbox_targets)

labels[pos_inds] = sampling_result.pos_gt_labels
if self.train_cfg["pos_weight"] <= 0:
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
self.task_type = task_type
super().__init__(
Expand Down
7 changes: 6 additions & 1 deletion src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class OTXDataset(Dataset, Generic[T_OTXDataEntity]):
max_refetch: Maximum number of images to fetch in cache
image_color_channel: Color channel of images
stack_images: Whether or not to stack images in collate function in OTXBatchData entity.
data_format: Source data format, which was originally passed to datumaro (could be arrow for instance).
"""

Expand All @@ -83,6 +84,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
self.dm_subset = dm_subset
self.transforms = transforms
Expand All @@ -92,8 +94,11 @@ def __init__(
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image
self.data_format = data_format

if self.dm_subset.categories():
if self.dm_subset.categories() and data_format == "arrow":
self.label_info = LabelInfo.from_dm_label_groups_arrow(self.dm_subset.categories()[AnnotationType.label])
elif self.dm_subset.categories():
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
else:
self.label_info = NullLabelInfo()
Expand Down
64 changes: 42 additions & 22 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,21 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
label_ids = set()
for ann in item.annotations:
# multilabel information stored in 'multi_label_ids' attribute when the source format is arrow
if "multi_label_ids" in ann.attributes:
for lbl_idx in ann.attributes["multi_label_ids"]:
label_ids.add(lbl_idx)

if isinstance(ann, Label):
label_anns.append(ann)
label_ids.add(ann.label)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
labels = torch.as_tensor([ann.label for ann in label_anns])
label_ids.add(label.label)
labels = torch.as_tensor(list(label_ids))

entity = MultilabelClsDataEntity(
image=img_data,
Expand Down Expand Up @@ -128,13 +132,22 @@ def __init__(self, **kwargs) -> None:
self.dm_categories = self.dm_subset.categories()[AnnotationType.label]

# Hlabel classification used HLabelInfo to insert the HLabelData.
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)
if self.data_format == "arrow":
# arrow format stores label IDs as names, have to deal with that here
self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories)
else:
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)

self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names))
self.id_to_name_mapping[""] = ""

if self.label_info.num_multiclass_heads == 0:
msg = "The number of multiclass heads should be larger than 0."
raise ValueError(msg)

for dm_item in self.dm_subset:
self._add_ancestors(dm_item.annotations)
if self.data_format != "arrow":
for dm_item in self.dm_subset:
self._add_ancestors(dm_item.annotations)

def _add_ancestors(self, label_anns: list[Label]) -> None:
"""Add ancestors recursively if some label miss the ancestor information.
Expand All @@ -149,14 +162,16 @@ def _add_ancestors(self, label_anns: list[Label]) -> None:
"""

def _label_idx_to_name(idx: int) -> str:
return self.label_info.label_names[idx]
return self.dm_categories[idx].name

def _label_name_to_idx(name: str) -> int:
indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name]
return indices[0]

def _get_label_group_idx(label_name: str) -> int:
if isinstance(self.label_info, HLabelInfo):
if self.data_format == "arrow":
return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0]
return self.label_info.class_to_group_idx[label_name][0]
msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}"
raise ValueError(msg)
Expand Down Expand Up @@ -197,17 +212,22 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
label_ids = set()
for ann in item.annotations:
# in h-cls scenario multilabel information stored in 'multi_label_ids' attribute
if "multi_label_ids" in ann.attributes:
for lbl_idx in ann.attributes["multi_label_ids"]:
label_ids.add(lbl_idx)

if isinstance(ann, Label):
label_anns.append(ann)
label_ids.add(ann.label)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
hlabel_labels = self._convert_label_to_hlabel_format(label_anns, ignored_labels)
label_ids.add(label.label)

hlabel_labels = self._convert_label_to_hlabel_format([Label(label=idx) for idx in label_ids], ignored_labels)

entity = HlabelClsDataEntity(
image=img_data,
Expand Down Expand Up @@ -256,18 +276,18 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label
class_indices[i] = -1

for ann in label_anns:
ann_name = self.dm_categories.items[ann.label].name
ann_parent = self.dm_categories.items[ann.label].parent
if self.data_format == "arrow":
# skips unknown labels for instance, the empty one
if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping:
continue
ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name]
else:
ann_name = self.dm_categories.items[ann.label].name
group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name]
(parent_group_idx, parent_in_group_idx) = (
self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None)
)

if group_idx < num_multiclass_heads:
class_indices[group_idx] = in_group_idx
if parent_group_idx is not None and parent_in_group_idx is not None:
class_indices[parent_group_idx] = parent_in_group_idx
elif not ignored_labels or ann.label not in ignored_labels:
elif ann.label not in ignored_labels:
class_indices[num_multiclass_heads + in_group_idx] = 1
else:
class_indices[num_multiclass_heads + in_group_idx] = -1
Expand Down
4 changes: 3 additions & 1 deletion src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(
self.dm_subset = self._get_single_bbox_dataset(dm_subset)

if self.dm_subset.categories():
kp_labels = self.dm_subset.categories()[AnnotationType.points][0].labels
self.label_info = LabelInfo(
label_names=self.dm_subset.categories()[AnnotationType.points][0].labels,
label_names=kp_labels,
label_groups=[],
label_ids=[str(i) for i in range(len(kp_labels))],
)
else:
self.label_info = NullLabelInfo()
Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
stack_images: bool = True,
to_tv_image: bool = True,
ignore_index: int = 255,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand All @@ -187,6 +188,7 @@ def __init__(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
label_ids=self.label_info.label_ids,
)
self.ignore_index = ignore_index

Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create( # noqa: PLR0911
dm_subset: DmDataset,
cfg_subset: SubsetConfig,
mem_cache_handler: MemCacheHandlerBase,
data_format: str,
mem_cache_img_max_size: tuple[int, int] | None = None,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -85,6 +86,7 @@ def create( # noqa: PLR0911
common_kwargs = {
"dm_subset": dm_subset,
"transforms": transforms,
"data_format": data_format,
"mem_cache_handler": mem_cache_handler,
"mem_cache_img_max_size": mem_cache_img_max_size,
"image_color_channel": image_color_channel,
Expand Down
Loading

0 comments on commit cd14c7b

Please sign in to comment.