Skip to content

Commit

Permalink
Make YoloNASRDFLHead inherit a base YoloNASDFLHead
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed May 27, 2024
1 parent 5165c98 commit 133e507
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(
self.cls_dropout_rate = nn.Dropout2d(cls_dropout_rate) if cls_dropout_rate > 0 else nn.Identity()
self.reg_dropout_rate = nn.Dropout2d(reg_dropout_rate) if reg_dropout_rate > 0 else nn.Identity()

self.grid = torch.zeros(1)
self.stride = stride

self.prior_prob = 1e-2
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import math
from typing import Tuple, Callable
from typing import Tuple

import torch
from torch import nn, Tensor

from super_gradients.common.registry import register_detection_module
from super_gradients.module_interfaces import SupportsReplaceNumClasses
from super_gradients.modules import ConvBNReLU
from super_gradients.modules.base_modules import BaseDetectionModule
from super_gradients.modules.utils import width_multiplier
from super_gradients.training.models.detection_models.yolo_nas import YoloNASDFLHead
from torch import nn, Tensor


@register_detection_module()
class YoloNASRDFLHead(BaseDetectionModule, SupportsReplaceNumClasses):
class YoloNASRDFLHead(YoloNASDFLHead):
"""
YoloNASRDFLHead is a YoloNASDFLHead with additional outputs for rotated bounding boxes.
"""

def __init__(
self,
in_channels: int,
Expand All @@ -37,41 +37,24 @@ def __init__(
:param cls_dropout_rate: Dropout rate for the classification head
:param reg_dropout_rate: Dropout rate for the regression head
"""
super().__init__(in_channels)

super().__init__(
in_channels=in_channels,
inter_channels=inter_channels,
width_mult=width_mult,
first_conv_group_size=first_conv_group_size,
num_classes=num_classes,
stride=stride,
reg_max=reg_max,
cls_dropout_rate=cls_dropout_rate,
reg_dropout_rate=reg_dropout_rate,
)
inter_channels = width_multiplier(inter_channels, width_mult, 8)
if first_conv_group_size == 0:
groups = 0
elif first_conv_group_size == -1:
groups = 1
else:
groups = inter_channels // first_conv_group_size

self.num_classes = num_classes
self.stem = ConvBNReLU(in_channels, inter_channels, kernel_size=1, stride=1, padding=0, bias=False)

first_cls_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else []
self.cls_convs = nn.Sequential(*first_cls_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False))

first_reg_conv = [ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)] if groups else []
self.reg_convs = nn.Sequential(*first_reg_conv, ConvBNReLU(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False))

self.cls_pred = nn.Conv2d(inter_channels, self.num_classes, 1, 1, 0)
self.reg_pred = nn.Conv2d(inter_channels, 2 * (reg_max + 1), 1, 1, 0)
self.rot_pred = nn.Conv2d(inter_channels, 1, kernel_size=1, stride=1, padding=0)
self.offset_pred = nn.Conv2d(inter_channels, 2, kernel_size=1, stride=1, padding=0)

self.cls_dropout_rate = nn.Dropout2d(cls_dropout_rate) if cls_dropout_rate > 0 else nn.Identity()
self.reg_dropout_rate = nn.Dropout2d(reg_dropout_rate) if reg_dropout_rate > 0 else nn.Identity()

self.stride = stride

self.prior_prob = 1e-2
self._initialize_biases()

def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module]):
self.cls_pred = compute_new_weights_fn(self.cls_pred, num_classes)
self.num_classes = num_classes
torch.nn.init.zeros_(self.offset_pred.weight)
torch.nn.init.zeros_(self.offset_pred.bias)

@property
def out_channels(self):
Expand Down Expand Up @@ -100,11 +83,3 @@ def forward(self, x) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
offset_output = self.offset_pred(reg_feat)

return reg_output, cls_output, offset_output, rot_output

def _initialize_biases(self):
prior_bias = -math.log((1 - self.prior_prob) / self.prior_prob)
torch.nn.init.zeros_(self.cls_pred.weight)
torch.nn.init.constant_(self.cls_pred.bias, prior_bias)

torch.nn.init.zeros_(self.offset_pred.weight)
torch.nn.init.zeros_(self.offset_pred.bias)
13 changes: 13 additions & 0 deletions tests/unit_tests/export_detection_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,19 @@ def test_onnx_nms_batch_result(self):
np.testing.assert_allclose(torch_result[2].numpy(), onnx_result[2], rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(torch_result[3].numpy(), onnx_result[3], rtol=1e-3, atol=1e-3)

def test_yolo_nas_r_export(self):
"""
Test the most common export use case - export to ONNX with all default parameters
"""
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "yolo_nas_r_s.onnx")

model = models.get(Models.YOLO_NAS_R_S, pretrained_weights="dota2")
result = model.export(out_path)
assert result.input_image_dtype == torch.uint8
assert result.input_image_shape == (1024, 1024)
assert result.input_image_channels == 3

def _get_image_as_bchw(self, image_shape=(640, 640)):
"""
Expand Down

0 comments on commit 133e507

Please sign in to comment.