-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfl_module.py
47 lines (43 loc) · 1.38 KB
/
fl_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
from gdn import GDN
class FL1(nn.Module):
def __init__(self,F,D_in,D_out,S,P):
super(FL1,self).__init__()
self.module = nn.Sequential(
nn.Conv2d(in_channels=D_in,
out_channels=D_out,
kernel_size=F,
stride=S,
padding=P)
)
if(torch.cuda.is_available()):
self.gdn = GDN(D_out,'cpu')
else:
self.gdn = GDN(D_out,'cpu')
self.last = nn.PReLU()
def forward(self,x):
x = self.module(x)
x = self.gdn.forward(x)
x = self.last(x)
return x
class FL2(nn.Module):
def __init__(self,F,D_in,D_out,S,P):
super(FL2,self).__init__()
self.module = nn.Sequential(
nn.ConvTranspose2d(in_channels=D_in,
out_channels=D_out,
kernel_size=F,
stride=S,
padding=P)
)
if(torch.cuda.is_available()):
self.gdn = GDN(D_out,'cpu',inverse=True)
else:
self.gdn = GDN(D_out,'cpu',inverse=True)
self.last = nn.PReLU()
def forward(self,x):
x = self.module(x)
x = self.gdn.forward(x)
x = self.last(x)
return x