-
Notifications
You must be signed in to change notification settings - Fork 0
/
unet.py
96 lines (79 loc) · 3.51 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
class UNet3D(nn.Module):
def __init__(self, in_channel, n_classes):
self.in_channel = in_channel
self.n_classes = n_classes
super(UNet3D, self).__init__()
self.ec0 = self.encoder(self.in_channel, 32, bias=False, batchnorm=False)
self.ec1 = self.encoder(32, 64, bias=False, batchnorm=False)
self.ec2 = self.encoder(64, 64, bias=False, batchnorm=False)
self.ec3 = self.encoder(64, 128, bias=False, batchnorm=False)
self.ec4 = self.encoder(128, 128, bias=False, batchnorm=False)
self.ec5 = self.encoder(128, 256, bias=False, batchnorm=False)
self.ec6 = self.encoder(256, 256, bias=False, batchnorm=False)
self.ec7 = self.encoder(256, 512, bias=False, batchnorm=False)
self.pool0 = nn.MaxPool3d(2)
self.pool1 = nn.MaxPool3d(2)
self.pool2 = nn.MaxPool3d(2)
self.dc9 = self.decoder(512, 512, kernel_size=2, stride=2, bias=False)
self.dc8 = self.decoder(256 + 512, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.dc7 = self.decoder(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.dc6 = self.decoder(256, 256, kernel_size=2, stride=2, bias=False)
self.dc5 = self.decoder(128 + 256, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.dc4 = self.decoder(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.dc3 = self.decoder(128, 128, kernel_size=2, stride=2, bias=False)
self.dc2 = self.decoder(64 + 128, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.dc1 = self.decoder(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.dc0 = self.decoder(64, n_classes, kernel_size=1, stride=1, bias=False)
def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
bias=True, batchnorm=False):
if batchnorm:
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
nn.BatchNorm2d(out_channels),
nn.ReLU())
else:
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
nn.ReLU())
return layer
def decoder(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
output_padding=0, bias=True):
layer = nn.Sequential(
nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, output_padding=output_padding, bias=bias),
nn.ReLU())
return layer
def forward(self, x):
e0 = self.ec0(x)
syn0 = self.ec1(e0)
e1 = self.pool0(syn0)
e2 = self.ec2(e1)
syn1 = self.ec3(e2)
del e0, e1, e2
e3 = self.pool1(syn1)
e4 = self.ec4(e3)
syn2 = self.ec5(e4)
del e3, e4
e5 = self.pool2(syn2)
e6 = self.ec6(e5)
e7 = self.ec7(e6)
del e5, e6
d9 = torch.cat((self.dc9(e7), syn2))
del e7, syn2
d8 = self.dc8(d9)
d7 = self.dc7(d8)
del d9, d8
d6 = torch.cat((self.dc6(d7), syn1))
del d7, syn1
d5 = self.dc5(d6)
d4 = self.dc4(d5)
del d6, d5
d3 = torch.cat((self.dc3(d4), syn0))
del d4, syn0
d2 = self.dc2(d3)
d1 = self.dc1(d2)
del d3, d2
d0 = self.dc0(d1)
return d0