forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCTensorMath.cuh
130 lines (107 loc) · 4.25 KB
/
THCTensorMath.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#ifndef THC_TENSORMATH_CUH
#define THC_TENSORMATH_CUH
// Copy the kth diagonal of a matrix B to a vector A.
template <typename T>
__global__ void THCTensor_copyFromDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideA) {
for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < size;
linearIndex += gridDim.x * blockDim.x) {
const ptrdiff_t bOffset = start + strideSum * linearIndex;
a[strideA * linearIndex] = b[bOffset];
}
}
// Copy vector B to the kth diagonal of a matrix A
template <typename T>
__global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideB) {
for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < size;
linearIndex += gridDim.x * blockDim.x) {
const ptrdiff_t aOffset = start + strideSum * linearIndex;
a[aOffset] = b[strideB * linearIndex];
}
}
#define CAT_ARRAY_BATCH_SIZE 1024
#define CAT_ARRAY_MAX_INPUT_DIMS 4
inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) {
int curDevice = -1;
cudaGetDevice(&curDevice);
if (curDevice == -1) {
return false;
}
// Assume a reasonable number of SMs if no state is available
int numSM =
state ? THCState_getCurrentDeviceProperties(state)->multiProcessorCount : 15;
//X dim of grid for cat array cooperates on a single tensor in the cat.
//Given half of the GPU, full utilization will always occur.
grid = dim3( 2LL * numSM, (long long) nTensors );
return true;
}
// Similar to any other IndexToOffset calculation for copying along a given dimension.
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
static inline __device__ IndexType compute(
const IndexType outputSize[Dims],
const IndexType outputStride[Dims],
const IndexType dimSize,
const unsigned int concatDim,
IndexType linearIndex) {
IndexType offset = 0;
#pragma unroll
for (int i = Dims - 1; i >= 1; --i) {
IndexType curDimSize = i == concatDim ? dimSize : outputSize[i];
IndexType nextDimIndex = linearIndex / curDimSize;
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
IndexType curDimOffset = curDimIndex * outputStride[i];
offset += curDimOffset;
linearIndex = nextDimIndex;
}
return offset + linearIndex * outputStride[0];
}
};
template <typename T, typename IndexType>
struct CatArrInputTensor {
T* input;
IndexType offset;
IndexType dimSize;
IndexType nElements;
};
template<typename IndexType, unsigned int MaxDims>
struct OutputTensorSizeStride {
IndexType outputSize[MaxDims];
IndexType outputStride[MaxDims];
};
/**
* Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a grid-stride loop based off of
* the blockIdx.x, threadIdx.x for each input to copy each element from each input tensor into the output.
*
* output: base pointer to the storage associated with the output tensor
* inputs: GPU-allocated array of input metadata for each input to concatenate in the kernel
* os: the size/stride vectors for the output tensor
* concatDim: dimension along which we are concatenating
* dimStride: the stride of the output tensor at the concatDim
*
* The most important assumption made is that the input tensors are contiguous.
*/
template <typename T, typename IndexType, int Dims>
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs[blockIdx.y].nElements;
if(tid >= nElements) return;
T* data = inputs[blockIdx.y].input;
IndexType offset = inputs[blockIdx.y].offset;
IndexType dimSize = inputs[blockIdx.y].dimSize;
IndexType dataOffset = offset * dimStride;
IndexType stride = gridDim.x * blockDim.x;
while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];
tid += stride;
}
}
#endif