diff --git a/tests/unit/algo/classification/conftest.py b/tests/unit/algo/classification/conftest.py index bf953c9bc4..a283eff41b 100644 --- a/tests/unit/algo/classification/conftest.py +++ b/tests/unit/algo/classification/conftest.py @@ -81,6 +81,7 @@ def fxt_hlabel_multilabel_info() -> HLabelInfo: "Red_Joker", "Extra_Joker", ], + label_ids=[str(i) for i in range(9)], label_groups=[ ["Heart", "Spade"], ["Heart_Queen", "Heart_King"], diff --git a/tests/unit/core/metrics/test_accuracy.py b/tests/unit/core/metrics/test_accuracy.py index d3c43a8a08..73486330a3 100644 --- a/tests/unit/core/metrics/test_accuracy.py +++ b/tests/unit/core/metrics/test_accuracy.py @@ -52,7 +52,7 @@ def test_default_multi_class_cls_metric_callable(self, fxt_multiclass_labelinfo: metric = MultiClassClsMetricCallable(fxt_multiclass_labelinfo) assert isinstance(metric.accuracy, MulticlassAccuracy) - one_class_label_info = LabelInfo(label_names=["class1"], label_groups=[["class1"]]) + one_class_label_info = LabelInfo(label_names=["class1"], label_groups=[["class1"]], label_ids=["0"]) assert one_class_label_info.num_classes == 1 binary_metric = MultiClassClsMetricCallable(one_class_label_info) assert isinstance(binary_metric.accuracy, BinaryAccuracy) diff --git a/tests/unit/core/types/test_label.py b/tests/unit/core/types/test_label.py index 3ae1ae1f46..8b853e4c4c 100644 --- a/tests/unit/core/types/test_label.py +++ b/tests/unit/core/types/test_label.py @@ -17,9 +17,10 @@ def test_seg_label_info(): # Automatically insert background label at zero index assert SegLabelInfo.from_num_classes(3) == SegLabelInfo( ["label_0", "label_1", "label_2"], + ["0", "1", "2"], [["label_0", "label_1", "label_2"]], ) - assert SegLabelInfo.from_num_classes(1) == SegLabelInfo(["background", "label_0"], [["background", "label_0"]]) + assert SegLabelInfo.from_num_classes(1) == SegLabelInfo(["background", "label_0"], ["0", "1"], [["background", "label_0"]]) assert SegLabelInfo.from_num_classes(0) == NullLabelInfo()