forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConv.cpp
1270 lines (1108 loc) · 46.3 KB
/
Conv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/Exceptions.h>
#if !AT_CUDNN_ENABLED()
namespace at { namespace native {
// See Note [ATen preprocessor philosophy]
at::Tensor cudnn_convolution(
const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic) {
AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support");
}
at::Tensor cudnn_convolution_backward_input(
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support");
}
at::Tensor cudnn_convolution_backward_weight(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support");
}
at::Tensor cudnn_convolution_backward_bias(
const at::Tensor& grad_output) {
AT_ERROR("cudnn_convolution_backward_bias: ATen not compiled with cuDNN support");
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support");
}
at::Tensor cudnn_convolution_transpose(
const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic) {
AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support");
}
at::Tensor cudnn_convolution_transpose_backward_input(
const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic) {
AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
}
at::Tensor cudnn_convolution_transpose_backward_weight(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support");
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
}
}}
#else // AT_CUDNN_ENABLED
#include <THC/THC.h>
#include <ATen/cudnn/cudnn-wrapper.h>
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/native/utils/ParamsHash.h>
#include <ATen/TensorUtils.h>
#include <functional>
#include <iterator>
#include <sstream>
#include <algorithm>
#include <memory>
#include <mutex>
#include <stdint.h>
#include <unordered_map>
// Note [behavior of cudnnFind and cudnnGet]
// You'll notice that by default, in the ConvolutionDescriptor, we do the following:
//
// AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
// if(dataType == CUDNN_DATA_HALF)
// AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
//
// When cudnnSetConvolutionMathType is called before cudnnGet/cudnnFind, it informs
// cudnnGet/cudnnFind to iterate/take into account both tensor core and non-tensor-core algos.
// If you don't call cudnnSetConvolutionMathType before calling cudnnGet/cudnnFind,
// cudnnGet/cudnnFind may not pick tensor core algos.
//
// Now after its run, cudnnGet/cudnnFind comes up with the best pair of algo+mathType
// with all the initial knowledge its given. It then becomes the user's responsibility
// to update mathType of the convolution descriptor and call the subsequent cudnn calls with
// the best algo and the updated descriptor. If we don't update the descriptor but just run
// with the best algo, under the hood, cudnn will run with the slower kernel
// since it sees fastest algorithm combination with a sub optimal mathType.
// Note [blacklist fft algorithms for strided dgrad]
// This is a workaround for a CuDNN bug that gave wrong results in certain strided convolution
// gradient setups. Check Issue #16610 for bug details. Bug is there for CUDNN version < 7.5 .
namespace at { namespace native {
// TODO: Go through all the checking code again and make sure
// we haven't missed anything.
// ---------------------------------------------------------------------
//
// Math
//
// ---------------------------------------------------------------------
constexpr int input_batch_size_dim = 0; // also grad_input
constexpr int input_channels_dim = 1;
constexpr int output_batch_size_dim = 0; // also grad_output
constexpr int output_channels_dim = 1;
constexpr int weight_output_channels_dim = 0;
constexpr int weight_input_channels_dim = 1;
// Often written as 2 + max_dim (extra dims for batch size and channels)
constexpr int max_dim = 3;
// NB: conv_output_size and conv_input_size are not bijections,
// as conv_output_size loses information; this is why conv_input_size
// takes an extra output_padding argument to resolve the ambiguity.
static std::vector<int64_t> conv_output_size(
IntArrayRef input_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
// ASSERT(input_size.size() > 2)
// ASSERT(input_size.size() == weight_size.size())
auto dim = input_size.size();
std::vector<int64_t> output_size(dim);
output_size[0] = input_size[input_batch_size_dim];
output_size[1] = weight_size[weight_output_channels_dim];
for (size_t d = 2; d < dim; ++d) {
auto kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
output_size[d] = (input_size[d] + (2 * padding[d - 2])
- kernel) / stride[d - 2] + 1;
}
return output_size;
}
std::vector<int64_t> conv_input_size(
IntArrayRef output_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
// ASSERT(output_size.size() > 2)
// ASSERT(output_size.size() == weight_size.size())
auto dim = output_size.size();
std::vector<int64_t> input_size(dim);
input_size[0] = output_size[output_batch_size_dim];
input_size[1] = weight_size[weight_input_channels_dim] * groups;
for (size_t d = 2; d < dim; ++d) {
int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) +
kernel + output_padding[d - 2];
}
return input_size;
}
std::vector<int64_t> conv_weight_size(
IntArrayRef input_size, IntArrayRef output_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
auto dim = input_size.size();
std::vector<int64_t> weight_size(dim);
weight_size[0] = output_size[1];
weight_size[1] = input_size[1] / groups;
for (size_t d = 2; d < dim; ++d) {
int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
+ 2 * padding[d - 2] - output_padding[d - 2];
weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
}
return weight_size;
}
// TODO: Move this into the standard library, with a better name?
Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) {
auto group_size = t.size(dim) / groups;
return t.narrow(dim, group_idx * group_size, group_size);
}
// ---------------------------------------------------------------------
//
// Checking
//
// ---------------------------------------------------------------------
// Note [Legacy CuDNN grouped convolution support]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// CuDNN earlier than CuDNN 7 does not directly support group
// convolution, so we provide support for it by sequentially
// running a convolution per group with appropriately
// adjusted sizes. https://blog.yani.io/filter-group-tutorial/
// has a fairly good diagram explaining how it works.
// Used on pad, stride and dilation
static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
{
TORCH_CHECK(args.size() <= expected_size,
"Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
TORCH_CHECK(args.size() >= expected_size,
"Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
if (num_negative_values > 0){
std::stringstream ss;
ss << arg_name << " should be greater than zero but got (";
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
AT_ERROR(ss.str());
}
}
// NOTE [ Convolution checks ]
//
// NB: For many call sites, it is not strictly necessary to check all of
// these relationships (for example, for forward convolution, we compute
// the size of output ourselves, so we don't actually need to check
// output. However, writing a single function that does everything
// means we get to reuse it for both forwards and all backwards
// variants, even when the set of "real" inputs varies. The magic of
// relational computing!
//
// (There is one downside, which is that it is slightly harder to write
// error messages which are able to distinguish between real inputs
// (which the user can change) and computed inputs (which the user can
// only indirectly affect). It would be an interesting exercise to
// come up with a general framework to handle such situations.)
static void convolution_shape_check(
CheckedFrom c,
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
{
check_args(c, padding, input->dim() - 2, "padding");
check_args(c, stride, padding.size(), "stride");
check_args(c, dilation, padding.size(), "dilation");
// Input
checkDimRange(c, input, 3, 6 /* exclusive */);
checkSize(c, input, input_channels_dim, weight->size(1) * groups);
// Weight
checkSameDim(c, input, weight);
// TODO: check that output->size() matches output_sizes
// TODO: check that weight matches output->sizes()
checkSameDim(c, input, output);
}
// This POD struct is used to let us easily compute hashes of the
// parameters
struct ConvolutionParams
{
cudnnDataType_t dataType;
int input_size[2 + max_dim];
int input_stride[2 + max_dim];
int weight_size[2 + max_dim];
int padding[max_dim];
int stride[max_dim];
int dilation[max_dim];
int64_t groups;
bool deterministic;
// NB: transposed purposely omitted: transposed just swaps
// forward and backward, so you can reuse the benchmark entry,
};
// NB: This can't be a constructor, because then ConvolutionParams
// would not be a POD anymore.
// TODO: Use TensorGeometry here instead of the entire Tensor, which we
// don't actually need. (OTOH: We can always pass in
// grad_input/grad_output, so this is not very pressing)
void setConvolutionParams(
ConvolutionParams* params,
const at::Tensor& input, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool deterministic) {
cudnnDataType_t dataType = getCudnnDataType(input);
memset(params, 0, sizeof(ConvolutionParams));
params->dataType = dataType;
// ASSERT(weight.dim() == input.dim())
for (int i = 0; i != input.dim(); ++i) {
params->input_size[i] = (int) input.size(i);
params->input_stride[i] = (int) input.stride(i);
params->weight_size[i] = (int) weight.size(i);
}
// ASSERT(padding.size() == stride.size())
// ASSERT(padding.size() == dilation.size())
for (size_t i = 0; i != padding.size(); ++i) {
params->padding[i] = padding[i];
params->stride[i] = stride[i];
params->dilation[i] = dilation[i];
}
// In principle, we shouldn't parametrize by groups for legacy
// CuDNN, but it doesn't seem worth the effort to actually do this.
params->groups = groups;
params->deterministic = deterministic;
}
// Convenience struct for passing around descriptors and data
// pointers
struct ConvolutionArgs {
cudnnHandle_t handle;
ConvolutionParams params;
TensorDescriptor idesc, odesc;
FilterDescriptor wdesc;
const Tensor& input, output, weight;
ConvolutionDescriptor cdesc;
ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) {
}
};
// ---------------------------------------------------------------------
//
// Benchmarking
//
// ---------------------------------------------------------------------
// TODO: Use something less heavy duty than a big honking mutex
template <typename T>
struct BenchmarkCache {
std::mutex mutex;
std::unordered_map<ConvolutionParams, T, ParamsHash<ConvolutionParams>, ParamsEqual<ConvolutionParams>> map;
bool find(const ConvolutionParams& params, T* results) {
std::lock_guard<std::mutex> guard(mutex);
auto it = map.find(params);
if (it == map.end()) {
return false;
}
*results = it->second;
return true;
}
void insert(const ConvolutionParams& params, const T& results) {
std::lock_guard<std::mutex> guard(mutex);
map[params] = results;
}
};
BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos;
BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos;
BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos;
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(globalContext().lazyInitCUDA(), size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(globalContext().lazyInitCUDA(), data);
}
}
size_t size;
void* data;
};
template<typename perf_t>
struct algorithm_search {
};
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionFwdAlgo_t algo, size_t* sz)
{
return cudnnGetConvolutionForwardWorkspaceSize(
args.handle,
args.idesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
algo,
sz
);
}
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdDataAlgo_t algo, size_t* sz)
{
return cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle,
args.wdesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.idesc.desc(),
algo,
sz);
}
cudnnStatus_t getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz)
{
return cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle,
args.idesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.wdesc.desc(),
algo,
sz);
}
template<typename algo_t>
size_t getMaxWorkspaceSize(
const ConvolutionArgs& args,
const algo_t *algo, int n_algo)
{
THCState *state = globalContext().lazyInitCUDA();
size_t max_ws_size = 0;
size_t max_block_size = 0;
size_t total_gpu_mem = 0;
size_t free_gpu_mem = 0;
THCudaCheck(THCudaMemGetInfo(state, &free_gpu_mem, &total_gpu_mem, &max_block_size));
for (int i = 0; i < n_algo; i++) {
cudnnStatus_t err;
size_t sz;
err = getWorkspaceSize(args, algo[i], &sz);
if (CUDNN_STATUS_SUCCESS != err || sz == 0
|| sz < max_ws_size || sz > max_block_size) continue;
max_ws_size = sz;
}
return max_ws_size;
}
template<typename perf_t>
perf_t getBestAlgorithm(perf_t *perfResults, const ConvolutionArgs& args, int n_algo) {
int best_algo_idx;
bool is_deterministic = false;
if (args.params.deterministic) {
// iterate over perf results of all algorithms and find the best deterministic algo
for (int i = 0; i < n_algo; i++) {
// TODO: Shouldn't all returned results be successful?
// Double check documentation for cudnnFindConvolutionForwardAlgorithmEx
if (perfResults[i].status == CUDNN_STATUS_SUCCESS &&
perfResults[i].determinism == CUDNN_DETERMINISTIC) {
best_algo_idx = i;
is_deterministic = true;
break;
}
}
if (!is_deterministic) {
AT_ERROR("no deterministic convolution algorithms available in CuDNN");
}
} else {
best_algo_idx = 0;
}
// See Note [blacklist fft algorithms for strided dgrad]
#if CUDNN_VERSION < 7500
if (std::is_same<decltype(perfResults[best_algo_idx].algo), cudnnConvolutionBwdDataAlgo_t>::value) {
int stride_dim = args.input.dim() - 2;
bool blacklist = std::any_of(std::begin(args.params.stride),
std::begin(args.params.stride) + stride_dim,
[=](int n){return n != 1;});
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[best_algo_idx].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
|| static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[best_algo_idx].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
perfResults[best_algo_idx].algo = algorithm_search<perf_t>::DEFAULT_ALGO;
if (args.params.dataType == CUDNN_DATA_HALF) {
perfResults[best_algo_idx].mathType = CUDNN_TENSOR_OP_MATH;
} else {
perfResults[best_algo_idx].mathType = CUDNN_DEFAULT_MATH;
}
}
}
#endif
return perfResults[best_algo_idx];
}
template<>
struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
using algo_t = cudnnConvolutionFwdAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
static BenchmarkCache<perf_t>& cache() { return fwd_algos; }
static perf_t findAlgorithm(const ConvolutionArgs& args, bool benchmark) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
};
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution forward algorithms");
int perf_count;
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
if (!benchmark) {
AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
args.handle,
args.idesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
num_algos,
&perf_count,
perf_results.get()));
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
args.wdesc.desc(), args.weight.data_ptr(),
args.cdesc.desc(),
args.odesc.desc(), args.output.data_ptr(),
num_algos,
&perf_count,
perf_results.get(),
ws.data,
ws.size));
}
return getBestAlgorithm<perf_t>(perf_results.get(), args, perf_count);
}
static void getWorkspaceSize(
const ConvolutionArgs& args,
algo_t algo, size_t* workspaceSize)
{
AT_CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
args.handle,
args.idesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
algo,
workspaceSize));
}
};
template<>
struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
using algo_t = cudnnConvolutionBwdDataAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
static BenchmarkCache<perf_t>& cache() { return bwd_data_algos; }
static perf_t findAlgorithm(const ConvolutionArgs& args, bool benchmark) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
};
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution backward data algorithms.");
int perf_count;
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
if (!benchmark) {
AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7(
args.handle,
args.wdesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.idesc.desc(),
num_algos,
&perf_count,
perf_results.get()));
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
args.handle,
args.wdesc.desc(), args.weight.data_ptr(),
args.odesc.desc(), args.output.data_ptr(),
args.cdesc.desc(),
args.idesc.desc(), args.input.data_ptr(),
num_algos,
&perf_count,
perf_results.get(),
ws.data,
ws.size));
}
return getBestAlgorithm<perf_t>(perf_results.get(), args, perf_count);
}
static void getWorkspaceSize(
const ConvolutionArgs& args,
cudnnConvolutionBwdDataAlgo_t algo, size_t* workspaceSize)
{
AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle,
args.wdesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.idesc.desc(),
algo,
workspaceSize));
}
};
template<>
struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
using algo_t = cudnnConvolutionBwdFilterAlgo_t;
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
static BenchmarkCache<perf_t>& cache() { return bwd_filter_algos; }
static perf_t findAlgorithm(const ConvolutionArgs& args, bool benchmark) {
static const algo_t algos[] = {
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
};
// NOTE: - 1 because ALGO_WINOGRAD is not implemented
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution backward filter algorithms.");
std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
int perf_count;
if (!benchmark) {
AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
args.handle,
args.idesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.wdesc.desc(),
num_algos,
&perf_count,
perf_results.get()));
} else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
Workspace ws(max_ws_size);
AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
args.odesc.desc(), args.output.data_ptr(),
args.cdesc.desc(),
args.wdesc.desc(), args.weight.data_ptr(),
num_algos,
&perf_count,
perf_results.get(),
ws.data,
ws.size));
}
return getBestAlgorithm<perf_t>(perf_results.get(), args, perf_count);
}
static void getWorkspaceSize(const ConvolutionArgs& args, algo_t algo, size_t* workspaceSize)
{
AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle,
args.idesc.desc(),
args.odesc.desc(),
args.cdesc.desc(),
args.wdesc.desc(),
algo,
workspaceSize));
}
};
template<typename perf_t>
void findAlgorithm(const ConvolutionArgs& args, bool benchmark, perf_t* algoPerf) {
using search = algorithm_search<perf_t>;
auto& cache = search::cache();
if (cache.find(args.params, algoPerf)) {
return;
}
if (args.params.deterministic && !benchmark) {
algoPerf->algo = search::DEFAULT_ALGO;
if (args.params.dataType == CUDNN_DATA_HALF) {
algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
} else {
algoPerf->mathType = CUDNN_DEFAULT_MATH;
}
search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
return;
}
if (benchmark) {
if (cache.find(args.params, algoPerf)) {
// re-check cache since another thread may have benchmarked the algorithm
return;
}
}
auto perfResults = search::findAlgorithm(args, benchmark);
// for deterministic algo, look at all the perf results and return the best
// deterministic algo
if (perfResults.status == CUDNN_STATUS_SUCCESS &&
!(args.params.deterministic && perfResults.determinism != CUDNN_DETERMINISTIC)) {
// if benchmarking, map the original params with the found algo+math type for re-use
if (benchmark) {
cache.insert(args.params, perfResults);
// Free the cached blocks in our caching allocator. They are
// needed here because the above benchmarking uses a huge amount of memory,
// e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache();
}
*algoPerf = perfResults;
} else {
algoPerf->algo = search::DEFAULT_ALGO;
if (args.params.dataType == CUDNN_DATA_HALF) {
algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
} else {
algoPerf->mathType = CUDNN_DEFAULT_MATH;
}
search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
}
}
template<typename perf_t>
Workspace chooseAlgorithm(
const ConvolutionArgs& args,
bool benchmark,
perf_t* algoPerf)
{
findAlgorithm(args, benchmark, algoPerf);
using search = algorithm_search<perf_t>;
try {
return Workspace(algoPerf->memory);
} catch (const std::exception& e) {
cudaGetLastError(); // clear OOM error
// switch to default algorithm and record it in the cache to prevent
// further OOM errors
algoPerf->algo = search::DEFAULT_ALGO;
if (args.params.dataType == CUDNN_DATA_HALF) {
algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
} else {
algoPerf->mathType = CUDNN_DEFAULT_MATH;
}
search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
search::cache().insert(args.params, *algoPerf);
return Workspace(algoPerf->memory);
}
}
// ---------------------------------------------------------------------
//
// Bias addition
//
// ---------------------------------------------------------------------
// In-place!
void cudnn_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const TensorArg& bias)
{
checkAllSameType(c, {output, bias});
checkAllSameGPU(c, {output, bias});
checkSize(c, bias, { output->size(output_channels_dim) });
if (output.tensor.numel() == 0) {
return;
}
// See Note [CuDNN broadcast padding]. Handle the left padding
// ourselves, but use TensorDescriptor's padding argument to do the rest.
TensorDescriptor bdesc, odesc;
bdesc.set(bias->expand({1, bias->size(0)}), output->dim());
odesc.set(*output);
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*bias);
Constant one(dataType, 1);
AT_CUDNN_CHECK(cudnnAddTensor(handle, &one, bdesc.desc(), bias->data_ptr(),
&one, odesc.desc(), output->data_ptr()));
}
// NOTE [ Convolution design ]
//
// The general strategy:
//
// - cudnn_convolution (Tensor)
// Entry points for clients, takes bias
//
// - cudnn_convolution_forward (TensorArg)
// Entry point, which may be reused between regular
// convolution and transposed convolution. Does NOT take bias.
//
// - raw_cudnn_convolution_forward_out (Tensor)
// Low level function which invokes CuDNN, and takes an output
// tensor which is directly written to (thus _out).
//
// Where does argument checking happen? Here's the division of
// responsibility:
// - Things that happen in at::Tensor
// - TensorArg allocation
// - setCuDNNStreamToCurrent
// - Things that happen in TensorArg
// - Check arguments (type, GPU, shape)
//
// TODO: Consider renaming zero-indexed arguments to "self"
// ---------------------------------------------------------------------
//
// Convolution forward / Transposed convolution backward
//
// ---------------------------------------------------------------------
// The raw API directly invokes CuDNN and does not emulate support
// for group convolution on old versions of CuDNN.
//
// There are a few reasons this should never be directly exposed
// via ATen:
//
// - It takes output as a parameter (this should be computed!)
// - It doesn't do input checking
// - It doesn't resize output (it is assumed to be correctly sized)
//
void raw_cudnn_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
auto dataType = getCudnnDataType(input);
ConvolutionArgs args{ input, output, weight };
args.handle = getCudnnHandle();
setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(input);
args.wdesc.set(weight);
args.odesc.set(output);
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
// TODO: when we do legacy group convolution support, we'll repeatedly
// reinitialize the workspace for each convolution we do. This is
// wasteful; we'd rather reuse the workspace. OTOH, legacy group
// convolution support is already pretty slow, so this might not
// matter. (This applies to raw_cudnn_convolution_backward_input as well.)
cudnnConvolutionFwdAlgoPerf_t fwdAlgPerf;
Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlgPerf);
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
// whether to use Tensor core kernels or not
// See Note [behavior of cudnnFind and cudnnGet]
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType));
Constant one(dataType, 1);
Constant zero(dataType, 0);
AT_CUDNN_CHECK(cudnnConvolutionForward(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.wdesc.desc(), weight.data_ptr(),
args.cdesc.desc(), fwdAlgPerf.algo, workspace.data, workspace.size,
&zero, args.odesc.desc(), output.data_ptr()));
}
Tensor cudnn_convolution_forward(
CheckedFrom c,
const TensorArg& input, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
checkAllSameType(c, {input, weight});
checkAllSameGPU(c, {input, weight});
auto output_t = at::empty(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation, groups),
input->options());
if (output_t.numel() == 0) {
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{ output_t, "result", 0 };
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
// See #4500
Tensor weight_contig = weight->contiguous();
raw_cudnn_convolution_forward_out(
*output, *input, weight_contig,
padding, stride, dilation, groups, benchmark, deterministic);
return *output;
}
Tensor cudnn_convolution(
const Tensor& input_t, const Tensor& weight_t, const Tensor& bias_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic)
{
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 },
bias { bias_t, "bias", 3 };
setCuDNNStreamToCurrent();
CheckedFrom c = "cudnn_convolution";
auto output_t = cudnn_convolution_forward(
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
if (bias->defined()) {
cudnn_convolution_add_bias_(c, { output_t, "result", 0 }, bias);
}
return output_t;
}
// NB: output_padding not needed here, as there is no ambiguity to
// resolve
Tensor cudnn_convolution_transpose_backward_input(
const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic)
{
TensorArg grad_output { grad_output_t, "grad_output", 1 },
weight { weight_t, "weight", 2 };
setCuDNNStreamToCurrent();
return cudnn_convolution_forward(
"cudnn_convolution_transpose_backward_input",
grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
Tensor grad_output = grad_output_t.contiguous();
Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
}
if (output_mask[1]) {
grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
}
if (output_mask[2]) {
grad_bias = at::cudnn_convolution_backward_bias(grad_output);
}
return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
}
// ---------------------------------------------------------------------
//
// Convolution backward / Transposed convolution forward
//
// ---------------------------------------------------------------------
void raw_cudnn_convolution_backward_input_out(
const at::Tensor& grad_input,
const at::Tensor& grad_output,
const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
auto dataType = getCudnnDataType(grad_output);
ConvolutionArgs args{ grad_input, grad_output, weight };
args.handle = getCudnnHandle();
setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic);