Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed transfer learning architecture #6

Merged
merged 2 commits into from
Feb 28, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 83 additions & 6 deletions T3D_keras.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import keras
from keras.models import Sequential, Model
from keras.layers import Input, BatchNormalization, Activation, Conv3D, Dropout, Concatenate, AveragePooling3D, MaxPooling3D, Dense, Flatten, GlobalAveragePooling2D
from keras.layers import Input, BatchNormalization, Activation, Conv3D, Dropout, Concatenate, AveragePooling3D, MaxPooling3D, Dense, Flatten, GlobalAveragePooling2D, GlobalAveragePooling3D
from keras.activations import linear, softmax
from keras.applications import densenet
from keras.layers import TimeDistributed

__all__ = ['DenseNet', 'densenet121', 'densenet161'] # with DropOut

Expand Down Expand Up @@ -82,9 +83,15 @@ def DenseNet3D(input_shape, growth_rate=32, block_config=(6, 12, 24, 16),
"""
#-----------------------------------------------------------------
inp_2d = (Input(shape=(224,224,3), name='2d_input'))
pretrained_densenet = densenet.DenseNet169(include_top=False, input_shape=(224,224,3), input_tensor=inp_2d, weights='imagenet')
for layer in pretrained_densenet.layers:
batch_densenet = densenet.DenseNet169(include_top=False, input_shape=(224,224,3), input_tensor=inp_2d, weights='imagenet')

for layer in batch_densenet.layers:
layer.trainable = False

# Configure the 2D CNN to take batches of pictures
inp_2d_batch = (Input(shape=input_shape, name='2d_input_batch'))
batch_densenet = TimeDistributed(batch_densenet)(inp_2d_batch)
batch_densenet = Model(inputs=inp_2d_batch, outputs=batch_densenet)
#-----------------------------------------------------------------

# First convolution-----------------------
Expand Down Expand Up @@ -127,8 +134,8 @@ def DenseNet3D(input_shape, growth_rate=32, block_config=(6, 12, 24, 16),
x = AveragePooling3D(pool_size=(1, 7, 7))(x)
x = Flatten(name='flatten_3d')(x)
x = Dense(1024, activation='relu')(x)
#--------------fron 2d densenet model-----------------
y = GlobalAveragePooling2D(name='avg_pool_densnet2d')(pretrained_densenet.output)
#--------------from 2d densenet model-----------------
y = GlobalAveragePooling3D(name='avg_pool_densnet3d')(batch_densenet.output)
y = Dense(1024, activation='relu')(y)

#-----------------------------------------------------
Expand All @@ -140,7 +147,77 @@ def DenseNet3D(input_shape, growth_rate=32, block_config=(6, 12, 24, 16),
x = Dropout(0.35)(x)
out = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=[inp_2d, inp_3d], outputs=[out])
model = Model(inputs=[inp_2d_batch, inp_3d], outputs=[out])
model.summary()

return model


# The T3D CNN standalone
def T3D(input_shape, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=5):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""

# First convolution-----------------------
inp_3d = (Input(shape=input_shape, name='3d_input'))


# need to check padding
x = (Conv3D(num_init_features, kernel_size=(3, 7, 7),
strides=2, padding='same', use_bias=False))(inp_3d)

x = BatchNormalization()(x)
x = Activation('relu')(x)

# need to check padding
x = MaxPooling3D(pool_size=(3, 3, 3), strides=(
2, 2, 2), padding='valid')(x)

# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
# print('Pass', i)
x = _DenseBlock(x, num_layers=num_layers,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)

num_features = num_features + num_layers * growth_rate

if i != len(block_config) - 1:
# print('Not Last layer, so adding Temporal Transition Layer')

x = _TTL(x)
# num_features = 128*3

x = _Transition(x, num_output_features=num_features)
num_features = num_features

# Final batch norm
x = BatchNormalization()(x)

x = Activation('relu')(x)
x = AveragePooling3D(pool_size=(1, 7, 7))(x)
x = Flatten(name='flatten_3d')(x)
x = Dense(1024, activation='relu')(x)

#-----------------------------------------------------
x = Dropout(0.65)(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.35)(x)
out = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=[inp_3d], outputs=[out])
model.summary()

return model
Expand Down