Skip to content

Commit

Permalink
transunet model first implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
black0017 committed Feb 15, 2021
1 parent 085ecaf commit 31df688
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 38 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4,
y = bottleneck_block(inp)
```


## References

1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
Expand Down
9 changes: 5 additions & 4 deletions examples/test_TransUnet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from self_attention_cv.transunet import TransUnet
import torch

a = torch.rand(2,3,128,128)
from self_attention_cv.transunet import TransUnet

a = torch.rand(2, 3, 128, 128)

model = TransUnet(in_channels=3,img_dim=128)
model = TransUnet(in_channels=3, img_dim=128, classes=5)
y = model(a)
print(y.shape)
print('final out shape:', y.shape)
2 changes: 0 additions & 2 deletions self_attention_cv/transunet/bottleneck_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, inplanes, planes, stride=1, groups=1,
else:
self.downsample = nn.Identity()



width = int(planes * (base_width / 64.)) * groups

self.conv1 = conv1x1(inplanes, width)
Expand Down
61 changes: 61 additions & 0 deletions self_attention_cv/transunet/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import torch.nn as nn


class SignleConv(nn.Module):
"""
Double convolution block that keeps that spatial sizes the same
"""

def __init__(self, in_ch, out_ch, norm_layer=None):
super(SignleConv, self).__init__()

if norm_layer is None:
norm_layer = nn.BatchNorm2d

self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
norm_layer(out_ch),
nn.ReLU(inplace=True))

def forward(self, x):
return self.conv(x)


class DoubleConv(nn.Module):
"""
Double convolution block that keeps that spatial sizes the same
"""
def __init__(self, in_ch, out_ch, norm_layer=None):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(SignleConv(in_ch, out_ch, norm_layer),
SignleConv(out_ch, out_ch, norm_layer))

def forward(self, x):
return self.conv(x)


class Up(nn.Module):
"""
Doubles spatial size with bilinear upsampling
Skip connections and double convs
"""

def __init__(self, in_ch, out_ch):
super(Up, self).__init__()
mode = "bilinear"
self.up = nn.Upsample(scale_factor=2, mode=mode, align_corners=True)
self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x1, x2=None):
"""
Args:
x1: [b,c, h, w]
x2: [b,c, 2*h,2*w]
Returns: 2x upsampled double conv reselt
"""
x = self.up(x1)
if x2 is not None:
x = torch.cat([x2, x], dim=1)
return self.conv(x)
71 changes: 43 additions & 28 deletions self_attention_cv/transunet/trans_unet.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,66 @@
import torch
import torch.nn as nn
from einops import rearrange

from .bottleneck_layer import Bottleneck
from .decoder import Up, SignleConv
from ..vit import ViT


class TransUnet(nn.Module):
def __init__(self, *, img_dim, in_channels,
def __init__(self, *, img_dim, in_channels, classes,
vit_blocks=1,
vit_heads=4,
vit_dim_linear_mhsa_block=512,
):
super().__init__()
self.inplanes = 64
self.inplanes = 128
resnet_7x7_conv = True

if resnet_7x7_conv:
in_conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
bn1 = nn.BatchNorm2d(self.inplanes)
self.init_conv = nn.Sequential(in_conv1, bn1, nn.ReLU(inplace=True))
#self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
else:
self.init_conv = Bottleneck(in_channels, self.inplanes, stride=2)

in_conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
bn1 = nn.BatchNorm2d(self.inplanes)
self.init_conv = nn.Sequential(in_conv1, bn1, nn.ReLU(inplace=True))
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)

self.conv1 = Bottleneck(self.inplanes, 128)
self.conv2 = Bottleneck(128, 256, stride=2)
self.conv3 = Bottleneck(256, 512, stride=2)
self.conv1 = Bottleneck(self.inplanes, self.inplanes*2,stride=2)
self.conv2 = Bottleneck(self.inplanes*2, self.inplanes*4, stride=2)
self.conv3 = Bottleneck(self.inplanes*4, self.inplanes*8,stride=2)

self.img_dim = img_dim//16
self.vit = ViT(img_dim=self.img_dim,
in_channels=512, # based on resnet channels
vit_channels = self.inplanes*8

self.img_dim_vit = img_dim // 16
self.vit = ViT(img_dim=self.img_dim_vit,
in_channels=vit_channels, # based on resnet channels
patch_dim=1,
dim=512, # out channels for decoding
dim=vit_channels, # vit out channels for decoding
blocks=vit_blocks,
heads=vit_heads,
dim_linear_block=vit_dim_linear_mhsa_block,
classification=False)

def forward(self, x):
# ResNet 50 encoder
x1 = self.init_conv(x)
x2 = self.pool(x1)
x2 = self.conv1(x2)
x3 = self.conv2(x2)
x4 = self.conv3(x3)

# Vision Transformer ViT
x6 = self.vit(x4)
x7 = rearrange(x6, ' b (x y) dim -> b dim x y ', x=self.img_dim, y=self.img_dim)
self.vit_conv = SignleConv(in_ch=vit_channels, out_ch=512)

# Decoder
self.dec1 = Up(1024, 256)
self.dec2 = Up(512, 128)
self.dec3 = Up(256, 64)
self.dec4 = Up(64, 16)
self.conv1x1 = nn.Conv2d(16,classes,kernel_size=1)

return x7
def forward(self, x):
# ResNet 50 encoder
x2 = self.init_conv(x) # 128,64,64
x4 = self.conv1(x2) # 256,32,32
x8 = self.conv2(x4) # 512,16,16
x16 = self.conv3(x8) # 1024,8,8
y = self.vit(x16)
y = rearrange(y, 'b (x y) dim -> b dim x y ', x=self.img_dim_vit, y=self.img_dim_vit)
y = self.vit_conv(y)
y = self.dec1(y, x8) # 256,16,16
y = self.dec2(y, x4)
y = self.dec3(y, x2)
y = self.dec4(y)
return self.conv1x1(y)
5 changes: 2 additions & 3 deletions self_attention_cv/vit/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ def __init__(self, *,
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

self.mlp_head = nn.Linear(dim, num_classes)

if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
Expand Down Expand Up @@ -93,4 +92,4 @@ def forward(self, img, mask=None):
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :])
else:
return y
return y

0 comments on commit 31df688

Please sign in to comment.