forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReduceOpsKernel.cpp
210 lines (194 loc) · 6.39 KB
/
ReduceOpsKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#include <numeric>
#include <iterator>
#include <algorithm>
#include <ATen/Dispatch.h>
#include <ATen/cpu/vec256/vec256.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Reduce.h>
#include "c10/util/Optional.h"
namespace at { namespace native { namespace {
using namespace vec256;
static void sum_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.type(), "sum", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) { return a + b; });
});
}
static void prod_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.type(), "prod", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) { return a * b; },
/*identity=*/1);
});
}
static inline int64_t round_down(int64_t a, int64_t m) {
return a - (a % m);
}
template<typename scalar_t>
struct NormReduction {
// reduction width in number of scalar elements
static constexpr int WIDTH = 128 / sizeof(scalar_t);
using Vec = Vec256<scalar_t>;
static void apply(
Tensor& res,
const Tensor& self,
Scalar p,
c10::optional<int64_t> dim) {
auto out_ = res.data<scalar_t>();
auto data_ = self.data<scalar_t>();
auto numel = self.numel();
float pval = 0.0;
if (p.isIntegral()){
pval = p.to<int64_t>();
} else if (p.isFloatingPoint()) {
pval = p.to<float>();
}
if (!dim.has_value()) {
*out_ = reduce_all(data_, numel, pval);
return;
}
int64_t n = self.size(*dim);
int64_t stride = self.stride(*dim);
// A contiguous tensor does not need to hold a meaningful stride
// if the corresponding size is 1
if (n == 1) {
stride = 1;
for (int64_t i = self.ndimension() - 1; i > *dim; i--) {
stride *= self.size(i);
}
}
int64_t batch = numel / n;
parallel_for(0, batch, 1, [=](int64_t begin, int64_t end) {
for (int64_t bi = begin; bi < end; bi++) {
int64_t b = bi / stride;
int64_t i = bi % stride;
const scalar_t* data = &data_[b * n * stride + i];
out_[bi] = norm_reduce(data, n, stride, pval);
}
});
}
static scalar_t reduce_all(const scalar_t* data_, int64_t size, float pval) {
scalar_t sum = parallel_reduce(
0,
size,
internal::GRAIN_SIZE,
(scalar_t)0,
[=](int64_t begin, int64_t end, scalar_t init) {
const scalar_t* data = &data_[begin];
int64_t n = end - begin;
scalar_t result = norm_reduce(data, n, 1, pval);
return result;
},
std::plus<scalar_t>());
return sum;
}
static scalar_t norm_reduce(const scalar_t* data, int64_t n, int64_t stride, float pval) {
scalar_t result = 0.0;
if (stride == 1 && (pval == 1 || pval == 2 || pval == 3) && n >= WIDTH) {
int64_t n_rounded = round_down(n, WIDTH);
scalar_t result1 = norm_reduce128(data, n_rounded, pval);
scalar_t result2 = norm_reduce_sequential(data + n_rounded, n - n_rounded, stride, pval);
result = std::pow(std::pow(result1, pval) + std::pow(result2, pval), 1.0/pval);
} else {
result = norm_reduce_sequential(data, n, stride, pval);
}
return result;
}
static scalar_t norm_reduce_sequential(const scalar_t* data, int64_t n, int64_t stride, float pval) {
scalar_t result = 0.0;
if (pval == 0) {
for (int64_t k = 0; k < n; k++) {
result += (data[k * stride] != 0.0);
}
} else if (pval == 1) {
for (int64_t k = 0; k < n; k++) {
result += std::abs(data[k * stride]);
}
} else if (pval == 2) {
for (int64_t k = 0; k < n; k++) {
result += data[k * stride] * data[k * stride];
}
result = std::sqrt(result);
} else if (pval == 3) {
for (int64_t k = 0; k < n; k++) {
result += std::abs(data[k * stride] * data[k * stride] * data[k * stride]);
}
result = std::pow(result, 1.0/3);
} else if (pval == INFINITY) {
for (int64_t k = 0; k < n; k++) {
result = std::abs(data[k * stride]) > result ? std::abs(data[k * stride]) : result;
}
} else if (pval == -INFINITY) {
result = INFINITY;
for (int64_t k = 0; k < n; k++) {
result = std::abs(data[k * stride]) < result ? std::abs(data[k * stride]) : result;
}
} else {
for (int64_t k = 0; k < n; k++) {
result += std::pow(std::abs(data[k * stride]), pval);
}
result = std::pow(result, 1.0/pval);
}
return result;
}
// Reduce down a column of WIDTH elements (128 bytes) with the given number n
// n is already rounded by 128
static scalar_t norm_reduce128(const scalar_t* data, int64_t n, float pval) {
scalar_t result = 0.0;
Vec acc[4] = {0.0, 0.0, 0.0, 0.0}; // 128 bytes (two cache lines)
static_assert(sizeof(acc) == 128, "accumulator should be 128 bytes");
int64_t rows = n / WIDTH;
if (pval == 1){
for (int row = 0; row < rows; row ++) {
for (int j = 0; j != 4; j++) {
auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
acc[j] = acc[j] + val.abs();
}
}
}
else if (pval == 2) {
for (int row = 0; row < rows; row ++) {
for (int j = 0; j != 4; j++) {
auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
acc[j] = acc[j] + val * val;
}
}
}
else if (pval == 3) {
for (int row = 0; row < rows; row ++) {
for (int j = 0; j != 4; j++) {
auto val = Vec::loadu(&data[row * WIDTH + j * Vec::size]);
acc[j] = acc[j] + (val * val * val).abs();
}
}
}
scalar_t buf[WIDTH] = {0};
for (int j = 0; j != 4; j++) {
acc[j].store(&buf[j * Vec::size]);
}
for (int i = 0; i < WIDTH; i++) {
result += buf[i];
}
result = std::pow(result, 1.0/pval);
return result;
}
};
static void norm_kernel_impl(
Tensor& result,
const Tensor& self,
Scalar p,
c10::optional<int64_t> dim) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "norm", [&] {
NormReduction<scalar_t>::apply(result, self, p, dim);
});
}
} // anonymous namespace
REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
REGISTER_DISPATCH(norm_kernel, &norm_kernel_impl);
}} // namespace at::native