Skip to content

Commit

Permalink
update hybrid encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed Dec 19, 2024
1 parent 7ee9fa2 commit a88a9d1
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 116 deletions.
241 changes: 127 additions & 114 deletions src/otx/algo/detection/detectors/dfine/hybrid_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,29 @@

import copy
from collections import OrderedDict
from typing import Any, ClassVar
from typing import Any, Callable, ClassVar

import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor, nn

from otx.algo.common.layers.transformer_layers import TransformerEncoder, TransformerEncoderLayer

from .utils import get_activation


class ConvNormLayer_fuse(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
class FusedConvNormLayer(nn.Module):
def __init__(
self,
ch_in,
ch_out,
kernel_size,
stride,
g=1,
padding=None,
bias=False,
act=None,
):
super().__init__()
padding = (kernel_size - 1) // 2 if padding is None else padding
self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias)
Expand Down Expand Up @@ -76,7 +88,17 @@ def _fuse_bn_tensor(self):


class ConvNormLayer(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
def __init__(
self,
ch_in,
ch_out,
kernel_size,
stride,
g=1,
padding=None,
bias=False,
act=None,
):
super().__init__()
padding = (kernel_size - 1) // 2 if padding is None else padding
self.conv = nn.Conv2d(ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias)
Expand All @@ -90,15 +112,20 @@ def forward(self, x):
class SCDown(nn.Module):
def __init__(self, c1, c2, k, s):
super().__init__()
self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1)
self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2)
self.cv1 = FusedConvNormLayer(c1, c2, 1, 1)
self.cv2 = FusedConvNormLayer(c2, c2, k, s, c2)

def forward(self, x):
return self.cv2(self.cv1(x))


class VGGBlock(nn.Module):
def __init__(self, ch_in, ch_out, act="relu"):
def __init__(
self,
ch_in,
ch_out,
act="relu",
):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
Expand Down Expand Up @@ -152,19 +179,28 @@ def _fuse_bn_tensor(self, branch: ConvNormLayer):

class RepNCSPELAN4(nn.Module):
# csp-elan
def __init__(self, c1, c2, c3, c4, n=3, bias=False, act="silu"):
def __init__(
self,
c1,
c2,
c3,
c4,
n=3,
bias=False,
act="silu",
):
super().__init__()
self.c = c3 // 2
self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act)
self.cv1 = FusedConvNormLayer(c1, c3, 1, 1, bias=bias, act=act)
self.cv2 = nn.Sequential(
CSPLayer(c3 // 2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock),
ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act),
FusedConvNormLayer(c4, c4, 3, 1, bias=bias, act=act),
)
self.cv3 = nn.Sequential(
CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock),
ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act),
FusedConvNormLayer(c4, c4, 3, 1, bias=bias, act=act),
)
self.cv4 = ConvNormLayer_fuse(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act)
self.cv4 = FusedConvNormLayer(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act)

def forward(self, x):
y = list(self.cv1(x).split((self.c, self.c), 1))
Expand All @@ -185,13 +221,13 @@ def __init__(
):
super().__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.conv2 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.conv1 = FusedConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.conv2 = FusedConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(
*[bottletype(hidden_channels, hidden_channels, act=get_activation(act)) for _ in range(num_blocks)],
)
if hidden_channels != out_channels:
self.conv3 = ConvNormLayer_fuse(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
self.conv3 = FusedConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
else:
self.conv3 = nn.Identity()

Expand All @@ -202,90 +238,55 @@ def forward(self, x):
return self.conv3(x_1 + x_2)


class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.normalize_before = normalize_before

self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)

self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)

self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = get_activation(activation)

@staticmethod
def with_pos_embed(tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed

def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
residual = src
if self.normalize_before:
src = self.norm1(src)
q = k = self.with_pos_embed(src, pos_embed)
src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask)

src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)

residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src


class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = norm

def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
output = src
for layer in self.layers:
output = layer(output, src_mask=src_mask, pos_embed=pos_embed)

if self.norm is not None:
output = self.norm(output)

return output


class HybridEncoderModule(nn.Module):
"""HybridEncoder for DFine.
Args:
in_channels (list[int], optional): List of input channels for each feature map.
Defaults to [512, 1024, 2048].
feat_strides (list[int], optional): List of stride values for
each feature map. Defaults to [8, 16, 32].
hidden_dim (int, optional): Hidden dimension size. Defaults to 256.
nhead (int, optional): Number of attention heads in the transformer encoder.
Defaults to 8.
dim_feedforward (int, optional): Dimension of the feedforward network
in the transformer encoder. Defaults to 1024.
dropout (float, optional): Dropout rate. Defaults to 0.0.
enc_activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.GELU``.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``partial(build_norm_layer, nn.BatchNorm2d, layer_name="norm")``.
use_encoder_idx (list[int], optional): List of indices of the encoder to use.
Defaults to [2].
num_encoder_layers (int, optional): Number of layers in the transformer encoder.
Defaults to 1.
pe_temperature (float, optional): Temperature parameter for positional encoding.
Defaults to 10000.
expansion (float, optional): Expansion factor for the CSPRepLayer.
Defaults to 1.0.
depth_mult (float, optional): Depth multiplier for the CSPRepLayer.
Defaults to 1.0.
activation (Callable[..., nn.Module]): Activation layer module.
Defaults to ``nn.SiLU``.
eval_spatial_size (tuple[int, int] | None, optional): Spatial size for
evaluation. Defaults to None.
"""

def __init__(
self,
in_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
hidden_dim=256,
nhead=8,
dim_feedforward=1024,
dropout=0.0,
enc_act="gelu",
use_encoder_idx=[2],
num_encoder_layers=1,
pe_temperature=10000,
expansion=1.0,
depth_mult=1.0,
eval_spatial_size=None,
in_channels: list[int] = [512, 1024, 2048],
feat_strides: list[int] = [8, 16, 32],
hidden_dim: int = 256,
nhead: int = 8,
dim_feedforward: int = 1024,
dropout: float = 0.0,
enc_activation: Callable[..., nn.Module] = nn.GELU,
use_encoder_idx: list[int] = [2],
num_encoder_layers: int = 1,
pe_temperature: int = 10000,
expansion: float = 1.0,
depth_mult: float = 1.0,
eval_spatial_size: tuple[int, int] | None = None,
):
super().__init__()
self.in_channels = in_channels
Expand All @@ -301,35 +302,36 @@ def __init__(
# channel projection
self.input_proj = nn.ModuleList()
for in_channel in in_channels:
proj = nn.Sequential(
OrderedDict(
[
("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)),
("norm", nn.BatchNorm2d(hidden_dim)),
],
self.input_proj.append(
nn.Sequential(
OrderedDict(
[
("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)),
("norm", nn.BatchNorm2d(hidden_dim)),
],
),
),
)

self.input_proj.append(proj)

# encoder transformer
encoder_layer = TransformerEncoderLayer(
hidden_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=enc_act,
activation=enc_activation,
)

self.encoder = nn.ModuleList(
[TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))],
)

# NOTE: Below code start to divert from rtdetr.HybridEncoder
# top-down fpn
self.lateral_convs = nn.ModuleList()
self.fpn_blocks = nn.ModuleList()
for _ in range(len(in_channels) - 1, 0, -1):
self.lateral_convs.append(ConvNormLayer_fuse(hidden_dim, hidden_dim, 1, 1))
self.lateral_convs.append(FusedConvNormLayer(hidden_dim, hidden_dim, 1, 1))
self.fpn_blocks.append(
RepNCSPELAN4(
hidden_dim * 2,
Expand Down Expand Up @@ -361,7 +363,8 @@ def __init__(

self._reset_parameters()

def _reset_parameters(self):
def _reset_parameters(self) -> None:
"""Reset parameters."""
if self.eval_spatial_size:
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
Expand All @@ -374,12 +377,19 @@ def _reset_parameters(self):
setattr(self, f"pos_embed{idx}", pos_embed)

@staticmethod
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
""" """
def build_2d_sincos_position_embedding(
w: int,
h: int,
embed_dim: int = 256,
temperature: float = 10000.0,
) -> Tensor:
"""Build 2D sin-cos position embedding."""
grid_w = torch.arange(int(w), dtype=torch.float32)
grid_h = torch.arange(int(h), dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
if embed_dim % 4 != 0:
msg = "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
raise ValueError(msg)
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
Expand All @@ -389,8 +399,11 @@ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0)

return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]

def forward(self, feats):
assert len(feats) == len(self.in_channels)
def forward(self, feats: torch.Tensor) -> list[torch.Tensor]:
"""Forward pass."""
if len(feats) != len(self.in_channels):
msg = f"Input feature size {len(feats)} does not match the number of input channels {len(self.in_channels)}"
raise ValueError(msg)
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]

# encoder
Expand All @@ -406,7 +419,7 @@ def forward(self, feats):
else:
pos_embed = getattr(self, f"pos_embed{enc_ind}", None).to(src_flatten.device)

memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed)
memory = self.encoder[i](src_flatten, pos_embed=pos_embed)
proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()

# broadcasting and fusion
Expand Down
8 changes: 6 additions & 2 deletions src/otx/algo/detection/layers/csp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def __init__(
normalization=build_norm_layer(normalization, num_features=ch_out),
activation=None,
)
self.act = activation() if activation else nn.Identity()
if isinstance(activation, type):
activation = activation()
if activation is None:
activation = nn.Identity()
self.act = activation

def forward(self, x: Tensor) -> Tensor:
"""Forward function."""
Expand Down Expand Up @@ -378,7 +382,7 @@ def __init__(
RepVggBlock(
hidden_channels,
hidden_channels,
activation=activation,
activation=build_activation_layer(activation),
normalization=normalization,
)
for _ in range(num_blocks)
Expand Down

0 comments on commit a88a9d1

Please sign in to comment.