-
Notifications
You must be signed in to change notification settings - Fork 33
/
doconv_pytorch.py
224 lines (200 loc) · 9.85 KB
/
doconv_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# coding=utf-8
import math
import torch
import numpy as np
from torch.nn import functional as F
from torch._jit_internal import Optional
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch import nn
from torch.nn import init
class DOConv2d(Module):
"""
DOConv2d can be used as an alternative for torch.nn.Conv2d.
The interface is similar to that of Conv2d, with one exception:
1. D_mul: the depth multiplier for the over-parameterization.
Note that the groups parameter switchs between DO-Conv (groups=1),
DO-DConv (groups=in_channels), DO-GConv (otherwise).
"""
__constants__ = ['stride', 'padding', 'dilation', 'groups',
'padding_mode', 'output_padding', 'in_channels',
'out_channels', 'kernel_size', 'D_mul']
__annotations__ = {'bias': Optional[torch.Tensor]}
def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1,
padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False):
super(DOConv2d, self).__init__()
kernel_size = (kernel_size, kernel_size)
stride = (stride, stride)
padding = (padding, padding)
dilation = (dilation, dilation)
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
if padding_mode not in valid_padding_modes:
raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
valid_padding_modes, padding_mode))
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))
self.simam = simam
#################################### Initailization of D & W ###################################
M = self.kernel_size[0]
N = self.kernel_size[1]
self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, self.D_mul))
init.kaiming_uniform_(self.W, a=math.sqrt(5))
if M * N > 1:
self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
self.D.data = torch.from_numpy(init_zero)
eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
D_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
if self.D_mul % (M * N) != 0: # the cases when D_mul > M * N
zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
self.D_diag = Parameter(torch.cat([D_diag, zeros], dim=2), requires_grad=False)
else: # the case when D_mul = M * N
self.D_diag = Parameter(D_diag, requires_grad=False)
##################################################################################################
if simam:
self.simam_block = simam_module()
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
else:
self.register_parameter('bias', None)
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
return s.format(**self.__dict__)
def __setstate__(self, state):
super(DOConv2d, self).__setstate__(state)
if not hasattr(self, 'padding_mode'):
self.padding_mode = 'zeros'
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
(0, 0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input):
M = self.kernel_size[0]
N = self.kernel_size[1]
DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
if M * N > 1:
######################### Compute DoW #################
# (input_channels, D_mul, M * N)
D = self.D + self.D_diag
W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))
# einsum outputs (out_channels // groups, in_channels, M * N),
# which is reshaped to
# (out_channels, in_channels // groups, M, N)
DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
#######################################################
else:
DoW = torch.reshape(self.W, DoW_shape)
if self.simam:
DoW_h1, DoW_h2 = torch.chunk(DoW, 2, dim=2)
DoW = torch.cat([self.simam_block(DoW_h1), DoW_h2], dim=2)
return self._conv_forward(input, DoW)
class DOConv2d_eval(Module):
"""
DOConv2d can be used as an alternative for torch.nn.Conv2d.
The interface is similar to that of Conv2d, with one exception:
1. D_mul: the depth multiplier for the over-parameterization.
Note that the groups parameter switchs between DO-Conv (groups=1),
DO-DConv (groups=in_channels), DO-GConv (otherwise).
"""
__constants__ = ['stride', 'padding', 'dilation', 'groups',
'padding_mode', 'output_padding', 'in_channels',
'out_channels', 'kernel_size', 'D_mul']
__annotations__ = {'bias': Optional[torch.Tensor]}
def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1,
padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False):
super(DOConv2d_eval, self).__init__()
kernel_size = (kernel_size, kernel_size)
stride = (stride, stride)
padding = (padding, padding)
dilation = (dilation, dilation)
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
if padding_mode not in valid_padding_modes:
raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
valid_padding_modes, padding_mode))
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))
self.simam = simam
#################################### Initailization of D & W ###################################
M = self.kernel_size[0]
N = self.kernel_size[1]
self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, M, N))
init.kaiming_uniform_(self.W, a=math.sqrt(5))
self.register_parameter('bias', None)
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
return s.format(**self.__dict__)
def __setstate__(self, state):
super(DOConv2d, self).__setstate__(state)
if not hasattr(self, 'padding_mode'):
self.padding_mode = 'zeros'
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
(0, 0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input):
return self._conv_forward(input, self.W)
class simam_module(torch.nn.Module):
def __init__(self, e_lambda=1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)