-
-
Notifications
You must be signed in to change notification settings - Fork 154
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add tests on cuda if available, UNETTR paper, absolute_positional_enc…
…odings_1d, update readme for running pytests
- Loading branch information
Showing
21 changed files
with
326 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ dist/ | |
self_attention_cv.egg-info/ | ||
.idea/modules.xml | ||
.idea/self_attention.iml | ||
venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch | ||
import torch.nn as nn | ||
from einops import rearrange | ||
|
||
from self_attention_cv.UnetTr.modules import TranspConv3DBlock, BlueBlock, Conv3DBlock | ||
from self_attention_cv.UnetTr.volume_embedding import Embeddings3D | ||
from self_attention_cv.transformer_vanilla import TransformerBlock | ||
|
||
|
||
class TransformerEncoder(nn.Module): | ||
def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers): | ||
super().__init__() | ||
self.layer = nn.ModuleList() | ||
self.extract_layers = extract_layers | ||
|
||
# makes TransformerBlock device aware | ||
self.block_list = nn.ModuleList() | ||
for _ in range(num_layers): | ||
self.block_list.append(TransformerBlock(dim=embed_dim, heads=num_heads, | ||
dim_linear_block=1024, dropout=dropout, prenorm=True)) | ||
|
||
def forward(self, x): | ||
extract_layers = [] | ||
for depth, layer_block in enumerate(self.block_list): | ||
x = layer_block(x) | ||
if (depth + 1) in self.extract_layers: | ||
extract_layers.append(x) | ||
|
||
return extract_layers | ||
|
||
# based on https://arxiv.org/abs/2103.10504 | ||
class UNETR(nn.Module): | ||
def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3, | ||
embed_dim=768, patch_size=16, num_heads=12, dropout=0.1, | ||
num_layers=12, ext_layers=[3, 6, 9, 12], version='light'): | ||
""" | ||
Args: | ||
img_shape: volume shape, provided as a tuple | ||
input_dim: input modalities/channels | ||
output_dim: number of classes | ||
embed_dim: transformer embed dim. | ||
patch_size: the non-overlapping patches to be created | ||
num_heads: for the transformer encoder | ||
dropout: percentage for dropout | ||
num_layers: static to the architecture. cannot be changed with the current architecture. | ||
ext_layers: transformer layers to use their output | ||
version: 'light' saves some parameters in the decoding part | ||
""" | ||
super().__init__() | ||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
self.embed_dim = embed_dim | ||
self.img_shape = img_shape | ||
self.patch_size = patch_size | ||
self.num_heads = num_heads | ||
self.dropout = dropout | ||
self.num_layers = num_layers | ||
self.ext_layers = ext_layers | ||
self.patch_dim = [int(x / patch_size) for x in img_shape] | ||
self.base_filters = 64 | ||
self.prelast_filters = 32 | ||
|
||
# cheap way to reduce the number of parameters in the decoding part. | ||
self.yellow_conv_channels = [256, 128, 64] if version == 'light' else [512, 256, 128] | ||
|
||
self.embed = Embeddings3D(input_dim=input_dim, embed_dim=embed_dim, | ||
cube_size=img_shape, patch_size=patch_size, dropout=dropout) | ||
|
||
self.transformer = TransformerEncoder(embed_dim, num_heads, num_layers, dropout, ext_layers) | ||
|
||
self.init_conv = Conv3DBlock(input_dim, self.base_filters, double=True) | ||
|
||
# blue blocks in Fig.1 | ||
self.z3_blue_conv = nn.Sequential( | ||
BlueBlock(in_planes=embed_dim, out_planes=512), | ||
BlueBlock(in_planes=512, out_planes=256), | ||
BlueBlock(in_planes=256, out_planes=128)) | ||
|
||
self.z6_blue_conv = nn.Sequential( | ||
BlueBlock(in_planes=embed_dim, out_planes=512), | ||
BlueBlock(in_planes=512, out_planes=256)) | ||
|
||
self.z9_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=512) | ||
|
||
# Green blocks in Fig.1 | ||
self.z12_deconv = TranspConv3DBlock(embed_dim, 512) | ||
|
||
self.z9_deconv = TranspConv3DBlock(self.yellow_conv_channels[0], 256) | ||
self.z6_deconv = TranspConv3DBlock(self.yellow_conv_channels[1], 128) | ||
self.z3_deconv = TranspConv3DBlock(self.yellow_conv_channels[2], 64) | ||
|
||
# Yellow blocks in Fig.1 | ||
self.z9_conv = Conv3DBlock(1024, self.yellow_conv_channels[0], double=True) | ||
self.z6_conv = Conv3DBlock(512, self.yellow_conv_channels[1], double=True) | ||
self.z3_conv = Conv3DBlock(256, self.yellow_conv_channels[2], double=True) | ||
# out convolutions | ||
self.out_conv = nn.Sequential( | ||
# last yellow conv block | ||
Conv3DBlock(128, self.prelast_filters, double=True), | ||
# grey block, final classification layer | ||
Conv3DBlock(self.prelast_filters, output_dim, kernel_size=1, double=False)) | ||
|
||
def forward(self, x): | ||
transf_input = self.embed(x) | ||
z3, z6, z9, z12 = map(lambda t: rearrange(t, 'b (x y z) d -> b d x y z', | ||
x=self.patch_dim[0], y=self.patch_dim[1], z=self.patch_dim[2]), | ||
self.transformer(transf_input)) | ||
|
||
# Blue convs | ||
z0 = self.init_conv(x) | ||
z3 = self.z3_blue_conv(z3) | ||
z6 = self.z6_blue_conv(z6) | ||
z9 = self.z9_blue_conv(z9) | ||
|
||
# Green block for z12 | ||
z12 = self.z12_deconv(z12) | ||
# Concat + yellow conv | ||
y = torch.cat([z12, z9], dim=1) | ||
y = self.z9_conv(y) | ||
|
||
# Green block for z6 | ||
y = self.z9_deconv(y) | ||
# Concat + yellow conv | ||
y = torch.cat([y, z6], dim=1) | ||
y = self.z6_conv(y) | ||
|
||
# Green block for z3 | ||
y = self.z6_deconv(y) | ||
# Concat + yellow conv | ||
y = torch.cat([y, z3], dim=1) | ||
y = self.z3_conv(y) | ||
|
||
y = self.z3_deconv(y) | ||
y = torch.cat([y, z0], dim=1) | ||
return self.out_conv(y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .UnetTr import UNETR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch.nn as nn | ||
|
||
|
||
# yellow block in Fig.1 | ||
class Conv3DBlock(nn.Module): | ||
def __init__(self, in_planes, out_planes, kernel_size=3, double=True): | ||
super().__init__() | ||
if double: | ||
self.conv_block = nn.Sequential( | ||
nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, | ||
padding=((kernel_size - 1) // 2)), | ||
nn.BatchNorm3d(out_planes), | ||
nn.ReLU(inplace=True), | ||
nn.Conv3d(out_planes, out_planes, kernel_size=kernel_size, stride=1, | ||
padding=((kernel_size - 1) // 2)), | ||
nn.BatchNorm3d(out_planes), | ||
nn.ReLU(inplace=True) | ||
) | ||
else: | ||
self.conv_block = nn.Sequential( | ||
nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, | ||
padding=((kernel_size - 1) // 2)), | ||
nn.BatchNorm3d(out_planes), | ||
nn.ReLU(inplace=True)) | ||
|
||
def forward(self, x): | ||
return self.conv_block(x) | ||
|
||
# green block in Fig.1 | ||
class TranspConv3DBlock(nn.Module): | ||
def __init__(self, in_planes, out_planes): | ||
super().__init__() | ||
self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
# blue box in Fig.1 | ||
class BlueBlock(nn.Module): | ||
def __init__(self, in_planes, out_planes): | ||
super().__init__() | ||
self.block = nn.Sequential(TranspConv3DBlock(in_planes, out_planes), | ||
Conv3DBlock(out_planes, out_planes,double=False)) | ||
def forward(self, x): | ||
return self.block(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch.nn as nn | ||
from einops import rearrange | ||
|
||
from self_attention_cv.pos_embeddings import AbsPositionalEncoding1D | ||
|
||
|
||
class Embeddings3D(nn.Module): | ||
def __init__(self, input_dim, embed_dim, cube_size, patch_size=16, dropout=0.1): | ||
super().__init__() | ||
self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size)) | ||
self.patch_size = patch_size | ||
self.embed_dim = embed_dim | ||
self.patch_embeddings = nn.Conv3d(in_channels=input_dim, out_channels=embed_dim, | ||
kernel_size=patch_size, stride=patch_size, bias=False) | ||
self.position_embeddings = AbsPositionalEncoding1D(self.n_patches, embed_dim) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x): | ||
""" | ||
x is a 5D tensor | ||
""" | ||
x = rearrange(self.patch_embeddings(x), 'b d x y z -> b (x y z) d') | ||
embeddings = self.dropout(self.position_embeddings(x)) | ||
return embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from self_attention_cv.pos_embeddings import * | ||
from .pos_embeddings import * | ||
from .MSA_transformer import MSATransformerBlock, MSATransformerEncoder, TiedRowAxialAttention | ||
from .axial_attention_deeplab import * | ||
from .linformer import LinformerAttention, LinformerBlock, LinformerEncoder | ||
from .timesformer import Timesformer, SpacetimeMHSA | ||
from .transformer_vanilla import * | ||
from .vit import ViT, ResNet50ViT | ||
from .transunet import TransUnet | ||
from .UnetTr import UNETR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .abs_pos_emb1D import AbsPosEmb1D | ||
from .abs_pos_encoding_1d import AbsPositionalEncoding1D | ||
from .relative_embeddings_1D import rel_pos_emb_1d, RelPosEmb1D | ||
from .relative_embeddings_2D import RelPosEmb2D | ||
from .relative_pos_enc_qkv import Relative2DPosEncQKV | ||
from .pos_encoding_sin import PositionalEncoding1D | ||
from .pos_encoding_sin import PositionalEncodingSin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import torch | ||
import torch.nn as nn | ||
from ..common import expand_to_batch | ||
|
||
|
||
class AbsPositionalEncoding1D(nn.Module): | ||
def __init__(self, tokens, dim): | ||
super(AbsPositionalEncoding1D, self).__init__() | ||
self.abs_pos_enc = nn.Parameter(torch.randn(1,tokens, dim)) | ||
|
||
def forward(self, x): | ||
batch = x.size()[0] | ||
return x + expand_to_batch(self.abs_pos_enc, desired_size=batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.