forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Conv_v7.cpp
983 lines (879 loc) · 37.4 KB
/
Conv_v7.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
#if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
#endif
#include <limits>
#include <vector>
#include <ATen/Config.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/Exceptions.h>
#include <ATen/native/cudnn/ConvShared.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/native/utils/ParamsHash.h>
#include <ATen/TensorUtils.h>
#include <c10/util/irange.h>
#include <functional>
#include <iterator>
#include <sstream>
#include <algorithm>
#include <memory>
#include <mutex>
#include <stdint.h>
#include <unordered_map>
// Note [behavior of cudnnFind and cudnnGet]
// You'll notice that by default, in the ConvolutionDescriptor, we do the following:
//
// AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
// if(dataType == CUDNN_DATA_HALF)
// AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
//
// Update: AT_CUDNN_CHECK is updated with AT_CUDNN_CHECK_WITH_SHAPES, which
// automatically prints tensor shapes and convolution parameters if there is
// a cuDNN exception thrown.
//
// When cudnnSetConvolutionMathType is called before cudnnGet/cudnnFind, it informs
// cudnnGet/cudnnFind to iterate/take into account both tensor core and non-tensor-core algos.
// If you don't call cudnnSetConvolutionMathType before calling cudnnGet/cudnnFind,
// cudnnGet/cudnnFind may not pick tensor core algos.
//
// Now after its run, cudnnGet/cudnnFind comes up with the best pair of algo+mathType
// with all the initial knowledge its given. It then becomes the user's responsibility
// to update mathType of the convolution descriptor and call the subsequent cudnn calls with
// the best algo and the updated descriptor. If we don't update the descriptor but just run
// with the best algo, under the hood, cudnn will run with the slower kernel
// since it sees fastest algorithm combination with a sub optimal mathType.
// Note [blocklist fft algorithms for strided dgrad]
// This is a workaround for a CuDNN bug that gave wrong results in certain strided convolution
// gradient setups. Check Issue #16610 for bug details. Bug is there for CUDNN version < 7.5 .
constexpr size_t operator "" _TiB(unsigned long long n) {
return size_t(n) * 1024 * 1024 * 1024 * 1024;
}
namespace at { namespace native {
// Convenience struct for passing around descriptors and data
// pointers
struct ConvolutionArgs {
cudnnHandle_t handle;
ConvolutionParams params;
TensorDescriptor idesc, odesc;
FilterDescriptor wdesc;
const Tensor& input, output, weight;
ConvolutionDescriptor cdesc;
ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) {
}
};
std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args) {
out << repro_from_args(args.params) // already has a trailing newline
<< args.params // already has a trailing newline
<< "input: " << args.idesc // already has a trailing newline
<< "output: " << args.odesc // already has a trailing newline
<< "weight: " << args.wdesc // already has a trailing newline
<< "Pointer addresses: " << "\n"
<< " input: " << args.input.data_ptr() << "\n"
<< " output: " << args.output.data_ptr() << "\n"
<< " weight: " << args.weight.data_ptr() << "\n";
return out;
}
// ---------------------------------------------------------------------
//
// Benchmarking
//
// ---------------------------------------------------------------------
// TODO: Use something less heavy duty than a big honking mutex
template <typename T>
struct BenchmarkCache {
std::mutex mutex;
std::unordered_map<ConvolutionParams, T, ParamsHash<ConvolutionParams>, ParamsEqual<ConvolutionParams>> map;
bool find(const ConvolutionParams& params, T* results) {
std::lock_guard<std::mutex> guard(mutex);
auto it = map.find(params);
if (it == map.end()) {
return false;
}
*results = it->second;
return true;
}
void insert(const ConvolutionParams& params, const T& results) {
std::lock_guard<std::mutex> guard(mutex);
map[params] = results;
}
};
BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos;
BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos;
BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos;
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
// Sometimes cuDNN returns a workspace size > 2^63, this could makes the allocation of
// workspace fail with some 64bit indexing error instead of an OOM error. In such case,
// we manually fail with OOM.
TORCH_CHECK_WITH(OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
c10::cuda::CUDACachingAllocator::raw_delete(data);
}
}
size_t size;
void* data;
};
template<typename perf_t>
struct algorithm_search {
};
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionFwdAlgo_t algo, size_t* sz)
{
return cudnnGetConvolutionForwardWorkspaceSize(
args.handle,
args.idesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
algo,
sz
);
}
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdDataAlgo_t algo, size_t* sz)
{
return cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle,
args.wdesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.idesc.desc(),
algo,
sz);
}
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz)
{
return cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle,
args.idesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.wdesc.desc(),
algo,
sz);
}
template<typename algo_t>
size_t getMaxWorkspaceSize(
const ConvolutionArgs& args,
const algo_t *algo, int n_algo)
{
size_t max_ws_size = 0;
size_t max_block_size = 0;
const auto device = c10::cuda::current_device();
// For the native allocator, retrieves the size of the largest unused block.
// For cudaMallocAsync, see c10/cuda/CUDAMallocAsync.cpp:cacheInfo for details.
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
for (const auto i : c10::irange(n_algo)) {
cudnnStatus_t err;
size_t sz;
err = getWorkspaceSize(args, algo[i], &sz);
if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > max_block_size)
continue;
max_ws_size = sz;
}
return max_ws_size;
}
template<typename perf_t>
std::vector<perf_t> getValidAlgorithms(perf_t *perfResults, const ConvolutionArgs& args, int n_algo) {
// See Note [blocklist fft algorithms for strided dgrad]
#if CUDNN_VERSION < 7500
bool blocklist = std::is_same<decltype(perfResults[0].algo), cudnnConvolutionBwdDataAlgo_t>::value;
int stride_dim = args.input.dim() - 2;
blocklist &= std::any_of(std::begin(args.params.stride),
std::begin(args.params.stride) + stride_dim,
[=](int n){return n != 1;});
#endif
std::vector<perf_t> result;
result.reserve(n_algo);
for (const auto i : c10::irange(n_algo)) {
perf_t perf = perfResults[i];
// TODO: Shouldn't all returned results be successful?
// Double check documentation for cudnnFindConvolutionForwardAlgorithmEx
if (perf.status == CUDNN_STATUS_SUCCESS) {
if (!args.params.deterministic || perf.determinism == CUDNN_DETERMINISTIC) {
// See Note [blocklist fft algorithms for strided dgrad]
#if CUDNN_VERSION < 7500
bool skip = blocklist;
skip &= (static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[i].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT);
if (skip) {
continue;
}
#endif
result.push_back(perf);
}
}
}
TORCH_CHECK(result.size() > 0, "no valid convolution algorithms available in CuDNN");
return result;
}
template<>
struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
using algo_t = cudnnConvolutionFwdAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
static BenchmarkCache<perf_t>& cache() { return fwd_algos; }
static std::vector<perf_t> findAlgorithms(const ConvolutionArgs& args, bool benchmark) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
};
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution forward algorithms");
int perf_count;
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
if (!benchmark) {
AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionForwardAlgorithm_v7(
args.handle,
args.idesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
num_algos,
&perf_count,
perf_results.get()), args);
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionForwardAlgorithmEx(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
args.wdesc.desc(), args.weight.data_ptr(),
args.cdesc.desc(),
args.odesc.desc(), args.output.data_ptr(),
num_algos,
&perf_count,
perf_results.get(),
ws.data,
ws.size), args);
// Free the cached blocks in our caching allocator. They are
// needed here because the above benchmarking uses a huge amount of memory,
// e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache();
}
return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
}
static void getWorkspaceSize(
const ConvolutionArgs& args,
algo_t algo, size_t* workspaceSize)
{
AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionForwardWorkspaceSize(
args.handle,
args.idesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
algo,
workspaceSize), args);
}
};
template<>
struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
using algo_t = cudnnConvolutionBwdDataAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
static BenchmarkCache<perf_t>& cache() { return bwd_data_algos; }
static std::vector<perf_t> findAlgorithms(const ConvolutionArgs& args, bool benchmark) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
};
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution backward data algorithms.");
int perf_count;
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
if (!benchmark) {
AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardDataAlgorithm_v7(
args.handle,
args.wdesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.idesc.desc(),
num_algos,
&perf_count,
perf_results.get()), args);
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardDataAlgorithmEx(
args.handle,
args.wdesc.desc(), args.weight.data_ptr(),
args.odesc.desc(), args.output.data_ptr(),
args.cdesc.desc(),
args.idesc.desc(), args.input.data_ptr(),
num_algos,
&perf_count,
perf_results.get(),
ws.data,
ws.size), args);
// Free the cached blocks in our caching allocator. They are
// needed here because the above benchmarking uses a huge amount of memory,
// e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache();
}
return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
}
static void getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdDataAlgo_t algo, size_t* workspaceSize)
{
AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle,
args.wdesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.idesc.desc(),
algo,
workspaceSize), args);
}
};
template<>
struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
using algo_t = cudnnConvolutionBwdFilterAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
static BenchmarkCache<perf_t>& cache() { return bwd_filter_algos; }
static std::vector<perf_t> findAlgorithms(const ConvolutionArgs& args, bool benchmark) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
};
// NOTE: - 1 because ALGO_WINOGRAD is not implemented
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution backward filter algorithms.");
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
int perf_count;
if (!benchmark) {
AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
args.handle,
args.idesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.wdesc.desc(),
num_algos,
&perf_count,
perf_results.get()), args);
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardFilterAlgorithmEx(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
args.odesc.desc(), args.output.data_ptr(),
args.cdesc.desc(),
args.wdesc.desc(), args.weight.data_ptr(),
num_algos,
&perf_count,
perf_results.get(),
ws.data,
ws.size), args);
// Free the cached blocks in our caching allocator. They are
// needed here because the above benchmarking uses a huge amount of memory,
// e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache();
}
return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
}
static void getWorkspaceSize(const ConvolutionArgs& args, algo_t algo, size_t* workspaceSize)
{
AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle,
args.idesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.wdesc.desc(),
algo,
workspaceSize), args);
}
};
template<typename perf_t>
class AlgoIterator {
using search = algorithm_search<perf_t>;
const ConvolutionArgs &args;
bool benchmark;
public:
AlgoIterator(const ConvolutionArgs &args, bool benchmark): args(args), benchmark(benchmark) {}
static std::vector<perf_t> onlyDefaultAlgorithm(const ConvolutionArgs &args) {
std::vector<perf_t> perfResults(1);
perfResults[0].algo = search::DEFAULT_ALGO;
if (args.params.dataType == CUDNN_DATA_HALF) {
perfResults[0].mathType = CUDNN_TENSOR_OP_MATH;
} else {
perfResults[0].mathType = CUDNN_DEFAULT_MATH;
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) {
perfResults[0].mathType = CUDNN_FMA_MATH;
}
#endif
}
search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory));
return perfResults;
}
void try_all(std::function<void (const perf_t &perf)> f) {
bool only_use_default = args.params.deterministic && !benchmark;
auto& cache = search::cache();
perf_t algoPerf;
if (!only_use_default && cache.find(args.params, &algoPerf)) {
try {
f(algoPerf);
return;
} catch (c10::OutOfMemoryError &e) {
cudaGetLastError(); // clear CUDA error
}
}
auto perfResults = only_use_default ? onlyDefaultAlgorithm(args) : search::findAlgorithms(args, benchmark);
for (auto &algoPerf : perfResults) {
try {
f(algoPerf);
cache.insert(args.params, algoPerf);
return;
} catch (c10::OutOfMemoryError &e) {
cudaGetLastError(); // clear CUDA error
} catch (c10::CuDNNError &e) {
cudaGetLastError(); // clear CUDA error
}
}
TORCH_CHECK(false, "Unable to find a valid cuDNN algorithm to run convolution");
}
};
inline Tensor allocate_workspace(size_t size, const Tensor &other) {
// Sometimes cuDNN returns a workspace size > 2^63, this could makes the allocation of
// workspace fail with some 64bit indexing error instead of an OOM error. In such case,
// we manually fail with OOM.
TORCH_CHECK_WITH(OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
return at::empty({static_cast<int64_t>(size)}, other.options().dtype(kByte));
}
// NOTE [ raw_cudnn_convolution_forward_out ]
//
// - raw_cudnn_convolution_forward_out (Tensor)
// Functiont that handles tensors that are too large to use 32bit indexing.
// It just split the tensor and dispatches to `raw_cudnn_convolution_forward_out_32bit`.
//
// - raw_cudnn_convolution_forward_out_32bit (Tensor)
// Low level function which invokes CuDNN, and takes an output
// tensor which is directly written to (thus _out).
//
// ---------------------------------------------------------------------
//
// Splitting to 32bit
//
// ---------------------------------------------------------------------
template <typename func_t>
static inline void split_batch_dim_to_32bit_out(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32,
int64_t max_worksize, func_t func_32bit) {
constexpr int64_t int_max = std::numeric_limits<int>::max();
const int64_t ni = input.numel();
const int64_t no = output.numel();
// Assume the shape of the tensor is (N, C, D1, D2, ...)
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
if (ni <= int_max && no <= int_max) {
func_32bit(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return;
}
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension
//
// Here we use a simple heuristics to determine the size of each split
// We don't max out the 2^31 address space because this number is super
// large and very likely to get an OOM.
int64_t n = output.size(0);
int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
int64_t split_size = std::max<int64_t>(max_worksize / max_inner_size, 1L);
int64_t num_splits = (n + split_size - 1) / split_size;
if (split_size * max_inner_size < int_max) {
for (const auto i : c10::irange(num_splits)) {
int64_t start = split_size * i;
int64_t split_size_ = std::min<int64_t>(split_size, n - start);
Tensor input_ = input.narrow(0, start, split_size_);
Tensor output_ = output.narrow(0, start, split_size_);
func_32bit(output_, input_, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
}
return;
}
// If control flow reaches here, this means even splitting N is not enough, then things starts to become complicated:
// For example, for conv2d, there following questions needs to be considered.
// - Is the memory layout NCHW or NHWC ?
// - If the conv is NCHW -> NC'H'W', then should we
// - split only NC?
// - split only N'C'?
// - split both?
// - If the conv is NHWC, then we need to split across H, we need to be very careful about the boundary condition
// to make sure that the boundary is handled correctly.
// - If we decide to make these splits, is the memory contiguous? Do we need to copy the memory?
// Considering the complexity of this issue, it is better not to use cuDNN for this case
TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
}
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
#define ASSERT_CORRECT_PRECISION(math_type) \
if (args.params.dataType == CUDNN_DATA_FLOAT) { \
TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \
}
#else
#define ASSERT_CORRECT_PRECISION(math_type)
#endif // CUDNN_VERSION >= 8000
// ---------------------------------------------------------------------
//
// Convolution forward / Transposed convolution backward
//
// ---------------------------------------------------------------------
void raw_cudnn_convolution_forward_out_32bit(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
auto dataType = getCudnnDataType(input);
ConvolutionArgs args{ input, output, weight };
args.handle = getCudnnHandle();
at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, weight);
setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format);
args.idesc.set(input, memory_format);
args.wdesc.set(weight, memory_format, 0);
args.odesc.set(output, memory_format);
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);
// TODO: when we do legacy group convolution support, we'll repeatedly
// reinitialize the workspace for each convolution we do. This is
// wasteful; we'd rather reuse the workspace. OTOH, legacy group
// convolution support is already pretty slow, so this might not
// matter. (This applies to raw_cudnn_convolution_backward_input as well.)
AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark).try_all(
[&](const cudnnConvolutionFwdAlgoPerf_t &fwdAlgPerf){
Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
// whether to use Tensor core kernels or not
// See Note [behavior of cudnnFind and cudnnGet]
ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType), args);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionForward(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.wdesc.desc(), weight.data_ptr(),
args.cdesc.desc(), fwdAlgPerf.algo, workspace.data_ptr(), fwdAlgPerf.memory,
&zero, args.odesc.desc(), output.data_ptr()),
args, "Forward algorithm: ", static_cast<int>(fwdAlgPerf.algo), "\n");
}
);
}
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_forward_out(
#else
void raw_cudnn_convolution_forward_out_v7(
#endif
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit);
}
// ---------------------------------------------------------------------
//
// Convolution backward / Transposed convolution forward
//
// ---------------------------------------------------------------------
void raw_cudnn_convolution_backward_input_out_32bit(
const at::Tensor& grad_input,
const at::Tensor& grad_output,
const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
auto dataType = getCudnnDataType(grad_output);
ConvolutionArgs args{ grad_input, grad_output, weight };
args.handle = getCudnnHandle();
at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(grad_input, weight);
setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format);
args.idesc.set(grad_input, memory_format);
args.wdesc.set(weight, memory_format, 0);
args.odesc.set(grad_output, memory_format);
args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);
AlgoIterator<cudnnConvolutionBwdDataAlgoPerf_t>(args, benchmark).try_all(
[&](const cudnnConvolutionBwdDataAlgoPerf_t &bwdDataAlgPerf){
Tensor workspace = allocate_workspace(bwdDataAlgPerf.memory, grad_output);
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
// whether to use Tensor core kernels or not
// See Note [behavior of cudnnFind and cudnnGet]
ASSERT_CORRECT_PRECISION(bwdDataAlgPerf.mathType);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType), args);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionBackwardData(
args.handle,
&one, args.wdesc.desc(), weight.data_ptr(),
args.odesc.desc(), grad_output.data_ptr(),
args.cdesc.desc(), bwdDataAlgPerf.algo, workspace.data_ptr(), bwdDataAlgPerf.memory,
&zero, args.idesc.desc(), grad_input.mutable_data_ptr()),
args,
"Additional pointer addresses: \n",
" grad_output: ", grad_output.data_ptr(), "\n",
" grad_input: ", grad_input.mutable_data_ptr(), "\n",
"Backward data algorithm: ", static_cast<int>(bwdDataAlgPerf.algo), "\n");
}
);
}
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_backward_input_out(
#else
void raw_cudnn_convolution_backward_input_out_v7(
#endif
const at::Tensor& grad_input,
const at::Tensor& grad_output,
const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
split_batch_dim_to_32bit_out(grad_input, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 128, raw_cudnn_convolution_backward_input_out_32bit);
}
// ---------------------------------------------------------------------
//
// Convolution backward (weight)
//
// ---------------------------------------------------------------------
void raw_cudnn_convolution_backward_weight_out_32bit(
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
auto dataType = getCudnnDataType(input);
ConvolutionArgs args{ input, grad_output, grad_weight };
args.handle = getCudnnHandle();
at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, grad_weight);
setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format);
args.idesc.set(input, memory_format);
args.wdesc.set(grad_weight, memory_format, 0);
args.odesc.set(grad_output, memory_format);
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);
AlgoIterator<cudnnConvolutionBwdFilterAlgoPerf_t>(args, benchmark).try_all(
[&](const cudnnConvolutionBwdFilterAlgoPerf_t &bwdFilterAlgPerf){
Tensor workspace = allocate_workspace(bwdFilterAlgPerf.memory, input);
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
// whether to use Tensor core kernels or not
// See Note [behavior of cudnnFind and cudnnGet]
ASSERT_CORRECT_PRECISION(bwdFilterAlgPerf.mathType);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType), args);
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionBackwardFilter(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.odesc.desc(), grad_output.data_ptr(),
args.cdesc.desc(), bwdFilterAlgPerf.algo, workspace.data_ptr(), bwdFilterAlgPerf.memory,
&zero, args.wdesc.desc(), grad_weight.data_ptr()),
args,
"Additional pointer addresses: \n",
" grad_output: ", grad_output.data_ptr(), "\n",
" grad_weight: ", grad_weight.data_ptr(), "\n",
"Backward filter algorithm: ", static_cast<int>(bwdFilterAlgPerf.algo), "\n");
}
);
}
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_backward_weight_out(
#else
void raw_cudnn_convolution_backward_weight_out_v7(
#endif
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32) {
constexpr int64_t int_max = std::numeric_limits<int>::max();
const int64_t ni = input.numel();
const int64_t no = grad_output.numel();
// Assume the shape of the tensor is (N, C, D1, D2, ...)
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
if (ni <= int_max && no <= int_max) {
raw_cudnn_convolution_backward_weight_out_32bit(grad_weight, grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
return;
}
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension
//
// Here we use a simple heuristics to determine the size of each split
// We don't max out the 2^31 address space because this number is super
// large and very likely to get an OOM.
int64_t n = grad_output.size(0);
int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
int64_t split_size = std::max<int64_t>(1024 * 1024 * 512 / max_inner_size, 1L);
int64_t num_splits = (n + split_size - 1) / split_size;
if (split_size * max_inner_size < int_max) {
const auto kAccType = (grad_weight.scalar_type() == kHalf || grad_weight.scalar_type() == kBFloat16)
? kFloat : grad_weight.scalar_type();
Tensor grad_weight_accumulator = at::zeros(grad_weight.sizes(), grad_weight.options().dtype(kAccType));
for (const auto i : c10::irange(num_splits)) {
int64_t start = split_size * i;
int64_t split_size_ = std::min<int64_t>(split_size, n - start);
Tensor input_ = input.narrow(0, start, split_size_);
Tensor grad_output_ = grad_output.narrow(0, start, split_size_);
Tensor grad_weight_ = at::empty_like(grad_weight);
raw_cudnn_convolution_backward_weight_out_32bit(grad_weight_, grad_output_, input_, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
grad_weight_accumulator.add_(grad_weight_);
}
grad_weight.copy_(grad_weight_accumulator);
return;
}
// If control flow reaches here, this means even splitting N is not enough, then things starts to become complicated:
// For example, for conv2d, there following questions needs to be considered.
// - Is the memory layout NCHW or NHWC ?
// - If the conv is NCHW -> NC'H'W', then should we
// - split only NC?
// - split only N'C'?
// - split both?
// - If the conv is NHWC, then we need to split across H, we need to be very careful about the boundary condition
// to make sure that the boundary is handled correctly.
// - If we decide to make these splits, is the memory contiguous? Do we need to copy the memory?
// Considering the complexity of this issue, it is better not to use cuDNN for this case
TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
}
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_add_relu_out(
#else
void raw_cudnn_convolution_add_relu_out_v7(
#endif
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const Tensor& z,
float alpha,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups,
bool benchmark,
bool deterministic,
bool allow_tf32) {
auto dataType = getCudnnDataType(input);
ConvolutionArgs args{input, output, weight};
args.handle = getCudnnHandle();
at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(input, weight);
setConvolutionParams(
&args.params,
input,
weight,
padding,
stride,
dilation,
groups,
deterministic,
allow_tf32,
memory_format);
args.idesc.set(input, memory_format);
args.wdesc.set(weight, memory_format, 0);
args.odesc.set(output, memory_format);
args.cdesc.set(
dataType,
input.dim() - 2,
args.params.padding,
args.params.stride,
args.params.dilation,
args.params.groups,
args.params.allow_tf32);
TensorDescriptor zdesc;
zdesc.set(z, memory_format);
TensorDescriptor bdesc;
bdesc.set(bias.expand({1, bias.size(0)}), memory_format, output.dim());
ActivationDescriptor adesc;
adesc.set(CUDNN_ACTIVATION_RELU);
AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark)
.try_all([&](const cudnnConvolutionFwdAlgoPerf_t& fwdAlgPerf) {
Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
// update convDesc mathType since cudnn 7.4+ now requires both algo +
// mathType to figure out whether to use Tensor core kernels or not See
// Note [behavior of cudnnFind and cudnnGet]
ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
AT_CUDNN_CHECK_WITH_SHAPES(
cudnnSetConvolutionMathType(
args.cdesc.mut_desc(), fwdAlgPerf.mathType),
args);
Constant one(dataType, 1);
Constant alpha_(dataType, alpha);
AT_CUDNN_CHECK_WITH_SHAPES(
cudnnConvolutionBiasActivationForward(
args.handle,
&one,
args.idesc.desc(),
input.data_ptr(),
args.wdesc.desc(),
weight.data_ptr(),
args.cdesc.desc(),
fwdAlgPerf.algo,
workspace.data_ptr(),
fwdAlgPerf.memory,
&alpha_,
zdesc.desc(),
z.data_ptr(),
bdesc.desc(),
bias.data_ptr(),
adesc.desc(),
args.odesc.desc(),
output.data_ptr()),
args,
"zdesc: ", zdesc,
"bdesc: ", bdesc,
"cudnnConvolutionBiasActivationForward: ",
static_cast<int>(fwdAlgPerf.algo),
"\n");
});
}
void raw_cudnn_convolution_add_relu_fallback_out(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const Tensor& z,
float alpha,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups,
bool benchmark,
bool deterministic,
bool allow_tf32) {
// cuDNN Conv-Bias-Activation:
// y = act ( alpha1 * conv(x) + alpha2 * z + bias )
// In pytorch function `raw_cudnn_convolution_add_relu_out`: alpha1 is 1, alpha 2 is `float alpha`
raw_cudnn_convolution_forward_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input.dim(), bias).add(z, alpha);
output.add_(alpha_mul_z_add_bias);
output.relu_();
}
}} // namespace at::native
#endif