From edf22a07dfefc44ece2b8f088c9a155723e21ba3 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Thu, 19 Dec 2024 16:45:55 +0000 Subject: [PATCH] Refactor DFINE HybridEncoderModule to improve code clarity and remove redundant parameters --- .../detectors/dfine/hybrid_encoder.py | 372 ++++++++---------- 1 file changed, 165 insertions(+), 207 deletions(-) diff --git a/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py b/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py index 5c8c6a5e82..b796281113 100644 --- a/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py +++ b/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py @@ -1,246 +1,178 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""D-FINE Hybrid Encoder. Modified from D-FINE (https://github.com/Peterande/D-FINE)""" +"""D-FINE Hybrid Encoder. Modified from D-FINE (https://github.com/Peterande/D-FINE).""" from __future__ import annotations import copy from collections import OrderedDict +from functools import partial from typing import Any, Callable, ClassVar import torch -import torch.nn.functional as F +import torch.nn.functional as f from torch import Tensor, nn from otx.algo.common.layers.transformer_layers import TransformerEncoder, TransformerEncoderLayer +from otx.algo.detection.layers.csp_layer import CSPRepLayer +from otx.algo.detection.utils.utils import auto_pad +from otx.algo.modules.activation import build_activation_layer +from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer -from .utils import get_activation +class SCDown(nn.Module): + """SCDown downsampling module. -class FusedConvNormLayer(nn.Module): - def __init__( - self, - ch_in, - ch_out, - kernel_size, - stride, - g=1, - padding=None, - bias=False, - act=None, - ): - super().__init__() - padding = (kernel_size - 1) // 2 if padding is None else padding - self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias) - self.norm = nn.BatchNorm2d(ch_out) - self.act = nn.Identity() if act is None else get_activation(act) - self.ch_in, self.ch_out, self.kernel_size, self.stride, self.g, self.padding, self.bias = ( - ch_in, - ch_out, - kernel_size, - stride, - g, - padding, - bias, - ) - - def forward(self, x): - if hasattr(self, "conv_bn_fused"): - y = self.conv_bn_fused(x) - else: - y = self.norm(self.conv(x)) - return self.act(y) - - def convert_to_deploy(self): - if not hasattr(self, "conv_bn_fused"): - self.conv_bn_fused = nn.Conv2d( - self.ch_in, - self.ch_out, - self.kernel_size, - self.stride, - groups=self.g, - padding=self.padding, - bias=True, - ) - - kernel, bias = self.get_equivalent_kernel_bias() - self.conv_bn_fused.weight.data = kernel - self.conv_bn_fused.bias.data = bias - self.__delattr__("conv") - self.__delattr__("norm") - - def get_equivalent_kernel_bias(self): - kernel3x3, bias3x3 = self._fuse_bn_tensor() - - return kernel3x3, bias3x3 - - def _fuse_bn_tensor(self): - kernel = self.conv.weight - running_mean = self.norm.running_mean - running_var = self.norm.running_var - gamma = self.norm.weight - beta = self.norm.bias - eps = self.norm.eps - std = (running_var + eps).sqrt() - t = (gamma / std).reshape(-1, 1, 1, 1) - return kernel * t, beta - running_mean * gamma / std - + Args: + c1 (int): Number of channels in the input feature map. + c2 (int): Number of channels produced by the convolution. + k (int): Kernel size of the convolving kernel. + s (int): Stride of the convolution. + normalization (Callable[..., nn.Module] | None): Normalization layer module. + """ -class ConvNormLayer(nn.Module): def __init__( self, - ch_in, - ch_out, - kernel_size, - stride, - g=1, - padding=None, - bias=False, - act=None, - ): + c1: int, + c2: int, + k: int, + s: int, + normalization: Callable[..., nn.Module] | None = None, + ) -> None: super().__init__() - padding = (kernel_size - 1) // 2 if padding is None else padding - self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias) - self.norm = nn.BatchNorm2d(ch_out) - self.act = nn.Identity() if act is None else get_activation(act) + self.cv1 = Conv2dModule( + c1, + c2, + 1, + 1, + normalization=build_norm_layer(normalization, num_features=c2), + activation=None, + ) + self.cv2 = Conv2dModule( + c2, + c2, + k, + s, + padding=auto_pad(kernel_size=k), + groups=c2, + normalization=build_norm_layer(normalization, num_features=c2), + activation=None, + ) - def forward(self, x): - return self.act(self.norm(self.conv(x))) + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.cv2(self.cv1(x)) -class SCDown(nn.Module): - def __init__(self, c1, c2, k, s): - super().__init__() - self.cv1 = FusedConvNormLayer(c1, c2, 1, 1) - self.cv2 = FusedConvNormLayer(c2, c2, k, s, c2) +class RepNCSPELAN4(nn.Module): + """GELANModule from YOLOv9. - def forward(self, x): - return self.cv2(self.cv1(x)) + Note: + Might not be replaceable as layer implementation is very different from GELANModule in YOLOv9. + Args: + c1 (int): c1 channel size. Refer to GELAN paper. + c2 (int): c2 channel size. Refer to GELAN paper. + c3 (int): c3 channel size. Refer to GELAN paper. + c4 (int): c4 channel size. Refer to GELAN paper. + n (int, optional): number of blocks. Defaults to 3. + bias (bool, optional): _description_. Defaults to False. + activation (Callable[..., nn.Module] | None, optional): _description_. Defaults to None. + normalization (Callable[..., nn.Module] | None, optional): _description_. Defaults to None. + """ -class VGGBlock(nn.Module): def __init__( self, - ch_in, - ch_out, - act="relu", - ): + c1: int, + c2: int, + c3: int, + c4: int, + num_blocks: int = 3, + bias: bool = False, + activation: Callable[..., nn.Module] | None = None, + normalization: Callable[..., nn.Module] | None = None, + ) -> None: super().__init__() - self.ch_in = ch_in - self.ch_out = ch_out - self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None) - self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None) - self.act = nn.Identity() if act is None else act - - def forward(self, x): - if hasattr(self, "conv"): - y = self.conv(x) - else: - y = self.conv1(x) + self.conv2(x) - - return self.act(y) - - def convert_to_deploy(self): - if not hasattr(self, "conv"): - self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) - - kernel, bias = self.get_equivalent_kernel_bias() - self.conv.weight.data = kernel - self.conv.bias.data = bias - self.__delattr__("conv1") - self.__delattr__("conv2") - - def get_equivalent_kernel_bias(self): - kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) - kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) - - return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 - - def _pad_1x1_to_3x3_tensor(self, kernel1x1): - if kernel1x1 is None: - return 0 - else: - return F.pad(kernel1x1, [1, 1, 1, 1]) - - def _fuse_bn_tensor(self, branch: ConvNormLayer): - if branch is None: - return 0, 0 - kernel = branch.conv.weight - running_mean = branch.norm.running_mean - running_var = branch.norm.running_var - gamma = branch.norm.weight - beta = branch.norm.bias - eps = branch.norm.eps - std = (running_var + eps).sqrt() - t = (gamma / std).reshape(-1, 1, 1, 1) - return kernel * t, beta - running_mean * gamma / std + self.c = c3 // 2 + self.cv1 = Conv2dModule( + c1, + c3, + 1, + 1, + bias=bias, + activation=build_activation_layer(activation), + normalization=build_norm_layer(normalization, num_features=c3), + ) -class RepNCSPELAN4(nn.Module): - # csp-elan - def __init__( - self, - c1, - c2, - c3, - c4, - n=3, - bias=False, - act="silu", - ): - super().__init__() - self.c = c3 // 2 - self.cv1 = FusedConvNormLayer(c1, c3, 1, 1, bias=bias, act=act) self.cv2 = nn.Sequential( - CSPLayer(c3 // 2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), - FusedConvNormLayer(c4, c4, 3, 1, bias=bias, act=act), + CSPRepLayer( + c3 // 2, + c4, + num_blocks, + 1, + bias=bias, + activation=activation, + normalization=normalization, + ), + Conv2dModule( + c4, + c4, + 3, + 1, + padding=auto_pad(kernel_size=3), + bias=bias, + activation=build_activation_layer(activation), + normalization=build_norm_layer(normalization, num_features=c4), + ), ) + self.cv3 = nn.Sequential( - CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), - FusedConvNormLayer(c4, c4, 3, 1, bias=bias, act=act), + CSPRepLayer( + c4, + c4, + num_blocks, + 1, + bias=bias, + activation=activation, + normalization=normalization, + ), + Conv2dModule( + c4, + c4, + 3, + 1, + padding=auto_pad(kernel_size=3), + bias=bias, + activation=build_activation_layer(activation), + normalization=build_norm_layer(normalization, num_features=c4), + ), + ) + + self.cv4 = Conv2dModule( + c3 + (2 * c4), + c2, + 1, + 1, + bias=bias, + activation=build_activation_layer(activation), + normalization=build_norm_layer(normalization, num_features=c2), ) - self.cv4 = FusedConvNormLayer(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" y = list(self.cv1(x).split((self.c, self.c), 1)) y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) return self.cv4(torch.cat(y, 1)) -class CSPLayer(nn.Module): - def __init__( - self, - in_channels, - out_channels, - num_blocks=3, - expansion=1.0, - bias=False, - act="silu", - bottletype=VGGBlock, - ): - super().__init__() - hidden_channels = int(out_channels * expansion) - self.conv1 = FusedConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) - self.conv2 = FusedConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) - self.bottlenecks = nn.Sequential( - *[bottletype(hidden_channels, hidden_channels, act=get_activation(act)) for _ in range(num_blocks)], - ) - if hidden_channels != out_channels: - self.conv3 = FusedConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act) - else: - self.conv3 = nn.Identity() - - def forward(self, x): - x_1 = self.conv1(x) - x_1 = self.bottlenecks(x_1) - x_2 = self.conv2(x) - return self.conv3(x_1 + x_2) - - class HybridEncoderModule(nn.Module): """HybridEncoder for DFine. + TODO(Eugene): Merge with current rtdetr.HybridEncoderModule in next PR. + Args: in_channels (list[int], optional): List of input channels for each feature map. Defaults to [512, 1024, 2048]. @@ -274,18 +206,20 @@ class HybridEncoderModule(nn.Module): def __init__( self, - in_channels: list[int] = [512, 1024, 2048], - feat_strides: list[int] = [8, 16, 32], + in_channels: list[int] = [512, 1024, 2048], # noqa: B006 + feat_strides: list[int] = [8, 16, 32], # noqa: B006 hidden_dim: int = 256, nhead: int = 8, dim_feedforward: int = 1024, dropout: float = 0.0, enc_activation: Callable[..., nn.Module] = nn.GELU, - use_encoder_idx: list[int] = [2], + normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, layer_name="norm"), + use_encoder_idx: list[int] = [2], # noqa: B006 num_encoder_layers: int = 1, pe_temperature: int = 10000, expansion: float = 1.0, depth_mult: float = 1.0, + activation: Callable[..., nn.Module] = nn.SiLU, eval_spatial_size: tuple[int, int] | None = None, ): super().__init__() @@ -326,12 +260,20 @@ def __init__( [TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))], ) - # NOTE: Below code start to divert from rtdetr.HybridEncoder # top-down fpn self.lateral_convs = nn.ModuleList() self.fpn_blocks = nn.ModuleList() for _ in range(len(in_channels) - 1, 0, -1): - self.lateral_convs.append(FusedConvNormLayer(hidden_dim, hidden_dim, 1, 1)) + self.lateral_convs.append( + Conv2dModule( + hidden_dim, + hidden_dim, + 1, + 1, + normalization=build_norm_layer(normalization, num_features=hidden_dim), + activation=None, + ), + ) self.fpn_blocks.append( RepNCSPELAN4( hidden_dim * 2, @@ -339,6 +281,8 @@ def __init__( hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult), + activation=activation, + normalization=normalization, ), ) @@ -348,7 +292,13 @@ def __init__( for _ in range(len(in_channels) - 1): self.downsample_convs.append( nn.Sequential( - SCDown(hidden_dim, hidden_dim, 3, 2), + SCDown( + hidden_dim, + hidden_dim, + 3, + 2, + normalization=normalization, + ), ), ) self.pan_blocks.append( @@ -358,6 +308,8 @@ def __init__( hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult), + activation=activation, + normalization=normalization, ), ) @@ -417,7 +369,7 @@ def forward(self, feats: torch.Tensor) -> list[torch.Tensor]: src_flatten.device, ) else: - pos_embed = getattr(self, f"pos_embed{enc_ind}", None).to(src_flatten.device) + pos_embed = getattr(self, f"pos_embed{enc_ind}", None) memory = self.encoder[i](src_flatten, pos_embed=pos_embed) proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() @@ -429,7 +381,7 @@ def forward(self, feats: torch.Tensor) -> list[torch.Tensor]: feat_low = proj_feats[idx - 1] feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh) inner_outs[0] = feat_heigh - upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest") + upsample_feat = f.interpolate(feat_heigh, scale_factor=2.0, mode="nearest") inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1)) inner_outs.insert(0, inner_out) @@ -445,6 +397,8 @@ def forward(self, feats: torch.Tensor) -> list[torch.Tensor]: class HybridEncoder: + """HybridEncoder factory for D-Fine detection.""" + encoder_cfg: ClassVar[dict[str, Any]] = { "dfine_hgnetv2_n": { "in_channels": [512, 1024], @@ -476,5 +430,9 @@ class HybridEncoder: }, } - def __new__(cls, model_name) -> HybridEncoderModule: + def __new__(cls, model_name: str) -> HybridEncoderModule: + """Constructor for HybridEncoder.""" + if model_name not in cls.encoder_cfg: + msg = f"model type '{model_name}' is not supported" + raise KeyError(msg) return HybridEncoderModule(**cls.encoder_cfg[model_name])