forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCAtomics.cuh
142 lines (121 loc) · 4.36 KB
/
THCAtomics.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#ifndef THC_ATOMICS_INC
#define THC_ATOMICS_INC
#include "THC.h"
#include "TH/THHalf.h"
#include "THCNumerics.cuh"
#include "ATen/ATen.h"
template <typename T, size_t n>
struct AtomicAddIntegerImpl;
template<typename T>
struct AtomicAddIntegerImpl<T, 1> {
inline __device__ void operator()(T *address, T val) {
uint32_t * address_as_ui =
(uint32_t *) (address - ((size_t)address & 3));
uint32_t old = *address_as_ui;
uint32_t shift = (((size_t)address & 3) * 8);
uint32_t sum;
uint32_t assumed;
do {
assumed = old;
sum = val + T((old >> shift) & 0xff);
old = (old & ~(0x000000ff << shift)) | (sum << shift);
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}
};
template<typename T>
struct AtomicAddIntegerImpl<T, 2> {
inline __device__ void operator()(T *address, T val) {
uint32_t * address_as_ui =
(uint32_t *) ((char *)address - ((size_t)address & 2));
uint32_t old = *address_as_ui;
uint32_t sum;
uint32_t newval;
uint32_t assumed;
do {
assumed = old;
sum = val + (size_t)address & 2 ? T(old >> 16) : T(old & 0xffff);
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) : (old & 0xffff0000) | sum;
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}
};
template<typename T>
struct AtomicAddIntegerImpl<T, 4> {
inline __device__ void operator()(T *address, T val) {
uint32_t * address_as_ui = (uint32_t *) (address);
uint32_t old = *address_as_ui;
uint32_t newval;
uint32_t assumed;
do {
assumed = old;
newval = val + (T)old;
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}
};
template<typename T>
struct AtomicAddIntegerImpl<T, 8> {
inline __device__ void operator()(T *address, T val) {
unsigned long long * address_as_ui = (unsigned long long *) (address);
unsigned long long old = *address_as_ui;
unsigned long long newval;
unsigned long long assumed;
do {
assumed = old;
newval = val + (T)old;
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}
};
static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomicAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomicAdd(at::Half *address, at::Half val) {
#if ((CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
unsigned int * address_as_ui =
(unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
at::Half hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
hsum = THCNumerics<at::Half>::add(hsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
#else
atomicAdd(reinterpret_cast<__half*>(address), val);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
// from CUDA C Programmic Guide
static inline __device__ void atomicAdd(double *address, double val) {
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val +
__longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000) || defined(__HIP_PLATFORM_HCC__)
#if defined(__HIP_PLATFORM_HCC__) && __hcc_workweek__ < 18312
// This needs to be defined for the host side pass
static inline __device__ void atomicAdd(double *address, double val) { }
#endif
#endif
#endif // THC_ATOMICS_INC