Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi committed Dec 18, 2024
1 parent dd56135 commit e51a403
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/otx/algo/common/losses/cross_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn.functional
from otx.utils.device import get_available_device
from torch import Tensor, nn
from torch.amp import custom_fwd
from torch.cuda.amp import custom_fwd

from .focal_loss import py_sigmoid_focal_loss

Expand Down
5 changes: 3 additions & 2 deletions tests/unit/core/model/test_detection_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def label_info(self) -> LabelInfo:
return LabelInfo(
label_names=["label_0", "label_1"],
label_groups=[["label_0", "label_1"]],
label_ids=["0", "1"],
)

@pytest.fixture()
Expand Down Expand Up @@ -61,8 +62,8 @@ def test_export_parameters(self, model):
("label_info", "expected_label_info"),
[
(
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]),
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]),
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]),
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]),
),
(LabelInfo.from_num_classes(num_classes=5), LabelInfo.from_num_classes(num_classes=5)),
],
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/core/model/test_keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def label_info(self) -> LabelInfo:
return LabelInfo(
label_names=["label_0", "label_1"],
label_groups=[["label_0", "label_1"]],
label_ids=["0", "1"],
)

@pytest.fixture()
Expand Down Expand Up @@ -61,8 +62,8 @@ def test_export_parameters(self, model):
("label_info", "expected_label_info"),
[
(
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]),
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]]),
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]),
LabelInfo(label_names=["label1", "label2", "label3"], label_groups=[["label1", "label2", "label3"]], label_ids=["0", "1", "2"]),
),
(LabelInfo.from_num_classes(num_classes=5), LabelInfo.from_num_classes(num_classes=5)),
],
Expand Down

0 comments on commit e51a403

Please sign in to comment.