Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Nov 21, 2024
1 parent d2ff90d commit 2cf8026
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
stack_images: bool = True,
to_tv_image: bool = True,
data_format="",
) -> None:
super().__init__(
dm_subset,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format="",
) -> None:
self.task_type = task_type
super().__init__(
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
stack_images: bool = True,
to_tv_image: bool = True,
ignore_index: int = 255,
data_format="",
) -> None:
super().__init__(
dm_subset,
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def fxt_seg_label_info() -> SegLabelInfo:
label_names,
["class2", "class3"],
],
label_ids=["0", "1", "2"],
)


Expand All @@ -382,6 +383,7 @@ def fxt_multiclass_labelinfo() -> LabelInfo:
label_names,
["class2", "class3"],
],
label_ids=["0", "1", "2"],
)


Expand All @@ -395,6 +397,7 @@ def fxt_multilabel_labelinfo() -> LabelInfo:
[label_names[1]],
[label_names[2]],
],
label_ids=["0", "1", "2"],
)


Expand Down Expand Up @@ -461,6 +464,7 @@ def fxt_hlabel_multilabel_info() -> HLabelInfo:
["Spade_A", "Spade"],
["Spade_King", "Spade"],
],
label_ids=[str(i) for i in range(9)],
)


Expand Down
2 changes: 2 additions & 0 deletions tests/unit/algo/classification/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def fxt_hlabel_data() -> HLabelInfo:
["Heart_Queen", "Heart_King"],
["Spade_A", "Spade_King"],
],
label_ids=[str(i) for i in range(6)],
num_multiclass_heads=3,
num_multilabel_classes=0,
head_idx_to_logits_range={"0": (0, 2), "1": (2, 4), "2": (4, 6)},
Expand Down Expand Up @@ -149,6 +150,7 @@ def fxt_hlabel_cifar() -> HLabelInfo:
"aquatic_mammals",
"fish",
],
label_ids=[str(i) for i in range(12)],
label_groups=[
["beaver", "dolphin", "otter", "seal", "whale"],
["aquarium_fish", "flatfish", "ray", "shark", "trout"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def batch(self):
inputs = torch.randn(1, 3, 32, 32)
return DetBatchDataEntity(
batch_size=1,
imgs_info=[LabelInfo(["a"], [["a"]])],
imgs_info=[LabelInfo(["a"], ["0"], [["a"]])],
images=inputs,
bboxes=[torch.tensor([[0.5, 0.5, 0.5, 0.5]])],
labels=[torch.tensor([0])],
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/algo/detection/test_rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class TestRTDETR:
def test_customize_outputs(self, mocker):
label_info = LabelInfo(["a", "b", "c"], [["a", "b", "c"]])
label_info = LabelInfo(["a", "b", "c"], ["0", "1", "2"], [["a", "b", "c"]])
mocker.patch("otx.algo.detection.rtdetr.RTDETR._build_model", return_value=mocker.MagicMock())
model = RTDETR(label_info)
model.model.load_from = None
Expand Down
1 change: 1 addition & 0 deletions tests/unit/core/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def fxt_mock_hlabelinfo():
return HLabelInfo(
label_names=_LABEL_NAMES,
label_groups=[["Non-Rigid", "Rigid"], ["Rectangle", "Triangle"], ["Circle"], ["Lion"], ["Panda"]],
label_ids=_LABEL_NAMES,
num_multiclass_heads=2,
num_multilabel_classes=3,
head_idx_to_logits_range={"0": (0, 2), "1": (2, 4)},
Expand Down
1 change: 1 addition & 0 deletions tests/unit/core/data/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_create(
cfg_subset=cfg_subset,
vpm_config=vpm_config,
image_color_channel=image_color_channel,
data_format="",
),
dataset_cls,
)
13 changes: 11 additions & 2 deletions tests/unit/core/model/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def label_info(self):
return SegLabelInfo(
label_names=["Background", "label_0", "label_1"],
label_groups=[["Background", "label_0", "label_1"]],
label_ids=["0", "1", "2"]
)

@pytest.fixture()
Expand Down Expand Up @@ -64,8 +65,16 @@ def test_export_parameters(self, model):
("label_info", "expected_label_info"),
[
(
SegLabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]),
SegLabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]),
SegLabelInfo(
label_names=["label1", "label2", "label3"],
label_groups=[["label1", "label2", "label3"]],
label_ids=["0", "1", "2"],
),
SegLabelInfo(
label_names=["label1", "label2", "label3"],
label_groups=[["label1", "label2", "label3"]],
label_ids=["0", "1", "2"],
),
),
(SegLabelInfo.from_num_classes(num_classes=5), SegLabelInfo.from_num_classes(num_classes=5)),
],
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/engine/utils/test_auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def test_get_model(self, fxt_task: OTXTaskType) -> None:
# With label_info
label_names = ["class1", "class2", "class3"]
label_info = (
LabelInfo(label_names=label_names, label_groups=[label_names])
LabelInfo(label_names=label_names, label_groups=[label_names], label_ids=label_names)
if fxt_task != OTXTaskType.SEMANTIC_SEGMENTATION
else SegLabelInfo(label_names=label_names, label_groups=[label_names])
else SegLabelInfo(label_names=label_names, label_groups=[label_names], label_ids=label_names)
)
model = auto_configurator.get_model(label_info=label_info)
assert isinstance(model, OTXModel)
Expand All @@ -147,7 +147,7 @@ def test_get_model(self, fxt_task: OTXTaskType) -> None:
def test_get_model_set_input_size(self) -> None:
auto_configurator = AutoConfigurator(task=OTXTaskType.MULTI_CLASS_CLS)
label_names = ["class1", "class2", "class3"]
label_info = LabelInfo(label_names=label_names, label_groups=[label_names])
label_info = LabelInfo(label_names=label_names, label_groups=[label_names], label_ids=label_names)
input_size = 300

model = auto_configurator.get_model(label_info=label_info, input_size=input_size)
Expand Down

0 comments on commit 2cf8026

Please sign in to comment.