diff --git a/src/otx/algo/common/losses/cross_focal_loss.py b/src/otx/algo/common/losses/cross_focal_loss.py index e6311dd0ae..7744a9e511 100644 --- a/src/otx/algo/common/losses/cross_focal_loss.py +++ b/src/otx/algo/common/losses/cross_focal_loss.py @@ -9,7 +9,7 @@ import torch.nn.functional from otx.utils.device import get_available_device from torch import Tensor, nn -from torch.amp import custom_fwd +from torch.cuda.amp import custom_fwd from .focal_loss import py_sigmoid_focal_loss diff --git a/tests/unit/core/model/test_detection_3d.py b/tests/unit/core/model/test_detection_3d.py index f46dc212b8..673133a456 100644 --- a/tests/unit/core/model/test_detection_3d.py +++ b/tests/unit/core/model/test_detection_3d.py @@ -34,6 +34,7 @@ def label_info(self) -> LabelInfo: return LabelInfo( label_names=["label_0", "label_1"], label_groups=[["label_0", "label_1"]], + label_ids=["0", "1"], ) @pytest.fixture() @@ -61,8 +62,8 @@ def test_export_parameters(self, model): ("label_info", "expected_label_info"), [ ( - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), ), (LabelInfo.from_num_classes(num_classes=5), LabelInfo.from_num_classes(num_classes=5)), ], diff --git a/tests/unit/core/model/test_keypoint_detection.py b/tests/unit/core/model/test_keypoint_detection.py index d3cc06fede..77e19d605e 100644 --- a/tests/unit/core/model/test_keypoint_detection.py +++ b/tests/unit/core/model/test_keypoint_detection.py @@ -34,6 +34,7 @@ def label_info(self) -> LabelInfo: return LabelInfo( label_names=["label_0", "label_1"], label_groups=[["label_0", "label_1"]], + label_ids=["0", "1"], ) @pytest.fixture() @@ -61,8 +62,8 @@ def test_export_parameters(self, model): ("label_info", "expected_label_info"), [ ( - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), - LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), + LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]), ), (LabelInfo.from_num_classes(num_classes=5), LabelInfo.from_num_classes(num_classes=5)), ],