diff --git a/README.md b/README.md index a5a8aef..0751f97 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,39 @@ block = ConformerBlock( ) x = torch.randn(1, 1024, 512) + block(x) # (1, 1024, 512) ``` + +Conformer - just multiple `ConformerBlock` from above + +```python +import torch +from conformer import Conformer + +conformer = Conformer( + dim = 512, + depth = 12, # 12 blocks + dim_head = 64, + heads = 8, + ff_mult = 4, + conv_expansion_factor = 2, + conv_kernel_size = 31, + attn_dropout = 0., + ff_dropout = 0., + conv_dropout = 0. +) + +x = torch.randn(1, 1024, 512) + +conformer(x) # (1, 1024, 512) +``` + +## Todo + +- [ ] switch to a better relative positional encoding. shaw's is dated +- [ ] flash attention with a better RPE + ## Citations ```bibtex diff --git a/conformer/__init__.py b/conformer/__init__.py index 257f34f..1923db6 100644 --- a/conformer/__init__.py +++ b/conformer/__init__.py @@ -1 +1 @@ -from conformer.conformer import ConformerConvModule, ConformerBlock +from conformer.conformer import ConformerConvModule, ConformerBlock, Conformer diff --git a/conformer/conformer.py b/conformer/conformer.py index 9ce3ff2..62aced4 100644 --- a/conformer/conformer.py +++ b/conformer/conformer.py @@ -85,7 +85,13 @@ def __init__( self.dropout = nn.Dropout(dropout) - def forward(self, x, context = None, mask = None, context_mask = None): + def forward( + self, + x, + context = None, + mask = None, + context_mask = None + ): n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context) context = default(context, x) @@ -95,6 +101,7 @@ def forward(self, x, context = None, mask = None, context_mask = None): dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # shaw's relative positional embedding + seq = torch.arange(n, device = device) dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j') dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb @@ -199,3 +206,41 @@ def forward(self, x, mask = None): x = self.ff2(x) + x x = self.post_norm(x) return x + +# Conformer + +class Conformer(nn.Module): + def __init__( + self, + dim, + *, + depth, + dim_head = 64, + heads = 8, + ff_mult = 4, + conv_expansion_factor = 2, + conv_kernel_size = 31, + attn_dropout = 0., + ff_dropout = 0., + conv_dropout = 0. + ): + super().__init__() + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append(ConformerBlock( + dim = dim, + dim_head = dim_head, + heads = heads, + ff_mult = ff_mult, + conv_expansion_factor = conv_expansion_factor, + conv_kernel_size = conv_kernel_size, + + )) + + def forward(self, x): + + for block in self.layers: + x = block(x) + + return x diff --git a/setup.py b/setup.py index c415b38..e388ad7 100644 --- a/setup.py +++ b/setup.py @@ -3,15 +3,20 @@ setup( name = 'conformer', packages = find_packages(), - version = '0.2.5', + version = '0.3.0', license='MIT', description = 'The convolutional module from the Conformer paper', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/conformer', - keywords = ['transformers', 'artificial intelligence', 'transformer'], + keywords = [ + 'artificial intelligence', + 'deep learning', + 'transformers', + 'audio' + ], install_requires=[ - 'einops', + 'einops>=0.6.1', 'torch' ], classifiers=[