From e51a4033a3e0795faf4e81197825500482d9aeb4 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Wed, 18 Dec 2024 23:21:55 +0900 Subject: [PATCH] update tests --- src/otx/algo/common/losses/cross_focal_loss.py | 2 +- tests/unit/core/model/test_detection_3d.py | 5 +++-- tests/unit/core/model/test_keypoint_detection.py | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/otx/algo/common/losses/cross_focal_loss.py b/src/otx/algo/common/losses/cross_focal_loss.py index e6311dd0ae0..7744a9e5117 100644 --- a/src/otx/algo/common/losses/cross_focal_loss.py +++ b/src/otx/algo/common/losses/cross_focal_loss.py @@ -9,7 +9,7 @@ import torch.nn.functional from otx.utils.device import get_available_device from torch import Tensor, nn -from torch.amp import custom_fwd +from torch.cuda.amp import custom_fwd from .focal_loss import py_sigmoid_focal_loss diff --git a/tests/unit/core/model/test_detection_3d.py b/tests/unit/core/model/test_detection_3d.py index f46dc212b8d..673133a4562 100644 --- a/tests/unit/core/model/test_detection_3d.py +++ b/tests/unit/core/model/test_detection_3d.py @@ -34,6 +34,7 @@ def label_info(self) -> LabelInfo: return LabelInfo( label_names=["label_0", "label_1"], label_groups=[["label_0", "label_1"]], + label_ids=["0", "1"], ) @pytest.fixture() @@ -61,8 +62,8 @@ def test_export_parameters(self, model): ("label_info", "expected_label_info"), [ ( - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), ), (LabelInfo.from_num_classes(num_classes=5), LabelInfo.from_num_classes(num_classes=5)), ], diff --git a/tests/unit/core/model/test_keypoint_detection.py b/tests/unit/core/model/test_keypoint_detection.py index d3cc06fede7..77e19d605ea 100644 --- a/tests/unit/core/model/test_keypoint_detection.py +++ b/tests/unit/core/model/test_keypoint_detection.py @@ -34,6 +34,7 @@ def label_info(self) -> LabelInfo: return LabelInfo( label_names=["label_0", "label_1"], label_groups=[["label_0", "label_1"]], + label_ids=["0", "1"], ) @pytest.fixture() @@ -61,8 +62,8 @@ def test_export_parameters(self, model): ("label_info", "expected_label_info"), [ ( - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), ), (LabelInfo.from_num_classes(num_classes=5), LabelInfo.from_num_classes(num_classes=5)), ],