From dff6ecc72cfd13c84378808871e14f09d9a2d4c5 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Fri, 20 Dec 2024 15:25:33 +0000 Subject: [PATCH] Update d-fine unit test --- tests/unit/algo/detection/test_dfine.py | 49 +++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/unit/algo/detection/test_dfine.py b/tests/unit/algo/detection/test_dfine.py index 6a48014643..91b26dd01c 100644 --- a/tests/unit/algo/detection/test_dfine.py +++ b/tests/unit/algo/detection/test_dfine.py @@ -9,13 +9,62 @@ import torch import torchvision from otx.algo.detection.backbones.hgnetv2 import HGNetv2 +from otx.algo.detection.d_fine import DFine from otx.algo.detection.heads.dfine_decoder import DFINETransformer from otx.algo.detection.losses.dfine_loss import DFINECriterion from otx.algo.detection.necks.dfine_hybrid_encoder import HybridEncoder from otx.algo.detection.rtdetr import DETR +from otx.core.data.entity.detection import DetBatchPredEntity class TestDFine: + @pytest.mark.parametrize( + "model", + [ + DFine(label_info=3, model_name="dfine_hgnetv2_n"), + DFine(label_info=3, model_name="dfine_hgnetv2_s"), + DFine(label_info=3, model_name="dfine_hgnetv2_m"), + DFine(label_info=3, model_name="dfine_hgnetv2_l"), + DFine(label_info=3, model_name="dfine_hgnetv2_x"), + ], + ) + def test_loss(self, model, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = torch.randn([2, 3, 640, 640]) + model(data) + + @pytest.mark.parametrize( + "model", + [ + DFine(label_info=3, model_name="dfine_hgnetv2_n"), + DFine(label_info=3, model_name="dfine_hgnetv2_s"), + DFine(label_info=3, model_name="dfine_hgnetv2_m"), + DFine(label_info=3, model_name="dfine_hgnetv2_l"), + DFine(label_info=3, model_name="dfine_hgnetv2_x"), + ], + ) + def test_predict(self, model, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = torch.randn(2, 3, 640, 640) + model.eval() + output = model(data) + assert isinstance(output, DetBatchPredEntity) + + @pytest.mark.parametrize( + "model", + [ + DFine(label_info=3, model_name="dfine_hgnetv2_n"), + DFine(label_info=3, model_name="dfine_hgnetv2_s"), + DFine(label_info=3, model_name="dfine_hgnetv2_m"), + DFine(label_info=3, model_name="dfine_hgnetv2_l"), + DFine(label_info=3, model_name="dfine_hgnetv2_x"), + ], + ) + def test_export(self, model): + model.eval() + output = model.forward_for_tracing(torch.randn(1, 3, 640, 640)) + assert len(output) == 3 + @pytest.fixture() def dfine_model(self): num_classes = 10