forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConvUtils.h
140 lines (119 loc) · 5.28 KB
/
ConvUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#pragma once
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>
namespace at { namespace native {
// ---------------------------------------------------------------------
//
// Math
//
// ---------------------------------------------------------------------
constexpr int input_batch_size_dim = 0; // also grad_input
constexpr int input_channels_dim = 1;
constexpr int output_batch_size_dim = 0; // also grad_output
constexpr int output_channels_dim = 1;
constexpr int weight_output_channels_dim = 0;
constexpr int weight_input_channels_dim = 1;
// Often written as 2 + max_dim (extra dims for batch size and channels)
constexpr int max_dim = 3;
// NB: conv_output_size and conv_input_size are not bijections,
// as conv_output_size loses information; this is why conv_input_size
// takes an extra output_padding argument to resolve the ambiguity.
static inline std::vector<int64_t> conv_output_size(
IntArrayRef input_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
) {
// ASSERT(input_size.size() > 2)
// ASSERT(input_size.size() == weight_size.size())
bool has_dilation = dilation.size() > 0;
auto dim = input_size.size();
std::vector<int64_t> output_size(dim);
output_size[0] = input_size[input_batch_size_dim];
output_size[1] = weight_size[weight_output_channels_dim];
for (const auto d : c10::irange(2, dim)) {
auto dilation_ = has_dilation ? dilation[d - 2] : 1;
auto kernel = dilation_ * (weight_size[d] - 1) + 1;
output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
}
return output_size;
}
static inline std::vector<int64_t> conv_input_size(
IntArrayRef output_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
// ASSERT(output_size.size() > 2)
// ASSERT(output_size.size() == weight_size.size())
auto dim = output_size.size();
std::vector<int64_t> input_size(dim);
input_size[0] = output_size[output_batch_size_dim];
input_size[1] = weight_size[weight_input_channels_dim] * groups;
for (const auto d : c10::irange(2, dim)) {
int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) +
kernel + output_padding[d - 2];
}
return input_size;
}
static inline std::vector<int64_t> conv_weight_size(
IntArrayRef input_size, IntArrayRef output_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
auto dim = input_size.size();
std::vector<int64_t> weight_size(dim);
weight_size[0] = output_size[1];
weight_size[1] = input_size[1] / groups;
for (const auto d : c10::irange(2, dim)) {
int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
+ 2 * padding[d - 2] - output_padding[d - 2];
weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
}
return weight_size;
}
static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
std::vector<int64_t> shape(dim, 1);
shape[1] = -1;
return bias.reshape(shape);
}
static inline bool cudnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return false;
}
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast)
);
bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
);
return can_use_cudnn_channels_last_2d || can_use_cudnn_channels_last_3d;
}
static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return false;
}
bool can_use_miopen_channels_last_2d = false;
#if defined(USE_ROCM) && (ROCM_VERSION >= 40300)
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
// See #64427
static c10::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
);
#endif
bool can_use_miopen_channels_last_3d = false;
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
}
}} // namespace at::native