forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TemporalMaxPooling.cu
86 lines (74 loc) · 3.64 KB
/
TemporalMaxPooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <THCUNN/THCUNN.h>
#include <THCUNN/common.h>
#include <TH/THHalf.h>
#include <THC/THCNumerics.cuh>
#include <THC/THCAtomics.cuh>
#include <THC/THCTensor.hpp>
#include <THC/THCStorage.hpp>
#define TEMPORAL_MAX_POOLING_THREADS 1024
template <typename Dtype>
__global__ void cunn_TemporalMaxPooling_updateOutputKernel(Dtype *input, Dtype *output, THCIndex_t *indices, int input_w, int input_n, int output_w, int kW, int dW) {
// Block idx is the batch index, thread idx + block idx y * MAX_THREADS is the time index
Dtype *input_data = input + blockIdx.x * input_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n * dW;
Dtype *output_data = output + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
THCIndex_t *indices_data = indices + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
int feat = 0;
int time = 0;
int max_time = input_n * kW;
Dtype max_value;
THCIndex_t max_index = 0;
if (threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS < output_w) {
// For all features
for (feat = 0; feat < input_n; ++feat) {
max_value = THCNumerics<Dtype>::min();
// For all values in the kernel space
for (time = 0; time < max_time; time += input_n) {
if (max_value < input_data[time + feat]) {
max_value = input_data[time + feat];
max_index = time / input_n;
}
}
output_data[feat] = max_value;
indices_data[feat] = max_index;
}
}
}
template <typename Dtype>
__global__ void cunn_TemporalMaxPooling_updateGradInputKernel(Dtype *gradInput, Dtype *gradOutput, THCIndex_t *indices, int input_w, int input_n, int output_w, int kW, int dW) {
// Block idx is the batch index, thread idx + block idx y * MAX_THREADS is the time index
Dtype *gradInput_data = gradInput + blockIdx.x * input_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n * dW;
Dtype *gradOutput_data = gradOutput + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
THCIndex_t *indices_data = indices + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
int feat = 0;
if (threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS < output_w) {
// For all features
for (feat = 0; feat < input_n; ++feat) {
gradInput_data[indices_data[feat] * input_n + feat] += gradOutput_data[feat];
}
}
}
template <typename Dtype>
__global__ void cunn_TemporalMaxPooling_updateGradInputKernelAtomic(Dtype *gradInput, Dtype *gradOutput, THCIndex_t *indices, int input_w, int input_n, int output_w, int kW, int dW) {
// Block idx is the batch index, thread idx + block idx y * MAX_THREADS is the time index
Dtype *gradInput_data = gradInput + blockIdx.x * input_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n * dW;
Dtype *gradOutput_data = gradOutput + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
THCIndex_t *indices_data = indices + blockIdx.x * output_w * input_n + (
threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS) * input_n;
int feat = 0;
if (threadIdx.x + blockIdx.y * TEMPORAL_MAX_POOLING_THREADS < output_w) {
// For all features
for (feat = 0; feat < input_n; ++feat) {
atomicAdd(&gradInput_data[indices_data[feat] * input_n + feat], gradOutput_data[feat]);
}
}
}
#include <THCUNN/generic/TemporalMaxPooling.cu>
#include <THC/THCGenerateFloatTypes.h>