Skip to content

Commit

Permalink
feat: add vit_transformer papramer in transunet
Browse files Browse the repository at this point in the history
  • Loading branch information
black0017 committed Mar 23, 2021
1 parent edab5e1 commit 25622d5
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions self_attention_cv/transunet/trans_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, *, img_dim, in_channels, classes,
vit_blocks=12,
vit_heads=4,
vit_dim_linear_mhsa_block=1024,
vit_transformer=None,
vit_channels = None
):
"""
My reimplementation of TransUnet based on the paper:
Expand All @@ -28,32 +30,33 @@ def __init__(self, *, img_dim, in_channels, classes,
vit_blocks: MHSA blocks of ViT
vit_heads: number of MHSA heads
vit_dim_linear_mhsa_block: MHSA MLP dimension
vit_transformer: pass your own version of vit
vit_channels: the channels of your pretrained vit. default is 128*8
"""
super().__init__()
self.inplanes = 128
vit_channels = self.inplanes * 8
vit_channels = self.inplanes * 8 if vit_channels is None else vit_channels

# Not clear how they used resnet arch. since the first input after conv
# must be 128 channels and half spat dims.
in_conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
bn1 = nn.BatchNorm2d(self.inplanes)
self.init_conv = nn.Sequential(in_conv1, bn1, nn.ReLU(inplace=True))
# self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)

self.conv1 = Bottleneck(self.inplanes, self.inplanes * 2, stride=2)
self.conv2 = Bottleneck(self.inplanes * 2, self.inplanes * 4, stride=2)
self.conv3 = Bottleneck(self.inplanes * 4, vit_channels, stride=2)

self.img_dim_vit = img_dim // 16

self.vit = ViT(img_dim=self.img_dim_vit,
in_channels=vit_channels, # encoder channels
patch_dim=1,
dim=vit_channels, # vit out channels for decoding
blocks=vit_blocks,
heads=vit_heads,
dim_linear_block=vit_dim_linear_mhsa_block,
classification=False)
classification=False) if vit_transformer is None else vit_transformer

self.vit_conv = SignleConv(in_ch=vit_channels, out_ch=512)

Expand Down

0 comments on commit 25622d5

Please sign in to comment.