forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compositional.py
341 lines (287 loc) · 12.7 KB
/
compositional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Credits: this is heavily inspired by the official implementation, present in
# https://github.com/sarthmit/Compositional-Attention
# Original author: Sarthak Mittal
# This is a simplified version, for the sake of clarity, and because some features could be exposed later
# via the library directly.
# In particular, code paths for TPUs, quantization and gumbel softmax have been removed
# We're also following the same dimension ordering as in the rest of the xformers library
# which is to say [Batch, Sequence, Embedding] wherever possible
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import _softmax
from xformers.components.input_projection import InputProjection, InputProjectionConfig
def _either_or(a: Optional[int], b: int) -> int:
return a if a is not None else b
@dataclass
class CompositionalAttentionConfig(AttentionConfig):
dim_model: int
num_heads: int
dim_attn: Optional[int] = None
num_rules: Optional[int] = None
dim_key: Optional[int] = None
dim_value: Optional[int] = None
dim_selection: Optional[int] = None
dropout: float
qk_rule: bool = False
nonlinear: bool = False
q_compose: bool = False
bias: bool = True
causal: Optional[bool] = False
in_proj_container: Optional[InputProjection] = None
use_separate_proj_weight: Optional[bool] = False
@register_attention("compositional", CompositionalAttentionConfig)
class CompositionalAttention(Attention):
"""Compositional Attention, as proposed in
"Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al.
A key insight from this proposal is that the attention mechanism can be conceived as two steps:
a search and a retrieval operation. When queried, the model can search for the most relevant information
(Softmax(QKt)), then retrieve information given the Value.
Contrary to the original attention proposal, which does not consider interactions in between heads,
the compositional attention will consider all possible interactions and softmax over that dimension,
so that the information retrieved covers the most relevant dimensions. The number of heads and rules to
use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads
may not fit in memory.
Args:
dim_model: dimension of the incoming latent space
num_heads: number of heads *for the search operation*
dim_attn: dimension (embedding) of the attention
num_rules: number of rules to consider *for the retrieval operation*
dim_selection: dimension of the scoring/selection space for the retrievals
dim_key, dim_value: dimensions of K and V, if different from Q
dropout: attention dropout probability
qk_rule: QK product will drive the retrieval process
nonlinear: use a non linear method to score the retrievals
bias: use bias in the initial projection step
causal: causal computations (attend to the past only)
_"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf
"""
def __init__(
self,
dim_model: int,
num_heads: int,
dim_attn: Optional[int] = None,
num_rules: Optional[int] = None,
dim_selection: Optional[int] = None,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
dropout=0.0,
qk_rule=False,
nonlinear=False,
q_compose=False,
in_proj_container: Optional[InputProjection] = None,
use_separate_proj_weight: Optional[bool] = False,
bias=True,
causal=False,
*_,
**__,
):
super().__init__()
# Define the inherited flags
self.requires_skip_multi_head = (
True # This attention owns the multi-head mechanism
)
# Handle defaults / undefined values
self.dim_model = dim_model
num_rules = _either_or(num_rules, num_heads)
dim_selection = _either_or(dim_selection, dim_model // num_heads)
# All the initial definition plumbing
dim_attn = _either_or(dim_attn, dim_model)
dim_key = _either_or(dim_key, dim_model)
dim_value = _either_or(dim_value, dim_model)
self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InputProjection(
query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias),
key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias)
if use_separate_proj_weight
else None,
value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias)
if use_separate_proj_weight
else None,
)
)
self.num_heads = num_heads
self.num_rules = num_rules
self.qk_rule = qk_rule
self.dim_selection = dim_selection
self.nonlinear = nonlinear
self.q_compose = q_compose
self.dropout_module = nn.Dropout(dropout)
self.dim_head = dim_model // num_heads
self.value_dim = dim_attn // num_rules
assert (
self.value_dim * num_rules == dim_attn
), "value_dim must be divisible by num_rules"
self.scaling = self.dim_head**-0.5
self.scaling_values = self.dim_selection**-0.5
self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)
if self.qk_rule:
self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias)
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_model, self.dim_selection * self.num_heads, bias=bias
)
else:
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_model, self.dim_selection * self.num_heads, bias=bias
)
if self.nonlinear:
self.score_network: nn.Module = nn.Sequential(
nn.Linear(
self.dim_selection + self.value_dim,
self.dim_selection,
bias=bias,
),
nn.ReLU(),
nn.Linear(self.dim_selection, 1, bias=bias),
)
else:
self.score_network = nn.Linear(
self.dim_selection + self.value_dim, 1, bias=bias
)
self.causal = causal
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
self._reset_parameters()
def _reset_parameters(self):
# NOTE: in_proj_container is already initialized
if self.qk_rule:
nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.value_q.weight)
if self.nonlinear:
nn.init.xavier_uniform_(self.score_network[0].weight)
nn.init.xavier_uniform_(self.score_network[2].weight)
else:
nn.init.xavier_uniform_(self.score_network.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
att_mask: Optional[Tensor] = None,
*args,
**kwargs,
) -> Tensor:
"""
Input shape: Time x Batch x Channel
Args:
att_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
B, Sq, E = q.shape
_, Sk, _ = k.shape
assert E == self.dim_model
# First define projected query/key/values
# We keep the projected and original tensors in flight,
# depending on the options the original values could be reused
q_unprojected = q
q, k, v = self.in_proj_container(query=q, key=k, value=v)
q *= self.scaling
# Init causal mask if needed, now that we know the context length
if self.causal and (
self._causal_mask is None or self._causal_mask.shape[0] != Sk
):
self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device)
# Convenience, create an attention mask if a tensor was passed
# This sanitizes different mask types being passed, from now on it's additive
if isinstance(att_mask, torch.Tensor):
# By default we don't know of the causality, and a check would be expensive
att_mask_additive: Optional[AttentionMask] = (
AttentionMask.from_bool(att_mask)
if att_mask.dtype == torch.bool
else AttentionMask(att_mask, is_causal=False)
)
else:
att_mask_additive = None
# Handle the attention and key padding masks
if self._causal_mask is not None:
# Optionally add the causal mask
if att_mask_additive is not None:
att_mask_additive += self._causal_mask
else:
att_mask_additive = self._causal_mask
# Flatten the heads or the rules
q = (
q.view(B, Sq, self.num_heads, self.dim_head)
.movedim(2, 1)
.flatten(0, 1) # [B * num_heads, Sq, dim_head]
)
k = (
k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1)
) # [B * num_heads, Sk, dim_head]
v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1)
# Compute the search: Softmax(QKt)
attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk]
if att_mask_additive is not None:
attn_weights += att_mask_additive.values
attn_weights = _softmax(attn_weights, causal=self.causal)
attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk)
attn_probs = self.dropout_module(attn_weights)
# Now compute the information retrieval
# keep all the heads in flight, we'll score the different possibilities
# - compute all the possible retrievals
v = v.view(B, 1, self.num_rules, Sk, self.value_dim)
attn_probs = attn_probs.unsqueeze(2)
attn = torch.matmul(attn_probs, v).view(
B, self.num_heads, self.num_rules, Sq, self.value_dim
)
attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values]
# - search the most appropriate retrieval among all the values
if self.q_compose:
v_q = self.value_q(q.transpose(0, 1)).view(
B, Sq, self.num_heads, 1, self.dim_selection
)
else:
v_q = self.value_q(q_unprojected).view(
B, Sq, self.num_heads, 1, self.dim_selection
)
if self.qk_rule:
v_q *= self.scaling_values
v_k = (
self.value_k(attn)
.view(B, Sq, self.num_heads, self.num_rules, self.dim_selection)
.transpose(4, 3)
.contiguous()
)
v_score = torch.matmul(v_q, v_k).view(
B, Sq, self.num_heads, self.num_rules, 1
)
else:
v_q = v_q.expand(-1, -1, -1, self.num_rules, -1)
v_in = torch.cat([attn, v_q], dim=-1)
v_score = self.score_network(v_in).view(
B, Sq, self.num_heads, self.num_rules, 1
)
v_score = F.softmax(v_score, dim=3)
# - extracted values are the original attention (inc. all the values) weighted by value score
attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim)
# Final attention projection, same as other mechanisms
attn = self.out_proj(attn)
return attn