-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
105 lines (78 loc) · 3.88 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch.nn as nn
import torch
class ComplexConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.real_conv = nn.Conv2d(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
padding=self.padding,
stride=self.stride)
self.im_conv = nn.Conv2d(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
padding=self.padding,
stride=self.stride)
nn.init.xavier_uniform_(self.real_conv.weight)
nn.init.xavier_uniform_(self.im_conv.weight)
def forward(self, x):
x_real = x.real
x_im = x.imag
c_real = self.real_conv(x_real) - self.im_conv(x_im)
c_im = self.im_conv(x_real) + self.real_conv(x_im)
output = torch.complex(c_real, c_im)
return output
class ComplexConvTranspose2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding=0, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.output_padding = output_padding
self.padding = padding
self.stride = stride
self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
output_padding=self.output_padding,
padding=self.padding,
stride=self.stride)
self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
output_padding=self.output_padding,
padding=self.padding,
stride=self.stride)
nn.init.xavier_uniform_(self.real_convt.weight)
nn.init.xavier_uniform_(self.im_convt.weight)
def forward(self, x):
x_real = x.real
x_im = x.imag
ct_real = self.real_convt(x_real) - self.im_convt(x_im)
ct_im = self.im_convt(x_real) + self.real_convt(x_im)
output = torch.complex(ct_real, ct_im)
return output
class ComplexBatchNorm2D(nn.Module):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.real_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
affine=self.affine, track_running_stats=self.track_running_stats)
self.im_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
affine=self.affine, track_running_stats=self.track_running_stats)
def forward(self, x):
x_real = x.real
x_im = x.imag
n_real = self.real_b(x_real)
n_im = self.im_b(x_im)
output = torch.complex(n_real, n_im)
return output