forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RangeFactoriesKernel.cpp
77 lines (68 loc) · 2.85 KB
/
RangeFactoriesKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/RangeFactories.h>
#include <cmath>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/AccumulateType.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/Parallel.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/core/Scalar.h>
namespace at::native {
namespace {
using namespace vec;
static void arange_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_steps, const Scalar& scalar_step) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "arange_cpu", [&]() {
using accscalar_t = at::acc_type<scalar_t, false>;
auto start = scalar_start.to<accscalar_t>();
auto steps = scalar_steps.to<accscalar_t>();
auto step = scalar_step.to<accscalar_t>();
at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
int64_t idx(p_begin);
TensorIterator it(iter);
cpu_serial_kernel_vec(
it,
[start, step, &idx]() -> scalar_t {
return start + step * (idx++);
},
[start, step, &idx]() -> Vectorized<scalar_t> {
Vectorized<scalar_t> res;
res = Vectorized<scalar_t>::arange(start + step * idx, step);
idx += Vectorized<scalar_t>::size();
return res;
}, {p_begin, p_end});
});
});
}
static void linspace_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_end, int64_t steps) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "linspace_cpu", [&]() {
// step should be of double type for all integral types
using step_t = std::conditional_t<std::is_integral<scalar_t>::value, double, scalar_t>;
const scalar_t start = scalar_start.to<scalar_t>();
const scalar_t end = scalar_end.to<scalar_t>();
// Cast `end` and `start` to `step_t`, since range can be larger than scalar_t for integral types
const step_t step = (static_cast<step_t>(end) - static_cast<step_t>(start)) / (steps - 1);
int64_t halfway = steps / 2;
at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
int64_t idx(p_begin);
TensorIterator it(iter);
// Remove vectorization implementation, due to the precision issue between integer and double.
// Will not harm the performance.
cpu_serial_kernel(
it,
[start, end, step, halfway, steps, &idx]() -> scalar_t {
if (idx < halfway) {
return start + step * (idx++);
} else {
return end - step * (steps - (idx++) - 1);
}
}, {p_begin, p_end});
});
});
}
} // anonymous namespace
REGISTER_DISPATCH(arange_stub, &arange_kernel);
REGISTER_DISPATCH(linspace_stub, &linspace_kernel);
} // namespace at::native