forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BinaryMiscOpsKernels.cu
81 lines (71 loc) · 2.72 KB
/
BinaryMiscOpsKernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/NumericUtils.h>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at::native {
void smooth_l1_kernel_cuda(TensorIteratorBase& iter, double beta) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&iter, beta]() {
scalar_t beta_val(beta);
gpu_kernel(iter, [beta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
auto z = ::abs(a - b);
return z < beta_val ? scalar_t(0.5) * z * z / beta_val : z - scalar_t(0.5) * beta_val;
});
});
}
void huber_kernel_cuda(TensorIterator& iter, double delta) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "huber_cuda", [&iter, delta] {
scalar_t delta_val(delta);
gpu_kernel(iter, [delta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
auto z = ::abs(a - b);
return z < delta_val ? scalar_t(0.5) * z * z : delta_val * (z - scalar_t(0.5) * delta_val);
});
});
}
void mse_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "mse_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
auto diff = a - b;
return diff * diff;
});
});
}
void xlogy_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
if (at::_isnan(y)){
return NAN;
}
if (x == 0){
return 0;
}
return x * std::log(y);
});
});
}
void xlog1py_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlog1py_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
if (at::_isnan(y)){
return NAN;
}
if (x == 0){
return 0;
}
return x * std::log1p(y);
});
});
}
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(huber_stub, &huber_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel_cuda);
// DO NOT ADD ANY NEW KERNELS HERE
// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.
} // namespace at::native