forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FlashAttentionKernel.cpp
630 lines (591 loc) · 23.9 KB
/
FlashAttentionKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif
namespace at::native {
namespace {
template <typename scalar_t>
static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
TORCH_INTERNAL_ASSERT(ptr2 == nullptr);
return ptr;
}
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) {
return ptr2;
}
template <typename scalar_t>
inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
using Vec = Vectorized<scalar_t>;
Vec data_vec = Vec(val);
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
data_vec.store(data + d);
}
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
# pragma unroll
#endif
for (; d < size; d++) {
data[d] = val;
}
}
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention(
const Tensor& output,
const Tensor& logsumexp,
const Tensor& cum_seq_q,
const Tensor& cum_seq_k,
int64_t& max_q,
int64_t& max_k,
const Tensor& philox_seed,
const Tensor& philox_offset,
const Tensor& debug_attn_mask,
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<double> scale) {
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
at::Tensor query = q.transpose(1, 2);
at::Tensor key = k.transpose(1, 2);
at::Tensor value = v.transpose(1, 2);
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor =
sdp::calculate_scale(query, scale).as_float_unchecked();
// Sizes
TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
int64_t batchSize = query.size(0);
int64_t qSize = query.size(1);
int64_t kvSize = value.size(1);
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);
// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
int64_t qStrideH = query.stride(2);
int64_t kStrideB = key.stride(0);
int64_t kStrideN = key.stride(1);
int64_t kStrideH = key.stride(2);
int64_t vStrideB = value.stride(0);
int64_t vStrideN = value.stride(1);
int64_t vStrideH = value.stride(2);
int64_t oStrideB = output.stride(0);
int64_t oStrideM = output.stride(1);
int64_t oStrideH = output.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t qSlice = (qSize - 1) / qSplitSize + 1;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* qk */ qSplitSize * kvSplitSize +
/* qk_max */ qSplitSize +
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * headSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
at::Tensor buf_reduced = at::empty({num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, query.options());
// Data ptrs
scalar_t* q_data = query.data_ptr<scalar_t>();
scalar_t* k_data = key.data_ptr<scalar_t>();
scalar_t* v_data = value.data_ptr<scalar_t>();
scalar_t* out_data = output.data_ptr<scalar_t>();
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, k = 0;
data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* qk_data = buf_ptr;
accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
accum_t* qk_sum_data = qk_max_data + qSplitSize;
accum_t* dst_data = qk_sum_data + qSplitSize;
scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize : nullptr;
for (const auto z : c10::irange(begin, end)) {
(void)z; // Suppress unused variable
int64_t m = k * qSplitSize;
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
// Initialize max and sum
fill_stub(qk_max_data,
-std::numeric_limits<accum_t>::infinity(), qBlockSize);
fill_stub(qk_sum_data,
static_cast<accum_t>(0), qBlockSize);
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
// Calculate scale * q @ k.T
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
scaling_factor,
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
static_cast<accum_t>(0),
qk_data,
kvBlockSize);
// Apply causal mask, fill unused with -inf
if (is_causal && num_keys - n <= kvSplitSize) {
for (const auto row : c10::irange(qBlockSize)) {
int64_t last_col = m + row - n;
accum_t* row_ptr = qk_data + row * kvBlockSize;
fill_stub(row_ptr + last_col + 1,
-std::numeric_limits<accum_t>::infinity(),
kvBlockSize - last_col - 1);
}
}
// Update coefficients with Softmax
accum_t tmp_max = 0, tmp_sum = 0, sum_old = 0, exp_tmp = 0;
for (int64_t row = 0; row < qBlockSize; ++row) {
sum_old = qk_sum_data[row];
// max per row
tmp_max = vec::reduce_all<accum_t>(
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
qk_data + row * kvBlockSize, kvBlockSize);
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
// qk <- exp(qk - max)
vec::map<accum_t>(
[tmp_max](Vec x) { return (x - Vec(tmp_max)).exp(); },
qk_data + row * kvBlockSize, qk_data + row * kvBlockSize, kvBlockSize);
// sum per row
tmp_sum = vec::reduce_all<accum_t>(
[](Vec& x, Vec& y) { return x + y; }, qk_data + row * kvBlockSize, kvBlockSize);
// exp_tmp <- exp(max[row] - max)
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
// sum[row] <- sum + exp_tmp * sum[row]
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
// max[row] <- max
qk_max_data[row] = tmp_max;
// qk <- qk / sum[row]
accum_t sum_new = qk_sum_data[row];
vec::map<accum_t>(
[sum_new](Vec x) { return x / Vec(sum_new); },
qk_data + row * kvBlockSize, qk_data + row * kvBlockSize, kvBlockSize);
if (is_reduced_type) {
convert<accum_t, scalar_t>(
qk_data + row * kvBlockSize,
qk_reduced_data + row * kvBlockSize,
kvBlockSize);
}
// dst <- dst * sum_old / sum_new * exp_tmp
if (n > 0) {
accum_t sum_cor = sum_old / sum_new;
vec::map<accum_t>(
[sum_cor, exp_tmp](Vec x)
{ return x * Vec(sum_cor) * Vec(exp_tmp); },
dst_data + row * headSize, dst_data + row * headSize, headSize);
}
}
// Calculate Softmax(q @ k.T) @ v
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
headSize,
qBlockSize,
kvBlockSize,
static_cast<accum_t>(1),
v_data + i * vStrideB + j * vStrideH +
n * vStrideN,
vStrideN,
conditional_data_ptr(qk_data, qk_reduced_data),
kvBlockSize,
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
dst_data,
headSize);
}
// reorder MHA output with strides
for (int64_t row = 0; row < qBlockSize; ++row) {
vec::map<scalar_t>(
[](Vec x) { return x; },
out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
dst_data + row * headSize,
headSize);
}
// Store logsumexp for backward
accum_t* lse_ptr = lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
for (const auto row : c10::irange(qBlockSize)) {
lse_ptr[row * lStrideM] = qk_max_data[row]
+ std::log(qk_sum_data[row]);
}
// Move to the next query
data_index_step(i, batchSize, j, num_head, k, qSlice);
}
});
}
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention_backward(
const at::Tensor& grad_q,
const at::Tensor& grad_k,
const at::Tensor& grad_v,
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
c10::optional<double> scale) {
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor =
sdp::calculate_scale(query, scale).as_float_unchecked();
// Sizes
TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention_backward: Q/K/V should have the same head size");
// Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
int64_t batchSize = query.size(0);
int64_t qSize = query.size(1);
int64_t kvSize = value.size(1);
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);
// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
int64_t qStrideH = query.stride(2);
int64_t kStrideB = key.stride(0);
int64_t kStrideN = key.stride(1);
int64_t kStrideH = key.stride(2);
int64_t vStrideB = value.stride(0);
int64_t vStrideN = value.stride(1);
int64_t vStrideH = value.stride(2);
int64_t oStrideB = out.stride(0);
int64_t oStrideM = out.stride(1);
int64_t oStrideH = out.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t grad_qStrideB = grad_q.stride(0);
int64_t grad_qStrideM = grad_q.stride(1);
int64_t grad_qStrideH = grad_q.stride(2);
int64_t grad_kStrideB = grad_k.stride(0);
int64_t grad_kStrideN = grad_k.stride(1);
int64_t grad_kStrideH = grad_k.stride(2);
int64_t grad_vStrideB = grad_v.stride(0);
int64_t grad_vStrideN = grad_v.stride(1);
int64_t grad_vStrideH = grad_v.stride(2);
int64_t grad_oStrideB = grad_out.stride(0);
int64_t grad_oStrideM = grad_out.stride(1);
int64_t grad_oStrideH = grad_out.stride(2);
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* attn */ qSplitSize * kvSplitSize +
/* grad_attn */ qSplitSize * kvSplitSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
// allocate per thread temp buf_reduced (scalar type)
// buf2 is only needed for bfloat16 and float16
int64_t size_per_thread_reduced =
/* attn_reduced */ qSplitSize * kvSplitSize +
/* grad_attn_reduced */ qSplitSize * kvSplitSize;
at::Tensor buf_reduced = at::empty({num_thread, is_reduced_type ? size_per_thread_reduced : 0}, query.options());
scalar_t* grad_q_data = grad_q.data_ptr<scalar_t>();
scalar_t* grad_k_data = grad_k.data_ptr<scalar_t>();
scalar_t* grad_v_data = grad_v.data_ptr<scalar_t>();
scalar_t* grad_out_data = grad_out.data_ptr<scalar_t>();
scalar_t* q_data = query.data_ptr<scalar_t>();
scalar_t* k_data = key.data_ptr<scalar_t>();
scalar_t* v_data = value.data_ptr<scalar_t>();
scalar_t* out_data = out.data_ptr<scalar_t>();
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0;
data_index_init(begin, i, batchSize, j, num_head);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* attn_data = buf_ptr;
accum_t* grad_attn_data = attn_data + qSplitSize * kvSplitSize;
scalar_t* buf_reduced_ptr = is_reduced_type ? buf_reduced_data + ompIdx * size_per_thread_reduced : nullptr;
scalar_t* attn_reduced_data = is_reduced_type ? buf_reduced_ptr : nullptr;
scalar_t* grad_attn_reduced_data = is_reduced_type ? attn_reduced_data + qSplitSize * kvSplitSize : nullptr;
at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype));
accum_t* dsum_data = dsum.data_ptr<accum_t>();
for (const auto z : c10::irange(begin, end)) {
(void)z; // Suppress unused variable
// rowsum of grad_out * out
for (int64_t m = 0; m < qSize; m += qSplitSize) {
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
// dsum <- rowsum(grad_out * out)
for (const auto row : c10::irange(qBlockSize)) {
*(dsum_data + row) = vec::map2_reduce_all<scalar_t>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
grad_out_data + i * grad_oStrideB + j * grad_oStrideH + (m + row) * grad_oStrideM,
out_data + i * oStrideB + j * oStrideH + (m + row) * oStrideM,
headSize);
}
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
// attn <- scale * q @ k.T
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
scaling_factor,
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
static_cast<accum_t>(0),
attn_data,
kvBlockSize);
// restore self attention after softmax from logsumexp
// attn <- exp(attn - normalizer)
for (const auto row : c10::irange(qBlockSize)) {
accum_t normalizer = lse_data[i * lStrideB + j * lStrideH + (m + row) * lStrideM];
vec::map<accum_t>(
[normalizer](Vec x) { return (x - Vec(normalizer)).exp(); },
attn_data + row * kvBlockSize,
attn_data + row * kvBlockSize,
kvBlockSize);
}
// Apply causal mask, filled unused with 0
if (is_causal && num_keys - n <= kvSplitSize) {
for (const auto row : c10::irange(qBlockSize)) {
int64_t last_col = m + row - n;
accum_t* row_ptr = attn_data + row * kvBlockSize;
fill_stub(row_ptr + last_col + 1, static_cast<accum_t>(0), kvBlockSize - last_col - 1);
}
}
if (is_reduced_type) {
for (const auto row : c10::irange(qBlockSize)) {
convert<accum_t, scalar_t>(
attn_data + row * kvBlockSize,
attn_reduced_data + row * kvBlockSize,
kvBlockSize);
}
}
// grad_v <- grad_v + attn.T @ grad_out
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::Transpose,
headSize,
kvBlockSize,
qBlockSize,
static_cast<accum_t>(1),
grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
m * grad_oStrideM,
grad_oStrideM,
conditional_data_ptr(attn_data, attn_reduced_data),
kvBlockSize,
static_cast<accum_t>(1),
grad_v_data + i * grad_vStrideB + j * grad_vStrideH +
n * grad_vStrideN,
grad_vStrideN);
// grad_attn <- grad_out @ v.T
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
static_cast<accum_t>(1),
v_data + i * vStrideB + j * vStrideH +
n * vStrideN,
vStrideN,
grad_out_data + i * grad_oStrideB + j * grad_oStrideH +
m * grad_oStrideM,
grad_oStrideM,
static_cast<accum_t>(0),
grad_attn_data,
kvBlockSize);
// grad_attn <- attn * (grad_attn - dsum)
for (const auto row : c10::irange(qBlockSize)) {
accum_t d = *(dsum_data + row);
vec::map2<accum_t>(
[d](Vec attn, Vec grad_attn) { return attn * (grad_attn - Vec(d)); },
grad_attn_data + row * kvBlockSize,
attn_data + row * kvBlockSize,
grad_attn_data + row * kvBlockSize,
kvBlockSize);
}
if (is_reduced_type) {
for (const auto row : c10::irange(qBlockSize)) {
convert<accum_t, scalar_t>(
grad_attn_data + row * kvBlockSize,
grad_attn_reduced_data + row * kvBlockSize,
kvBlockSize);
}
}
// grad_q <- grad_q + scale * grad_attn @ k
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
headSize,
qBlockSize,
kvBlockSize,
scaling_factor,
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
kvBlockSize,
static_cast<accum_t>(1),
grad_q_data + i * grad_qStrideB + j * grad_qStrideH +
m * grad_qStrideM,
grad_qStrideM);
// grad_k <- grad_k + scale * grad_attn.T @ q
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::Transpose,
headSize,
kvBlockSize,
qBlockSize,
scaling_factor,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
conditional_data_ptr(grad_attn_data, grad_attn_reduced_data),
kvBlockSize,
static_cast<accum_t>(1),
grad_k_data + i * grad_kStrideB + j * grad_kStrideH +
n * grad_kStrideN,
grad_kStrideN);
}
}
// Move to the next query
data_index_step(i, batchSize, j, num_head);
}
});
}
void flash_attention_kernel_impl(
const Tensor& output,
const Tensor& logsumexp,
const Tensor& cum_seq_q,
const Tensor& cum_seq_k,
int64_t& max_q,
int64_t& max_k,
const Tensor& philox_seed,
const Tensor& philox_offset,
const Tensor& debug_attn_mask,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<double> scale) {
auto q_seq_len = query.size(2);
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, query.scalar_type(), "flash_attention", [&] {
if (q_seq_len >= 768) {
cpu_flash_attention<scalar_t, 256, 512>(
output, logsumexp, cum_seq_q, cum_seq_k,
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<scalar_t, 64, 512>(
output, logsumexp, cum_seq_q, cum_seq_k,
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
} else {
cpu_flash_attention<scalar_t, 32, 512>(
output, logsumexp, cum_seq_q, cum_seq_k,
max_q, max_k, philox_seed, philox_offset, debug_attn_mask,
query, key, value, dropout_p, is_causal, return_debug_mask, scale);
}
});
}
void flash_attention_backward_kernel_impl(
const at::Tensor& grad_q,
const at::Tensor& grad_k,
const at::Tensor& grad_v,
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
const Tensor& cum_seq_q,
const Tensor& cum_seq_k,
const int64_t max_q,
const int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
c10::optional<double> scale) {
// make sure grad_out has no zero strides (broadcasted dimensions)
// since we are going to call gemm next
// zero stride in leading dimension would lead to slow impl for gemm
auto grad_out_contig = grad_out.contiguous();
auto q_seq_len = query.size(1);
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, query.scalar_type(), "flash_attention_backward", [&] {
if (q_seq_len >= 768) {
cpu_flash_attention_backward<scalar_t, 256, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
is_causal, philox_seed, philox_offset, scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention_backward<scalar_t, 64, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
is_causal, philox_seed, philox_offset, scale);
} else {
cpu_flash_attention_backward<scalar_t, 32, 512>(
grad_q, grad_k, grad_v, grad_out_contig,
query, key, value, out, logsumexp,
cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
is_causal, philox_seed, philox_offset, scale);
}
});
}
} // anonymous namespace
ALSO_REGISTER_AVX512_DISPATCH(flash_attention_kernel, &flash_attention_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(flash_attention_backward_kernel, &flash_attention_backward_kernel_impl);
} // at::native