Skip to content

Commit

Permalink
add tests on cuda if available, UNETTR paper, absolute_positional_enc…
Browse files Browse the repository at this point in the history
…odings_1d, update readme for running pytests
  • Loading branch information
black0017 committed Jun 30, 2021
1 parent 376f9f9 commit 0b9d3e4
Show file tree
Hide file tree
Showing 21 changed files with 326 additions and 38 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dist/
self_attention_cv.egg-info/
.idea/modules.xml
.idea/self_attention.iml
venv
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Focused on computer vision self-attention modules.

```$ pip install self-attention-cv```

It would be nice to pre-install pytorch in your environment, in case you don't have a GPU.
It would be nice to pre-install pytorch in your environment, in case you don't have a GPU. To run the tests from the terminal
```$ pytest``` you may need to run ``` export PYTHONPATH=$PATHONPATH:`pwd` ``` before.


## Related articles
Expand Down
4 changes: 2 additions & 2 deletions examples/pos_emb_1d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from self_attention_cv.pos_embeddings import AbsPosEmb1D, RelPosEmb1D
from self_attention_cv.pos_embeddings import PositionalEncoding1D
from self_attention_cv.pos_embeddings import PositionalEncodingSin

model = AbsPosEmb1D(tokens=20, dim_head=64)
# batch heads tokens dim_head
Expand All @@ -16,7 +16,7 @@
print('abs and pos emb ok')

a = torch.rand(3, 64, 128)
pos_enc = PositionalEncoding1D(dim=128, max_tokens=64)
pos_enc = PositionalEncodingSin(dim=128, max_tokens=64)
b = pos_enc(a)
assert a.shape == b.shape
print('sinusoidal pos enc 1D ok')
136 changes: 136 additions & 0 deletions self_attention_cv/UnetTr/UnetTr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
import torch.nn as nn
from einops import rearrange

from self_attention_cv.UnetTr.modules import TranspConv3DBlock, BlueBlock, Conv3DBlock
from self_attention_cv.UnetTr.volume_embedding import Embeddings3D
from self_attention_cv.transformer_vanilla import TransformerBlock


class TransformerEncoder(nn.Module):
def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers):
super().__init__()
self.layer = nn.ModuleList()
self.extract_layers = extract_layers

# makes TransformerBlock device aware
self.block_list = nn.ModuleList()
for _ in range(num_layers):
self.block_list.append(TransformerBlock(dim=embed_dim, heads=num_heads,
dim_linear_block=1024, dropout=dropout, prenorm=True))

def forward(self, x):
extract_layers = []
for depth, layer_block in enumerate(self.block_list):
x = layer_block(x)
if (depth + 1) in self.extract_layers:
extract_layers.append(x)

return extract_layers

# based on https://arxiv.org/abs/2103.10504
class UNETR(nn.Module):
def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3,
embed_dim=768, patch_size=16, num_heads=12, dropout=0.1,
num_layers=12, ext_layers=[3, 6, 9, 12], version='light'):
"""
Args:
img_shape: volume shape, provided as a tuple
input_dim: input modalities/channels
output_dim: number of classes
embed_dim: transformer embed dim.
patch_size: the non-overlapping patches to be created
num_heads: for the transformer encoder
dropout: percentage for dropout
num_layers: static to the architecture. cannot be changed with the current architecture.
ext_layers: transformer layers to use their output
version: 'light' saves some parameters in the decoding part
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.embed_dim = embed_dim
self.img_shape = img_shape
self.patch_size = patch_size
self.num_heads = num_heads
self.dropout = dropout
self.num_layers = num_layers
self.ext_layers = ext_layers
self.patch_dim = [int(x / patch_size) for x in img_shape]
self.base_filters = 64
self.prelast_filters = 32

# cheap way to reduce the number of parameters in the decoding part.
self.yellow_conv_channels = [256, 128, 64] if version == 'light' else [512, 256, 128]

self.embed = Embeddings3D(input_dim=input_dim, embed_dim=embed_dim,
cube_size=img_shape, patch_size=patch_size, dropout=dropout)

self.transformer = TransformerEncoder(embed_dim, num_heads, num_layers, dropout, ext_layers)

self.init_conv = Conv3DBlock(input_dim, self.base_filters, double=True)

# blue blocks in Fig.1
self.z3_blue_conv = nn.Sequential(
BlueBlock(in_planes=embed_dim, out_planes=512),
BlueBlock(in_planes=512, out_planes=256),
BlueBlock(in_planes=256, out_planes=128))

self.z6_blue_conv = nn.Sequential(
BlueBlock(in_planes=embed_dim, out_planes=512),
BlueBlock(in_planes=512, out_planes=256))

self.z9_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=512)

# Green blocks in Fig.1
self.z12_deconv = TranspConv3DBlock(embed_dim, 512)

self.z9_deconv = TranspConv3DBlock(self.yellow_conv_channels[0], 256)
self.z6_deconv = TranspConv3DBlock(self.yellow_conv_channels[1], 128)
self.z3_deconv = TranspConv3DBlock(self.yellow_conv_channels[2], 64)

# Yellow blocks in Fig.1
self.z9_conv = Conv3DBlock(1024, self.yellow_conv_channels[0], double=True)
self.z6_conv = Conv3DBlock(512, self.yellow_conv_channels[1], double=True)
self.z3_conv = Conv3DBlock(256, self.yellow_conv_channels[2], double=True)
# out convolutions
self.out_conv = nn.Sequential(
# last yellow conv block
Conv3DBlock(128, self.prelast_filters, double=True),
# grey block, final classification layer
Conv3DBlock(self.prelast_filters, output_dim, kernel_size=1, double=False))

def forward(self, x):
transf_input = self.embed(x)
z3, z6, z9, z12 = map(lambda t: rearrange(t, 'b (x y z) d -> b d x y z',
x=self.patch_dim[0], y=self.patch_dim[1], z=self.patch_dim[2]),
self.transformer(transf_input))

# Blue convs
z0 = self.init_conv(x)
z3 = self.z3_blue_conv(z3)
z6 = self.z6_blue_conv(z6)
z9 = self.z9_blue_conv(z9)

# Green block for z12
z12 = self.z12_deconv(z12)
# Concat + yellow conv
y = torch.cat([z12, z9], dim=1)
y = self.z9_conv(y)

# Green block for z6
y = self.z9_deconv(y)
# Concat + yellow conv
y = torch.cat([y, z6], dim=1)
y = self.z6_conv(y)

# Green block for z3
y = self.z6_deconv(y)
# Concat + yellow conv
y = torch.cat([y, z3], dim=1)
y = self.z3_conv(y)

y = self.z3_deconv(y)
y = torch.cat([y, z0], dim=1)
return self.out_conv(y)
1 change: 1 addition & 0 deletions self_attention_cv/UnetTr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .UnetTr import UNETR
45 changes: 45 additions & 0 deletions self_attention_cv/UnetTr/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch.nn as nn


# yellow block in Fig.1
class Conv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3, double=True):
super().__init__()
if double:
self.conv_block = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm3d(out_planes),
nn.ReLU(inplace=True),
nn.Conv3d(out_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm3d(out_planes),
nn.ReLU(inplace=True)
)
else:
self.conv_block = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm3d(out_planes),
nn.ReLU(inplace=True))

def forward(self, x):
return self.conv_block(x)

# green block in Fig.1
class TranspConv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super().__init__()
self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)

def forward(self, x):
return self.block(x)

# blue box in Fig.1
class BlueBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super().__init__()
self.block = nn.Sequential(TranspConv3DBlock(in_planes, out_planes),
Conv3DBlock(out_planes, out_planes,double=False))
def forward(self, x):
return self.block(x)
24 changes: 24 additions & 0 deletions self_attention_cv/UnetTr/volume_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch.nn as nn
from einops import rearrange

from self_attention_cv.pos_embeddings import AbsPositionalEncoding1D


class Embeddings3D(nn.Module):
def __init__(self, input_dim, embed_dim, cube_size, patch_size=16, dropout=0.1):
super().__init__()
self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_embeddings = nn.Conv3d(in_channels=input_dim, out_channels=embed_dim,
kernel_size=patch_size, stride=patch_size, bias=False)
self.position_embeddings = AbsPositionalEncoding1D(self.n_patches, embed_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
"""
x is a 5D tensor
"""
x = rearrange(self.patch_embeddings(x), 'b d x y z -> b (x y z) d')
embeddings = self.dropout(self.position_embeddings(x))
return embeddings
4 changes: 3 additions & 1 deletion self_attention_cv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from self_attention_cv.pos_embeddings import *
from .pos_embeddings import *
from .MSA_transformer import MSATransformerBlock, MSATransformerEncoder, TiedRowAxialAttention
from .axial_attention_deeplab import *
from .linformer import LinformerAttention, LinformerBlock, LinformerEncoder
from .timesformer import Timesformer, SpacetimeMHSA
from .transformer_vanilla import *
from .vit import ViT, ResNet50ViT
from .transunet import TransUnet
from .UnetTr import UNETR
18 changes: 18 additions & 0 deletions self_attention_cv/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import random
from typing import List, Tuple

import numpy as np
import torch
from einops import repeat
from torch import Tensor, nn


def expand_to_batch(tensor, desired_size):
Expand All @@ -20,3 +22,19 @@ def init_random_seed(seed, gpu=False):
os.environ['PYTHONHASHSEED'] = str(seed)
if gpu:
torch.backends.cudnn.deterministic = True


# from https://huggingface.co/transformers/_modules/transformers/modeling_utils.html
def get_module_device(parameter: nn.Module):
try:
return next(parameter.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5

def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
3 changes: 2 additions & 1 deletion self_attention_cv/pos_embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .abs_pos_emb1D import AbsPosEmb1D
from .abs_pos_encoding_1d import AbsPositionalEncoding1D
from .relative_embeddings_1D import rel_pos_emb_1d, RelPosEmb1D
from .relative_embeddings_2D import RelPosEmb2D
from .relative_pos_enc_qkv import Relative2DPosEncQKV
from .pos_encoding_sin import PositionalEncoding1D
from .pos_encoding_sin import PositionalEncodingSin
13 changes: 13 additions & 0 deletions self_attention_cv/pos_embeddings/abs_pos_encoding_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
import torch.nn as nn
from ..common import expand_to_batch


class AbsPositionalEncoding1D(nn.Module):
def __init__(self, tokens, dim):
super(AbsPositionalEncoding1D, self).__init__()
self.abs_pos_enc = nn.Parameter(torch.randn(1,tokens, dim))

def forward(self, x):
batch = x.size()[0]
return x + expand_to_batch(self.abs_pos_enc, desired_size=batch)
5 changes: 2 additions & 3 deletions self_attention_cv/pos_embeddings/pos_encoding_sin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from ..common import expand_to_batch



# adapted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding1D(nn.Module):
class PositionalEncodingSin(nn.Module):

def __init__(self, dim, dropout=0.1, max_tokens=5000):
super(PositionalEncoding1D, self).__init__()
super(PositionalEncodingSin, self).__init__()
self.dropout = nn.Dropout(p=dropout)

pe = torch.zeros(1, max_tokens, dim)
Expand Down
20 changes: 14 additions & 6 deletions self_attention_cv/transformer_vanilla/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn

from .mhsa import MultiHeadSelfAttention

from ..common import get_module_device

class TransformerBlock(nn.Module):
"""
Expand All @@ -11,7 +11,7 @@ class TransformerBlock(nn.Module):

def __init__(self, dim, heads=8, dim_head=None,
dim_linear_block=1024, dropout=0.1, activation=nn.GELU,
mhsa=None):
mhsa=None, prenorm=False):
"""
Args:
dim: token's vector length
Expand All @@ -20,9 +20,11 @@ def __init__(self, dim, heads=8, dim_head=None,
dim_linear_block: the inner projection dim
dropout: probability of droppping values
mhsa: if provided you can change the vanilla self-attention block
prenorm: if the layer norm will be applied before the mhsa or after
"""
super().__init__()
self.mhsa = mhsa if mhsa is not None else MultiHeadSelfAttention(dim=dim, heads=heads, dim_head=dim_head)
self.prenorm = prenorm
self.drop = nn.Dropout(dropout)
self.norm_1 = nn.LayerNorm(dim)
self.norm_2 = nn.LayerNorm(dim)
Expand All @@ -36,14 +38,20 @@ def __init__(self, dim, heads=8, dim_head=None,
)

def forward(self, x, mask=None):
y = self.norm_1(self.drop(self.mhsa(x, mask)) + x)
return self.norm_2(self.linear(y) + y)
if self.prenorm:
y = self.drop(self.mhsa(self.norm_1(x), mask)) + x
out = self.linear(self.norm_2(y)) + y
else:
y = self.norm_1(self.drop(self.mhsa(x, mask)) + x)
out = self.norm_2(self.linear(y) + y)
return out


class TransformerEncoder(nn.Module):
def __init__(self, dim, blocks=6, heads=8, dim_head=None, dim_linear_block=1024, dropout=0):
def __init__(self, dim, blocks=6, heads=8, dim_head=None, dim_linear_block=1024, dropout=0, prenorm=False):
super().__init__()
self.block_list = [TransformerBlock(dim, heads, dim_head, dim_linear_block, dropout) for _ in range(blocks)]
self.block_list = [TransformerBlock(dim, heads, dim_head,
dim_linear_block, dropout, prenorm=prenorm) for _ in range(blocks)]
self.layers = nn.ModuleList(self.block_list)

def forward(self, x, mask=None):
Expand Down
2 changes: 1 addition & 1 deletion self_attention_cv/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

__version__ = "1.1.0"
__version__ = "1.2.0"

msg = "Self_attention_cv is only compatible with Python 3.0 and newer."

Expand Down
Loading

0 comments on commit 0b9d3e4

Please sign in to comment.