forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AvgPoolKernel.cpp
552 lines (484 loc) · 18.8 KB
/
AvgPoolKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/Pool.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
namespace at::native {
namespace {
template <typename scalar_t, typename accscalar_t>
void cpu_avg_pool(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH,
int64_t dW, int64_t dH,
int64_t padW, int64_t padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
auto input = input_.contiguous();
auto output = output_.contiguous();
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t numel = output.numel();
int64_t ndim = input.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
int64_t input_height = input.size(-2);
int64_t input_width = input.size(-1);
int64_t output_height = output.size(-2);
int64_t output_width = output.size(-1);
// parallel on dim N, C, H, W
at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
int64_t c = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, c, channels, oh, output_height, ow, output_width);
for (const auto i : c10::irange(begin, end)) {
output_data[i] = static_cast<scalar_t>(0);
// local pointers
scalar_t* input_ptr = input_data + c * input_height * input_width;
// compute the mean of the input image...
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
if (ih0 >= ih1 || iw0 >= iw1) {
// move on to next output index
data_index_step(c, channels, oh, output_height, ow, output_width);
continue;
}
accscalar_t sum = 0;
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (ih1 - ih0) * (iw1 - iw0);
}
}
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
sum += input_ptr[ih * input_width + iw];
}
}
output_data[i] += scalar_t(sum / divide_factor);
// move on to next output index
data_index_step(c, channels, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous()) {
output_.copy_(output);
}
}
template <typename scalar_t>
void cpu_avg_pool_channels_last(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH,
int64_t dW, int64_t dH,
int64_t padW, int64_t padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_CHECK(input_.ndimension() == 4,
"average pooling with channels last format supports tensors with 4 dims");
auto memory_format = at::MemoryFormat::ChannelsLast;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_height = input.size(2);
int64_t input_width = input.size(3);
int64_t output_height = output.size(2);
int64_t output_width = output.size(3);
using Vec = vec::Vectorized<scalar_t>;
// parallel on dim N, H, W
at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
int64_t size = channels;
int64_t len = size - (size % Vec::size());
for (const auto i : c10::irange(begin, end)) {
// compute the mean of the input image...
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (ih1 - ih0) * (iw1 - iw0);
}
}
scalar_t* out = output_data + i * channels;
// Pass I: zero the out lane
int64_t d1 = 0;
for (; d1 < len; d1 += Vec::size()) {
Vec out_vec = Vec(scalar_t(0));
out_vec.store(out + d1);
}
for (; d1 < size; d1++) {
out[d1] = scalar_t(0);
}
if (ih0 >= ih1 || iw0 >= iw1) {
// move on to next output index
data_index_step(n, nbatch, oh, output_height, ow, output_width);
continue;
}
// Pass II: compute local sum
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
scalar_t* in = input_data + n * input_height * input_width * channels +
ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < len; d2 += Vec::size()) {
Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2);
out_vec.store(out + d2);
}
for (; d2 < size; d2++) {
out[d2] += in[d2];
}
}
}
// Pass III: compute local average
int64_t d3 = 0;
for (; d3 < len; d3 += Vec::size()) {
Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(divide_factor));
out_vec.store(out + d3);
}
for (; d3 < size; d3++) {
out[d3] = out[d3] / divide_factor;
}
// move on to next output index
data_index_step(n, nbatch, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
}
template <>
void cpu_avg_pool_channels_last<BFloat16>(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH,
int64_t dW, int64_t dH,
int64_t padW, int64_t padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_CHECK(input_.ndimension() == 4,
"average pooling with channels last format supports tensors with 4 dims");
auto memory_format = at::MemoryFormat::ChannelsLast;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto input_data = input.data_ptr<BFloat16>();
auto output_data = output.data_ptr<BFloat16>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_height = input.size(2);
int64_t input_width = input.size(3);
int64_t output_height = output.size(2);
int64_t output_width = output.size(3);
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
// parallel on dim N, H, W
at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
// temp buffer for sum, use float as accumulation type
// can't reuse output buffer to store sum since it is BFloat16
auto sum_arr = std::make_unique<float []>(channels);
float* sum = sum_arr.get();
int64_t size = channels;
for (const auto i : c10::irange(begin, end)) {
// compute the mean of the input image...
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (ih1 - ih0) * (iw1 - iw0);
}
}
BFloat16* out = output_data + i * channels;
// Pass I: zero the out lane
int64_t d1 = 0;
for (; d1 < size - (size % fVec::size()); d1 += fVec::size()) {
fVec sum_fvec = fVec(float(0));
sum_fvec.store(sum + d1);
}
for (; d1 < size; d1++) {
sum[d1] = float(0);
}
if (ih0 >= ih1 || iw0 >= iw1) {
// since we are not directly using output as the accumulation buffer,
// in case the kernel window is out of range, need to zero the output buffer here.
for (int64_t k = 0; k < size; k++) {
out[k] = 0;
}
// move on to next output index
data_index_step(n, nbatch, oh, output_height, ow, output_width);
continue;
}
// Pass II: compute local sum
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
BFloat16* in = input_data + n * input_height * input_width * channels +
ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) {
bVec data_bvec = bVec::loadu(in + d2);
fVec data_fvec0, data_fvec1;
std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
fVec sum_fvec0 = fVec::loadu(sum + d2) + data_fvec0;
fVec sum_fvec1 = fVec::loadu(sum + d2 + fVec::size()) + data_fvec1;
sum_fvec0.store(sum + d2);
sum_fvec1.store(sum + d2 + fVec::size());
}
for (; d2 < size; d2++) {
sum[d2] += float(in[d2]);
}
}
}
// Pass III: compute local average
int64_t d3 = 0;
for (; d3 < size - (size % bVec::size()); d3 += bVec::size()) {
fVec out_fvec0 = fVec::loadu(sum + d3) / fVec(float(divide_factor));
fVec out_fvec1 = fVec::loadu(sum + d3 + fVec::size()) / fVec(float(divide_factor));
bVec out_bvec = convert_float_bfloat16(out_fvec0, out_fvec1);
out_bvec.store(out + d3);
}
for (; d3 < size; d3++) {
out[d3] = BFloat16(sum[d3] / divide_factor);
}
// move on to next output index
data_index_step(n, nbatch, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
}
template <typename scalar_t>
void cpu_avg_pool_backward(
const Tensor& grad_input_,
const Tensor& grad_output_,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
int64_t ndim = grad_output.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
int64_t input_height = grad_input.size(-2);
int64_t input_width = grad_input.size(-1);
int64_t output_height = grad_output.size(-2);
int64_t output_width = grad_output.size(-1);
// parallel on dim of N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width;
scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width;
for (const auto oh : c10::irange(output_height)) {
for (const auto ow : c10::irange(output_width)) {
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (ih1 - ih0) * (iw1 - iw0);
}
}
scalar_t grad_delta = grad_output_ptr[oh * output_width + ow] / divide_factor;
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
grad_input_ptr[ih * input_width + iw] += grad_delta;
}
}
}
}
}
});
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
template <typename scalar_t>
void cpu_avg_pool_backward_channels_last(
const Tensor& grad_input_,
const Tensor& grad_output_,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
auto memory_format = at::MemoryFormat::ChannelsLast;
auto grad_input = grad_input_.contiguous(memory_format);
auto grad_output = grad_output_.contiguous(memory_format);
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
int64_t nbatch = grad_input.size(0);
int64_t channels = grad_input.size(1);
int64_t input_height = grad_input.size(2);
int64_t input_width = grad_input.size(3);
int64_t output_height = grad_output.size(2);
int64_t output_width = grad_output.size(3);
using Vec = vec::Vectorized<scalar_t>;
// parallel on dim N
at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
for (const auto n : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + n * input_height * input_width * channels;
scalar_t* grad_output_ptr = grad_output_data + n * output_height * output_width * channels;
for (const auto oh : c10::irange(output_height)) {
for (const auto ow : c10::irange(output_width)) {
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (ih1 - ih0) * (iw1 - iw0);
}
}
scalar_t* gout = grad_output_ptr + oh * output_width * channels + ow * channels;
int64_t size = channels;
int64_t len = size - (size % Vec::size());
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
scalar_t* gin = grad_input_ptr + ih * input_width * channels + iw * channels;
int64_t d = 0;
for (; d < len; d += Vec::size()) {
Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(divide_factor));
gin_vec.store(gin + d);
}
for (; d < size; d++) {
gin[d] += gout[d] / divide_factor;
}
}
}
}
}
}
});
if (!grad_input_.is_contiguous(memory_format)) {
grad_input_.copy_(grad_input);
}
}
void avg_pool2d_kernel_impl(
const Tensor& output,
const Tensor& input,
int64_t kW, int64_t kH,
int64_t dW, int64_t dH,
int64_t padW, int64_t padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Long, ScalarType::BFloat16, input.scalar_type(), "avg_pool2d", [&] {
if (input.scalar_type() == ScalarType::BFloat16) {
cpu_avg_pool<BFloat16, /*accscalar_t*/float>(output, input, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
} else {
cpu_avg_pool<scalar_t, scalar_t>(output, input, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
}
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Long, ScalarType::BFloat16, input.scalar_type(), "avg_pool2d_channels_last", [&] {
cpu_avg_pool_channels_last<scalar_t>(output, input, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
void avg_pool2d_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
int kW, int kH,
int dW, int dH,
int padW, int padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Long, ScalarType::BFloat16, grad_output.scalar_type(), "avg_pool2d_backward", [&] {
cpu_avg_pool_backward<scalar_t>(grad_input, grad_output, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Long, ScalarType::BFloat16, grad_output.scalar_type(), "avg_pool2d_backward_channels_last", [&] {
cpu_avg_pool_backward_channels_last<scalar_t>(grad_input, grad_output, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
} // anonymous namespace
REGISTER_DISPATCH(avg_pool2d_kernel, &avg_pool2d_kernel_impl);
REGISTER_DISPATCH(avg_pool2d_backward_kernel, &avg_pool2d_backward_kernel_impl);
} // at::native