forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FunctionOfAMatrixUtilsKernel.cpp
57 lines (46 loc) · 1.63 KB
/
FunctionOfAMatrixUtilsKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/FunctionOfAMatrixUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <c10/util/irange.h>
#if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif
namespace at::native {
namespace {
void _compute_linear_combination_cpu_kernel(
TensorIterator& iter,
int64_t in_stride,
int64_t coeff_stride,
int64_t num_summations
) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
"_compute_linear_combination_cpu", [&] {
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* RESTRICT out_ptr = data[0];
auto* RESTRICT in_ptr = data[1];
auto* RESTRICT coeff_ptr = data[2];
for (const auto elem C10_UNUSED : c10::irange(n)) {
auto* RESTRICT out_data = reinterpret_cast<scalar_t*>(out_ptr);
auto* RESTRICT in_data = reinterpret_cast<scalar_t*>(in_ptr);
using primitive_t = typename scalar_value_type<scalar_t>::type;
auto* RESTRICT coeff_data = reinterpret_cast<primitive_t*>(coeff_ptr);
// perform summation
for (const auto i : c10::irange(num_summations)) {
*out_data += in_data[i * in_stride] * coeff_data[i * coeff_stride];
}
out_ptr += strides[0];
in_ptr += strides[1];
coeff_ptr += strides[2];
}
};
iter.for_each(loop);
});
}
}
REGISTER_DISPATCH(_compute_linear_combination_stub, &_compute_linear_combination_cpu_kernel);
} // namespace at::native