forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBinaryOpsKernel.cu
188 lines (171 loc) · 6.11 KB
/
BinaryOpsKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <THC/THCNumerics.cuh>
#include <limits>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at { namespace native {
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "add_cuda/sub_cuda", [&]() {
auto alpha = alpha_scalar.to<scalar_t>();
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a + alpha * b;
});
});
}
static void sub_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
add_kernel_cuda(iter, -alpha_scalar);
}
void div_kernel_cuda(TensorIterator& iter) {
if (!isIntegralType(iter.dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) {
// optimization for floating-point types: if the second operand is a CPU
// scalar, compute a * reciprocal(b). Note that this may lose one bit of
// precision compared to computing the division.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "div_cuda", [&]() {
auto inv_b = scalar_t(1.0 / iter.scalar_value<scalar_t>(2));
iter.remove_operand(2);
gpu_kernel(iter, [inv_b]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a * inv_b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "div_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a / b;
});
});
}
}
void mul_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
return a && b;
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "mul_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a * b;
});
});
}
}
void atan2_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "atan2_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return THCNumerics<scalar_t>::atan2(a, b);
});
});
}
void logical_xor_kernel_cuda(TensorIterator& iter) {
gpu_kernel(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
return a != b;
});
}
void lt_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "lt_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "lt_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a < b;
});
});
}
}
void le_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "le_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a <= b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "le_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a <= b;
});
});
}
}
void gt_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "gt_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a > b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "gt_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a > b;
});
});
}
}
void ge_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "ge_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a >= b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "ge_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a >= b;
});
});
}
}
void eq_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "eq_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a == b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "eq_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a == b;
});
});
}
}
void ne_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "ne_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a != b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "ne_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a != b;
});
});
}
}
REGISTER_DISPATCH(add_stub, &add_kernel_cuda);
REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda);
REGISTER_DISPATCH(div_stub, &div_kernel_cuda);
REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel_cuda);
REGISTER_DISPATCH(lt_stub, <_kernel_cuda);
REGISTER_DISPATCH(le_stub, &le_kernel_cuda);
REGISTER_DISPATCH(gt_stub, >_kernel_cuda);
REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda);
REGISTER_DISPATCH(eq_stub, &eq_kernel_cuda);
REGISTER_DISPATCH(ne_stub, &ne_kernel_cuda);
}} // namespace at::native