From a88a9d19f7fbcf79f6cc1b3a1dee3f9f12d7b93d Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Thu, 19 Dec 2024 14:40:54 +0000 Subject: [PATCH] update hybrid encoder --- .../detectors/dfine/hybrid_encoder.py | 241 +++++++++--------- src/otx/algo/detection/layers/csp_layer.py | 8 +- 2 files changed, 133 insertions(+), 116 deletions(-) diff --git a/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py b/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py index 98f44d33cf9..5c8c6a5e820 100644 --- a/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py +++ b/src/otx/algo/detection/detectors/dfine/hybrid_encoder.py @@ -7,17 +7,29 @@ import copy from collections import OrderedDict -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar import torch import torch.nn.functional as F -from torch import nn +from torch import Tensor, nn + +from otx.algo.common.layers.transformer_layers import TransformerEncoder, TransformerEncoderLayer from .utils import get_activation -class ConvNormLayer_fuse(nn.Module): - def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None): +class FusedConvNormLayer(nn.Module): + def __init__( + self, + ch_in, + ch_out, + kernel_size, + stride, + g=1, + padding=None, + bias=False, + act=None, + ): super().__init__() padding = (kernel_size - 1) // 2 if padding is None else padding self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias) @@ -76,7 +88,17 @@ def _fuse_bn_tensor(self): class ConvNormLayer(nn.Module): - def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None): + def __init__( + self, + ch_in, + ch_out, + kernel_size, + stride, + g=1, + padding=None, + bias=False, + act=None, + ): super().__init__() padding = (kernel_size - 1) // 2 if padding is None else padding self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias) @@ -90,15 +112,20 @@ def forward(self, x): class SCDown(nn.Module): def __init__(self, c1, c2, k, s): super().__init__() - self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1) - self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2) + self.cv1 = FusedConvNormLayer(c1, c2, 1, 1) + self.cv2 = FusedConvNormLayer(c2, c2, k, s, c2) def forward(self, x): return self.cv2(self.cv1(x)) class VGGBlock(nn.Module): - def __init__(self, ch_in, ch_out, act="relu"): + def __init__( + self, + ch_in, + ch_out, + act="relu", + ): super().__init__() self.ch_in = ch_in self.ch_out = ch_out @@ -152,19 +179,28 @@ def _fuse_bn_tensor(self, branch: ConvNormLayer): class RepNCSPELAN4(nn.Module): # csp-elan - def __init__(self, c1, c2, c3, c4, n=3, bias=False, act="silu"): + def __init__( + self, + c1, + c2, + c3, + c4, + n=3, + bias=False, + act="silu", + ): super().__init__() self.c = c3 // 2 - self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act) + self.cv1 = FusedConvNormLayer(c1, c3, 1, 1, bias=bias, act=act) self.cv2 = nn.Sequential( CSPLayer(c3 // 2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), - ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act), + FusedConvNormLayer(c4, c4, 3, 1, bias=bias, act=act), ) self.cv3 = nn.Sequential( CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), - ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act), + FusedConvNormLayer(c4, c4, 3, 1, bias=bias, act=act), ) - self.cv4 = ConvNormLayer_fuse(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act) + self.cv4 = FusedConvNormLayer(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act) def forward(self, x): y = list(self.cv1(x).split((self.c, self.c), 1)) @@ -185,13 +221,13 @@ def __init__( ): super().__init__() hidden_channels = int(out_channels * expansion) - self.conv1 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act) - self.conv2 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.conv1 = FusedConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.conv2 = FusedConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) self.bottlenecks = nn.Sequential( *[bottletype(hidden_channels, hidden_channels, act=get_activation(act)) for _ in range(num_blocks)], ) if hidden_channels != out_channels: - self.conv3 = ConvNormLayer_fuse(hidden_channels, out_channels, 1, 1, bias=bias, act=act) + self.conv3 = FusedConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act) else: self.conv3 = nn.Identity() @@ -202,90 +238,55 @@ def forward(self, x): return self.conv3(x_1 + x_2) -class TransformerEncoderLayer(nn.Module): - def __init__( - self, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - normalize_before=False, - ): - super().__init__() - self.normalize_before = normalize_before - - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True) - - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.activation = get_activation(activation) - - @staticmethod - def with_pos_embed(tensor, pos_embed): - return tensor if pos_embed is None else tensor + pos_embed - - def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: - residual = src - if self.normalize_before: - src = self.norm1(src) - q = k = self.with_pos_embed(src, pos_embed) - src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask) - - src = residual + self.dropout1(src) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src) - if not self.normalize_before: - src = self.norm2(src) - return src - - -class TransformerEncoder(nn.Module): - def __init__(self, encoder_layer, num_layers, norm=None): - super().__init__() - self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) - self.num_layers = num_layers - self.norm = norm - - def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: - output = src - for layer in self.layers: - output = layer(output, src_mask=src_mask, pos_embed=pos_embed) - - if self.norm is not None: - output = self.norm(output) - - return output - - class HybridEncoderModule(nn.Module): + """HybridEncoder for DFine. + + Args: + in_channels (list[int], optional): List of input channels for each feature map. + Defaults to [512, 1024, 2048]. + feat_strides (list[int], optional): List of stride values for + each feature map. Defaults to [8, 16, 32]. + hidden_dim (int, optional): Hidden dimension size. Defaults to 256. + nhead (int, optional): Number of attention heads in the transformer encoder. + Defaults to 8. + dim_feedforward (int, optional): Dimension of the feedforward network + in the transformer encoder. Defaults to 1024. + dropout (float, optional): Dropout rate. Defaults to 0.0. + enc_activation (Callable[..., nn.Module]): Activation layer module. + Defaults to ``nn.GELU``. + normalization (Callable[..., nn.Module]): Normalization layer module. + Defaults to ``partial(build_norm_layer, nn.BatchNorm2d, layer_name="norm")``. + use_encoder_idx (list[int], optional): List of indices of the encoder to use. + Defaults to [2]. + num_encoder_layers (int, optional): Number of layers in the transformer encoder. + Defaults to 1. + pe_temperature (float, optional): Temperature parameter for positional encoding. + Defaults to 10000. + expansion (float, optional): Expansion factor for the CSPRepLayer. + Defaults to 1.0. + depth_mult (float, optional): Depth multiplier for the CSPRepLayer. + Defaults to 1.0. + activation (Callable[..., nn.Module]): Activation layer module. + Defaults to ``nn.SiLU``. + eval_spatial_size (tuple[int, int] | None, optional): Spatial size for + evaluation. Defaults to None. + """ + def __init__( self, - in_channels=[512, 1024, 2048], - feat_strides=[8, 16, 32], - hidden_dim=256, - nhead=8, - dim_feedforward=1024, - dropout=0.0, - enc_act="gelu", - use_encoder_idx=[2], - num_encoder_layers=1, - pe_temperature=10000, - expansion=1.0, - depth_mult=1.0, - eval_spatial_size=None, + in_channels: list[int] = [512, 1024, 2048], + feat_strides: list[int] = [8, 16, 32], + hidden_dim: int = 256, + nhead: int = 8, + dim_feedforward: int = 1024, + dropout: float = 0.0, + enc_activation: Callable[..., nn.Module] = nn.GELU, + use_encoder_idx: list[int] = [2], + num_encoder_layers: int = 1, + pe_temperature: int = 10000, + expansion: float = 1.0, + depth_mult: float = 1.0, + eval_spatial_size: tuple[int, int] | None = None, ): super().__init__() self.in_channels = in_channels @@ -301,35 +302,36 @@ def __init__( # channel projection self.input_proj = nn.ModuleList() for in_channel in in_channels: - proj = nn.Sequential( - OrderedDict( - [ - ("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), - ("norm", nn.BatchNorm2d(hidden_dim)), - ], + self.input_proj.append( + nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), + ("norm", nn.BatchNorm2d(hidden_dim)), + ], + ), ), ) - self.input_proj.append(proj) - # encoder transformer encoder_layer = TransformerEncoderLayer( hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, - activation=enc_act, + activation=enc_activation, ) self.encoder = nn.ModuleList( [TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))], ) + # NOTE: Below code start to divert from rtdetr.HybridEncoder # top-down fpn self.lateral_convs = nn.ModuleList() self.fpn_blocks = nn.ModuleList() for _ in range(len(in_channels) - 1, 0, -1): - self.lateral_convs.append(ConvNormLayer_fuse(hidden_dim, hidden_dim, 1, 1)) + self.lateral_convs.append(FusedConvNormLayer(hidden_dim, hidden_dim, 1, 1)) self.fpn_blocks.append( RepNCSPELAN4( hidden_dim * 2, @@ -361,7 +363,8 @@ def __init__( self._reset_parameters() - def _reset_parameters(self): + def _reset_parameters(self) -> None: + """Reset parameters.""" if self.eval_spatial_size: for idx in self.use_encoder_idx: stride = self.feat_strides[idx] @@ -374,12 +377,19 @@ def _reset_parameters(self): setattr(self, f"pos_embed{idx}", pos_embed) @staticmethod - def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): - """ """ + def build_2d_sincos_position_embedding( + w: int, + h: int, + embed_dim: int = 256, + temperature: float = 10000.0, + ) -> Tensor: + """Build 2D sin-cos position embedding.""" grid_w = torch.arange(int(w), dtype=torch.float32) grid_h = torch.arange(int(h), dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") - assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + if embed_dim % 4 != 0: + msg = "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + raise ValueError(msg) pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1.0 / (temperature**omega) @@ -389,8 +399,11 @@ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0) return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] - def forward(self, feats): - assert len(feats) == len(self.in_channels) + def forward(self, feats: torch.Tensor) -> list[torch.Tensor]: + """Forward pass.""" + if len(feats) != len(self.in_channels): + msg = f"Input feature size {len(feats)} does not match the number of input channels {len(self.in_channels)}" + raise ValueError(msg) proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] # encoder @@ -406,7 +419,7 @@ def forward(self, feats): else: pos_embed = getattr(self, f"pos_embed{enc_ind}", None).to(src_flatten.device) - memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed) + memory = self.encoder[i](src_flatten, pos_embed=pos_embed) proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() # broadcasting and fusion diff --git a/src/otx/algo/detection/layers/csp_layer.py b/src/otx/algo/detection/layers/csp_layer.py index 6a1c32f5693..35e2f9ddc60 100644 --- a/src/otx/algo/detection/layers/csp_layer.py +++ b/src/otx/algo/detection/layers/csp_layer.py @@ -193,7 +193,11 @@ def __init__( normalization=build_norm_layer(normalization, num_features=ch_out), activation=None, ) - self.act = activation() if activation else nn.Identity() + if isinstance(activation, type): + activation = activation() + if activation is None: + activation = nn.Identity() + self.act = activation def forward(self, x: Tensor) -> Tensor: """Forward function.""" @@ -378,7 +382,7 @@ def __init__( RepVggBlock( hidden_channels, hidden_channels, - activation=activation, + activation=build_activation_layer(activation), normalization=normalization, ) for _ in range(num_blocks)