forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy.cu
285 lines (244 loc) · 10.1 KB
/
Copy.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CachingHostAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/PeerToPeerAccess.h>
#include <ATen/native/Copy.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
namespace at::native {
void neg_kernel_cuda(TensorIteratorBase &iter);
void conj_kernel_cuda(TensorIteratorBase &iter);
void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType dtype = iter.dtype(0);
if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kHalf, kBool, kBFloat16, kComplexHalf, dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
}
}
void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_cuda", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -std::conj(x); });
});
}
using namespace at::cuda;
// device-to-device copy, does type conversion
void copy_device_to_device(TensorIterator& iter,
bool non_blocking,
bool p2p_enabled) {
int64_t numel = iter.numel();
// We can memcpy the memory if both tensors have the same type AND both
// tensors are contiguous after dimension coalescing and reordering.
bool same_type = iter.dtype(0) == iter.dtype(1);
bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj();
bool same_neg = iter.tensor(0).is_neg() == iter.tensor(1).is_neg();
bool memcpy_eligible = same_type && same_conj && same_neg && iter.is_contiguous();
Device dst_device = iter.device(0);
Device src_device = iter.device(1);
CUDAGuard device_guard(src_device);
// We always perform the copy on the source device, using the current stream
// on the source device, and we fully synchronize on both src and dst's
// current streams for completion of the copy. We have to explicitly do this
// for non-contig copies. This mimics the behavior of cross-device
// cudaMemcpyAsync on the default stream.
CUDAStream copy_stream = getCurrentCUDAStream(src_device.index());
if (src_device != dst_device) {
// This is a cross-device copy on the src current stream and dst current
// stream. We perform a two-way barrier between both devices' streams
// before the copy. This ensures that any write-after-write and
// write-after-read dependencies on the destination side are handled, so
// that no one is operating on the dst memory when we perform the copy.
// src waits on dst barrier (src already waits on src)
CUDAEvent dst_ready;
device_guard.set_device(dst_device);
dst_ready.record(getCurrentCUDAStream(dst_device.index()));
device_guard.set_device(src_device);
dst_ready.block(copy_stream);
}
if (memcpy_eligible) {
void *dst = iter.data_ptr(0);
void *src = iter.data_ptr(1);
size_t size = numel * iter.element_size(0);
if (src != dst || src_device != dst_device) {
// Due to bizarre cuda driver intricacies, copies of
// cudaMallocAsynced memory between devices that aren't
// peer-to-peer-capable need "cudaMemcpyPeerAsync".
#ifdef USE_ROCM
bool needs_pool_specific_peer_access = false;
#else
bool needs_pool_specific_peer_access = CUDACachingAllocator::get()->needsPoolSpecificPeerAccess();
#endif
bool needs_MemcpyPeer = (src_device != dst_device &&
needs_pool_specific_peer_access &&
!p2p_enabled);
if (needs_MemcpyPeer) {
AT_CUDA_CHECK(cudaMemcpyPeerAsync(
dst, dst_device.index(),
src, src_device.index(),
size, copy_stream));
} else {
AT_CUDA_CHECK(cudaMemcpyAsync(
dst, src, size,
cudaMemcpyDeviceToDevice,
copy_stream));
}
}
} else {
if (same_neg) {
if (!same_conj) {
conj_kernel_cuda(iter);
} else {
direct_copy_kernel_cuda(iter);
}
} else {
if (!same_conj) {
neg_conj_kernel_cuda(iter);
} else {
neg_kernel_cuda(iter);
}
}
}
if (src_device != dst_device) {
// dst waits on src barrier (dst already waits on dst). We cannot
// operate on dst's copy until the copy is complete.
// Still on src_device, record stream event
CUDAEvent src_ready;
src_ready.record(copy_stream);
device_guard.set_device(dst_device);
src_ready.block(getCurrentCUDAStream(dst_device.index()));
}
AT_CUDA_CHECK(cudaGetLastError());
}
static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
Device dst_device = iter.device(0);
Device src_device = iter.device(1);
if (dst_device == src_device) {
// We never require temporaries for copies on the same GPU.
TORCH_INTERNAL_ASSERT(dst_device.is_cuda() && src_device.is_cuda());
return false;
}
bool same_dtype = iter.dtype(0) == iter.dtype(1);
if (same_dtype && iter.is_contiguous()) {
// Contiguous same-dtype copies can always use cudaMemcpyAsync
return false;
} else if (dst_device.is_cuda() && src_device.is_cuda()) {
// Copies between GPUs can use the copy kernel if P2P is supported
return !p2p_enabled;
} else {
// The remaining cases require temporaries. For example, this includes
// non-contiguous copies between CPU and GPU.
return true;
}
}
static bool maybe_enable_p2p_access(Device dst_device, Device src_device) {
if (dst_device.is_cpu() || src_device.is_cpu()) {
return false;
}
return at::cuda::get_p2p_access(src_device.index(), dst_device.index());
}
static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
AT_ASSERT(iter.ntensors() == 2);
Device dst_device = iter.device(0);
Device src_device = iter.device(1);
// Enable p2p access between devices. (No-op if it involves the CPU)
bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device);
if (copy_requires_temporaries(iter, p2p_enabled)) {
// NB: this involves recursive calls to copy. Be careful that those copies
// don't require temporaries or you will cause an infinite recursion!
auto& dst = iter.tensor(0);
Tensor dst_contig;
Tensor src_contig;
// If non_blocking is true - type conversions are performed on the GPU
// for CPU-GPU copies, otherwise type conversions are performed on the CPU.
// Type conversions are performed on the src device for GPU-GPU copies.
if (iter.device_type(0) == kCUDA || non_blocking) {
dst_contig = dst.is_contiguous() ? dst : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous();
} else {
bool same_type = iter.dtype(0) == iter.dtype(1);
dst_contig = (dst.is_contiguous() && same_type) ? dst : at::empty_like(dst, iter.dtype(1), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
src_contig = iter.tensor(1).expand_as(dst).contiguous();
}
// propagate the correct conjugate bit
dst_contig._set_conj(dst.is_conj());
src_contig._set_conj(iter.tensor(1).is_conj());
dst_contig._set_neg(dst.is_neg());
src_contig._set_neg(iter.tensor(1).is_neg());
// perform a same-dtype copy on contiguous tensors
TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
dst_contig.copy_(src_contig, non_blocking);
// if necessary, copy back into dst
if (!dst_contig.is_same(dst)) {
TORCH_INTERNAL_ASSERT(dst_contig.device() == dst.device());
dst.copy_(dst_contig, non_blocking);
}
return;
}
// Copy on GPU (or between GPUs)
if (dst_device.is_cuda() && src_device.is_cuda()) {
copy_device_to_device(iter, non_blocking, p2p_enabled);
return;
}
// Copy between CPU and GPU
cuda::OptionalCUDAGuard device_guard;
cudaMemcpyKind kind;
if (dst_device.is_cuda() && src_device.is_cpu()) {
device_guard.set_device(dst_device);
kind = cudaMemcpyHostToDevice;
} else if (dst_device.is_cpu() && src_device.is_cuda()) {
device_guard.set_device(src_device);
kind = cudaMemcpyDeviceToHost;
} else {
TORCH_INTERNAL_ASSERT(false, "unsupported devices in GPU copy_()");
}
void* dst = iter.data_ptr(0);
void* src = iter.data_ptr(1);
int64_t nbytes = iter.numel() * iter.element_size(0);
CUDAStream stream = getCurrentCUDAStream();
if (non_blocking) {
AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
// we use both the storage context and the tensor data pointer as the key
// for the caching host allocator. This allows us to better attribute the
// events to the original tensor allocation correctly. The cases we seek to
// handle are:
// 1: a user can pass a pinned memory tensor with an alternative
// context, for example if allocating memory directly from the pinned memory
// allocator and constructing a tensor with torch::from_blob.
// 2: a user can pass a tensor with a different base pointer to the original
// allocation (via slicing).
const auto& dst_tensor = iter.tensor(0);
const auto& src_tensor = iter.tensor(1);
const auto& host_tensor = (dst_device == kCPU ? dst_tensor : src_tensor);
auto* ptr = (dst_device == kCPU ? dst : src);
auto* ctx = host_tensor.storage().data_ptr().get_context();
// TODO: warn on the return value.
CachingHostAllocator_recordEvent(ptr, ctx, stream);
} else {
at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream);
}
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
iter.tensor(0).conj_physical_();
}
if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) {
iter.tensor(0).neg_();
}
}
REGISTER_DISPATCH(copy_stub, ©_kernel_cuda);
} // namespace at::native