forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTHCReduceAll.cuh
331 lines (284 loc) · 10.8 KB
/
THCReduceAll.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#ifndef THC_REDUCEALL_INC
#define THC_REDUCEALL_INC
//
// This file contains dimension reduction operation functions and
// kernels that work on both contiguous and non-contiguous tensor
// arguments of arbitrary (up to MAX_CUTORCH_DIMS) dimensioned
// arguments without copying or temporary storage, for reducing an
// entire tensor to one value.
//
#include "THCReduceApplyUtils.cuh"
// Size per each reduction block
#define THC_REDUCE_ALL_BLOCK_SIZE 1024L
// Cutoff size for two-pass reduction
#define THC_TWO_PASS_REDUCTION_SIZE 2048L
// Kernel that handles an entire reduction of a tensor in one pass
template <typename T,
typename IndexType,
typename AccT,
typename ModifyOp,
typename ReduceOp,
int ADims>
__global__ void
kernelReduceAll(TensorInfo<T, IndexType> in,
IndexType totalElements,
AccT init,
ModifyOp modifyOp,
ReduceOp reduceOp,
AccT* out) {
// With a block-wide stride, have each thread perform its own reduction.
AccT r = init;
for (IndexType i = threadIdx.x; i < totalElements; i += blockDim.x) {
const IndexType inOffset = IndexToOffset<T, IndexType, ADims>::get(i, in);
const AccT val = scalar_cast<AccT>(in.data[inOffset]);
r = reduceOp(r, modifyOp(val));
}
// Reduce within the block
extern __shared__ char smemChar[];
AccT* smem = (AccT*) smemChar;
r = reduceBlock(smem, blockDim.x, r, reduceOp, init);
if (threadIdx.x == 0) {
// Write out reduced value
*out = r;
}
}
template <typename IndexType>
__device__ __forceinline__ IndexType getStartIndex(IndexType totalSize) {
IndexType sizePerBlock = THCCeilDiv(totalSize, (IndexType) gridDim.x);
return blockIdx.x * sizePerBlock;
}
template <typename IndexType>
__device__ __forceinline__ IndexType getEndIndex(IndexType totalSize) {
IndexType sizePerBlock = THCCeilDiv(totalSize, (IndexType) gridDim.x);
return min((IndexType) ((blockIdx.x + 1) * sizePerBlock), totalSize);
}
// Kernel that handles an entire reduction of a tensor in two passes
template <typename T,
typename IndexType,
typename AccT,
typename ModifyOp,
typename ReduceOp,
int ADims>
__global__ void
kernelReduceAllPass1(TensorInfo<T, IndexType> in,
IndexType totalElements,
AccT init,
ModifyOp modifyOp,
ReduceOp reduceOp,
AccT* scratchSpace) {
const IndexType startIndex = getStartIndex<IndexType>(totalElements);
const IndexType endIndex = getEndIndex<IndexType>(totalElements);
// With a block-wide stride, have each thread perform its own reduction.
AccT r = init;
for (IndexType i = startIndex + threadIdx.x; i < endIndex; i += blockDim.x) {
const IndexType inOffset = IndexToOffset<T, IndexType, ADims>::get(i, in);
const AccT val = scalar_cast<AccT>(in.data[inOffset]);
r = reduceOp(r, modifyOp(val));
}
// Reduce within the block
extern __shared__ char smemChar[];
AccT* smem = (AccT*) smemChar;
r = reduceBlock(smem, blockDim.x, r, reduceOp, init);
if (threadIdx.x == 0) {
// Write out block-wide reduced value
scratchSpace[blockIdx.x] = r;
}
}
template <typename T, typename ReduceOp>
__global__ void
kernelReduceAllPass2(int numPass1Blocks,
T init,
ReduceOp reduceOp,
T* scratchSpace,
T* out) {
T r = init;
if (threadIdx.x < numPass1Blocks) {
r = scratchSpace[threadIdx.x];
}
// Reduce within the block
extern __shared__ char smemChar[];
T* smem = (T*) smemChar;
r = reduceBlock(smem, numPass1Blocks, r, reduceOp, init);
if (threadIdx.x == 0) {
*out = r;
}
}
// Perform a two-pass reduction if the tensor is large enough to
// warrant it.
inline bool isTwoPassReductionSize(ptrdiff_t elements) {
return (elements > THC_TWO_PASS_REDUCTION_SIZE);
}
template <typename T>
inline ptrdiff_t getTwoPassBlocks(THCState* state, ptrdiff_t elements) {
ptrdiff_t numBlocks = THCCeilDiv(elements, (ptrdiff_t)THC_REDUCE_ALL_BLOCK_SIZE);
// We can only have as many blocks as there is scratch space
ptrdiff_t scratchSpace =
THCState_getCurrentDeviceScratchSpaceSize(state) / sizeof(T);
THAssert(scratchSpace > 0);
// Limit to 1024 due to dimensionality constraint
if (scratchSpace > 1024) {
scratchSpace = 1024;
}
if (numBlocks > scratchSpace) {
numBlocks = scratchSpace;
}
return numBlocks;
}
// Get the block/grid size that we want
template <typename T>
inline void getPass1ReduceBlockGrid(THCState* state, ptrdiff_t elements,
dim3& grid, dim3& block) {
grid = dim3(getTwoPassBlocks<T>(state, elements));
block = dim3(THC_REDUCE_ALL_BLOCK_SIZE);
}
template <typename T>
inline void getPass2ReduceBlockGrid(THCState* state, ptrdiff_t elements,
dim3& grid, dim3& block) {
grid = dim3(1);
// We only need as many threads as there were blocks originally
block = dim3(getTwoPassBlocks<T>(state, elements));
}
inline void getSinglePassReduceBlockGrid(ptrdiff_t elements,
dim3& grid, dim3& block) {
grid = dim3(1);
block = dim3(THC_REDUCE_ALL_BLOCK_SIZE);
}
template <typename T,
typename IndexType,
typename AccT,
typename ModifyOp,
typename ReduceOp,
int ADims>
void callReduceAll(THCState* state,
const TensorInfo<T, IndexType>& in,
ptrdiff_t totalElements,
AccT init,
const ModifyOp& modifyOp,
const ReduceOp& reduceOp,
AccT* devOut) {
dim3 grid;
dim3 block;
if (isTwoPassReductionSize(totalElements)) {
void* scratchSpace = THCudaMalloc(state, THCState_getCurrentDeviceScratchSpaceSize(state));
getPass1ReduceBlockGrid<AccT>(state, totalElements, grid, block);
size_t smemSize = block.x * sizeof(AccT);
kernelReduceAllPass1<T, IndexType, AccT, ModifyOp, ReduceOp, ADims>
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
in, (IndexType) totalElements, init, modifyOp, reduceOp,
(AccT*) scratchSpace);
int numPass1Blocks = grid.x;
getPass2ReduceBlockGrid<AccT>(state, totalElements, grid, block);
smemSize = block.x * sizeof(AccT);
kernelReduceAllPass2<AccT, ReduceOp>
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
numPass1Blocks, init, reduceOp,
(AccT*) scratchSpace, devOut);
THCudaFree(state, scratchSpace);
} else {
getSinglePassReduceBlockGrid(totalElements, grid, block);
size_t smemSize = block.x * sizeof(AccT);
kernelReduceAll<T, IndexType, AccT, ModifyOp, ReduceOp, ADims>
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
in, (IndexType) totalElements, init, modifyOp, reduceOp, devOut);
}
}
// Reduces the entire tensor to one value. `out` points to
// host-resident memory.
template <typename ScalarType,
typename TensorType,
typename ModifyOp,
typename ReduceOp,
typename AccT>
bool THC_reduceAll(THCState* state,
TensorType* in,
const ModifyOp& modifyOp,
const ReduceOp& reduceOp,
AccT init,
AccT* out,
int outOnDevice) {
ptrdiff_t inElements = THCTensor_nElement(state, in);
if (THCTensor_nDimensionLegacyAll(state, in) > MAX_CUTORCH_DIMS) {
return false;
}
if (THCTensor_nDimensionLegacyAll(state, in) == 0) {
// Zero-dim tensor; do nothing
*out = init;
return true;
}
bool freeDevOut = false;
AccT* devOut = out;
if (!outOnDevice) {
// Use the stream-specific scratch space for the reduction kernel
// to write out its value
devOut = static_cast<AccT*>(THCudaMalloc(state,
THCState_getCurrentDeviceScratchSpaceSize(state)));
freeDevOut = true;
}
// It is possible that the tensor dimensions are able to be collapsed,
// and thus we can reduce the actual code complexity of the copy by
// exploiting this knowledge statically, since the div/mod is the
// most expensive part of the operation, more so than memory accesses.
// For instance, when copying a non-contiguous to a contiguous tensor
// (or vice versa), the contiguous tensor can be collapsed to one
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, IN) \
callReduceAll<ScalarType, \
TYPE, AccT, ModifyOp, ReduceOp, IN>( \
state, inInfo, inElements, init, modifyOp, \
reduceOp, devOut);
#define HANDLE_IN_CASE(TYPE, IN) \
{ \
switch (IN) { \
case 1: \
HANDLE_CASE(TYPE, 1); \
break; \
case 2: \
HANDLE_CASE(TYPE, 2); \
break; \
default: \
HANDLE_CASE(TYPE, -1); \
break; \
} \
}
if (THCTensor_canUse32BitIndexMath(state, in)) {
TensorInfo<ScalarType, unsigned int> inInfo =
getTensorInfo<ScalarType, TensorType, unsigned int>(state, in);
inInfo.collapseDims();
HANDLE_IN_CASE(unsigned int, inInfo.dims);
} else {
TensorInfo<ScalarType,
uint64_t> inInfo =
getTensorInfo<ScalarType, TensorType, uint64_t>(state, in);
inInfo.collapseDims();
/*
Only instantiates the all 1D special case and the fallback all nD case for
large (64-bit indexed) tensors to reduce compilation time.
*/
if (inInfo.dims == 1) {
HANDLE_IN_CASE(uint64_t, 1);
} else {
HANDLE_IN_CASE(uint64_t, -1);
}
}
#undef HANDLE_CASE
#undef HANDLE_IN_CASE
// If our destination is not on the device, copy the value back to
// the host (synchronous!)
if (!outOnDevice) {
cudaStream_t stream = THCState_getCurrentStream(state);
THCudaCheck(cudaMemcpyAsync(out,
devOut,
sizeof(AccT),
cudaMemcpyDeviceToHost,
stream));
THCudaCheck(cudaStreamSynchronize(stream));
}
if (freeDevOut) {
THCudaFree(state, devOut);
}
return true;
}
#undef THC_REDUCE_ALL_BLOCK_SIZE
#undef THC_TWO_PASS_REDUCTION_SIZE
#endif // THC_REDUCEALL_INC