-
Notifications
You must be signed in to change notification settings - Fork 0
/
Unet_encoder.py
103 lines (94 loc) · 3.42 KB
/
Unet_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
"""
Created on Wed May 5 17:25:43 2021
@author: GREEN&LYC
"""
from keras import layers
from keras.layers import BatchNormalization
def Unet_encoder(img_input):
# Block 1
# 512,512,3 -> 512,512,64
x = layers.Conv2D(64, (3, 3),
activation='relu',
padding='same',
name='block1_conv1')(img_input)
x = (BatchNormalization())(x)
x = layers.Conv2D(64, (3, 3),
activation='relu',
padding='same',
name='block1_conv2')(x)
x = (BatchNormalization())(x)
feat1 = x
# 512,512,64 -> 256,256,64
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
# 256,256,64 -> 256,256,128
x = layers.Conv2D(128, (3, 3),
activation='relu',
padding='same',
name='block2_conv1')(x)
x = (BatchNormalization())(x)
x = layers.Conv2D(128, (3, 3),
activation='relu',
padding='same',
name='block2_conv2')(x)
x = (BatchNormalization())(x)
feat2 = x
# 256,256,128 -> 128,128,128
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
# 128,128,128 -> 128,128,256
x = layers.Conv2D(256, (3, 3),
activation='relu',
padding='same',
name='block3_conv1')(x)
x = (BatchNormalization())(x)
x = layers.Conv2D(256, (3, 3),
activation='relu',
padding='same',
name='block3_conv2')(x)
# x = layers.Conv2D(256, (3, 3),
# activation='relu',
# padding='same',
# name='block3_conv3')(x)
x = (BatchNormalization())(x)
feat3 = x
# 128,128,256 -> 64,64,256
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
# 64,64,256 -> 64,64,512
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block4_conv1')(x)
x = (BatchNormalization())(x)
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block4_conv2')(x)
# x = layers.Conv2D(512, (3, 3),
# activation='relu',
# padding='same',
# name='block4_conv3')(x)
x = (BatchNormalization())(x)
feat4 = x
# 64,64,512 -> 32,32,512
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
# # 32,32,512 -> 32,32,512
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block5_conv1')(x)
x = (BatchNormalization())(x)
x = layers.Conv2D(512, (3, 3),
activation='relu',
padding='same',
name='block5_conv2')(x)
# x = layers.Conv2D(512, (3, 3),
# activation='relu',
# padding='same',
# name='block5_conv3')(x)
x = (BatchNormalization())(x)
feat5 = x
return feat1, feat2, feat3, feat4, feat5