From 133e507faa0c9999ff731039f0b26f73924308a8 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 27 May 2024 11:57:22 +0300 Subject: [PATCH] Make YoloNASRDFLHead inherit a base YoloNASDFLHead --- .../detection_models/yolo_nas/dfl_heads.py | 1 - .../yolo_nas_r/yolo_nas_r_dfl_head.py | 67 ++++++------------- .../unit_tests/export_detection_model_test.py | 13 ++++ 3 files changed, 34 insertions(+), 47 deletions(-) diff --git a/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py b/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py index e8bba7e4a1..e15f00bf45 100644 --- a/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py +++ b/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py @@ -68,7 +68,6 @@ def __init__( self.cls_dropout_rate = nn.Dropout2d(cls_dropout_rate) if cls_dropout_rate > 0 else nn.Identity() self.reg_dropout_rate = nn.Dropout2d(reg_dropout_rate) if reg_dropout_rate > 0 else nn.Identity() - self.grid = torch.zeros(1) self.stride = stride self.prior_prob = 1e-2 diff --git a/src/super_gradients/training/models/detection_models/yolo_nas_r/yolo_nas_r_dfl_head.py b/src/super_gradients/training/models/detection_models/yolo_nas_r/yolo_nas_r_dfl_head.py index 857c059bf2..fadef77229 100644 --- a/src/super_gradients/training/models/detection_models/yolo_nas_r/yolo_nas_r_dfl_head.py +++ b/src/super_gradients/training/models/detection_models/yolo_nas_r/yolo_nas_r_dfl_head.py @@ -1,18 +1,18 @@ -import math -from typing import Tuple, Callable +from typing import Tuple import torch -from torch import nn, Tensor - from super_gradients.common.registry import register_detection_module -from super_gradients.module_interfaces import SupportsReplaceNumClasses -from super_gradients.modules import ConvBNReLU -from super_gradients.modules.base_modules import BaseDetectionModule from super_gradients.modules.utils import width_multiplier +from super_gradients.training.models.detection_models.yolo_nas import YoloNASDFLHead +from torch import nn, Tensor @register_detection_module() -class YoloNASRDFLHead(BaseDetectionModule, SupportsReplaceNumClasses): +class YoloNASRDFLHead(YoloNASDFLHead): + """ + YoloNASRDFLHead is a YoloNASDFLHead with additional outputs for rotated bounding boxes. + """ + def __init__( self, in_channels: int, @@ -37,41 +37,24 @@ def __init__( :param cls_dropout_rate: Dropout rate for the classification head :param reg_dropout_rate: Dropout rate for the regression head """ - super().__init__(in_channels) - + super().__init__( + in_channels=in_channels, + inter_channels=inter_channels, + width_mult=width_mult, + first_conv_group_size=first_conv_group_size, + num_classes=num_classes, + stride=stride, + reg_max=reg_max, + cls_dropout_rate=cls_dropout_rate, + reg_dropout_rate=reg_dropout_rate, + ) inter_channels = width_multiplier(inter_channels, width_mult, 8) - if first_conv_group_size == 0: - groups = 0 - elif first_conv_group_size == -1: - groups = 1 - else: - groups = inter_channels // first_conv_group_size - self.num_classes = num_classes - self.stem = ConvBNReLU(in_channels, inter_channels, kernel_size=1, stride=1, padding=0, bias=False) - - first_cls_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else [] - self.cls_convs = nn.Sequential(*first_cls_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False)) - - first_reg_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else [] - self.reg_convs = nn.Sequential(*first_reg_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False)) - - self.cls_pred = nn.Conv2d(inter_channels, self.num_classes, 1, 1, 0) self.reg_pred = nn.Conv2d(inter_channels, 2 * (reg_max + 1), 1, 1, 0) self.rot_pred = nn.Conv2d(inter_channels, 1, kernel_size=1, stride=1, padding=0) self.offset_pred = nn.Conv2d(inter_channels, 2, kernel_size=1, stride=1, padding=0) - - self.cls_dropout_rate = nn.Dropout2d(cls_dropout_rate) if cls_dropout_rate > 0 else nn.Identity() - self.reg_dropout_rate = nn.Dropout2d(reg_dropout_rate) if reg_dropout_rate > 0 else nn.Identity() - - self.stride = stride - - self.prior_prob = 1e-2 - self._initialize_biases() - - def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]): - self.cls_pred = compute_new_weights_fn(self.cls_pred, num_classes) - self.num_classes = num_classes + torch.nn.init.zeros_(self.offset_pred.weight) + torch.nn.init.zeros_(self.offset_pred.bias) @property def out_channels(self): @@ -100,11 +83,3 @@ def forward(self, x) -> Tuple[Tensor, Tensor, Tensor, Tensor]: offset_output = self.offset_pred(reg_feat) return reg_output, cls_output, offset_output, rot_output - - def _initialize_biases(self): - prior_bias = -math.log((1 - self.prior_prob) / self.prior_prob) - torch.nn.init.zeros_(self.cls_pred.weight) - torch.nn.init.constant_(self.cls_pred.bias, prior_bias) - - torch.nn.init.zeros_(self.offset_pred.weight) - torch.nn.init.zeros_(self.offset_pred.bias) diff --git a/tests/unit_tests/export_detection_model_test.py b/tests/unit_tests/export_detection_model_test.py index fdc12695ed..ec8fdd7fc0 100644 --- a/tests/unit_tests/export_detection_model_test.py +++ b/tests/unit_tests/export_detection_model_test.py @@ -785,6 +785,19 @@ def test_onnx_nms_batch_result(self): np.testing.assert_allclose(torch_result[2].numpy(), onnx_result[2], rtol=1e-3, atol=1e-3) np.testing.assert_allclose(torch_result[3].numpy(), onnx_result[3], rtol=1e-3, atol=1e-3) + def test_yolo_nas_r_export(self): + """ + Test the most common export use case - export to ONNX with all default parameters + """ + with tempfile.TemporaryDirectory() as tmpdirname: + out_path = os.path.join(tmpdirname, "yolo_nas_r_s.onnx") + + model = models.get(Models.YOLO_NAS_R_S, pretrained_weights="dota2") + result = model.export(out_path) + assert result.input_image_dtype == torch.uint8 + assert result.input_image_shape == (1024, 1024) + assert result.input_image_channels == 3 + def _get_image_as_bchw(self, image_shape=(640, 640)): """