Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change fixed classes name to instance check for Liskov substitute compliance #72

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 66 additions & 62 deletions src/kerassurgeon/surgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import BatchNormalization, InputLayer, Dense, Flatten, Conv1D, Conv2D, Conv3D, Cropping1D,\
Cropping2D, Cropping3D, UpSampling1D, UpSampling2D, UpSampling3D, ZeroPadding1D, ZeroPadding2D, ZeroPadding3D, \
GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalAveragePooling1D, GlobalAveragePooling2D, Dropout, Activation, \
ActivityRegularization, Masking, LeakyReLU, ELU, ThresholdedReLU, GaussianNoise, GaussianDropout, AlphaDropout, \
ReLU, Reshape, Permute, RepeatVector, Embedding, Add, Multiply, Average, Maximum, Concatenate, SimpleRNN, GRU, LSTM
from tensorflow.keras.layers.pooling.base_pooling1d import Pooling1D
from tensorflow.keras.layers.pooling.base_pooling2d import Pooling2D
from tensorflow.keras.layers.pooling.base_pooling3d import Pooling3D
from tensorflow.keras.models import Model

from kerassurgeon import utils
Expand Down Expand Up @@ -32,6 +39,7 @@ class Surgeon:
copy: If True, the model will be copied before and after any operations
This keeps the layers in the original model and the new model separate.
"""

def __init__(self, model, copy=None):
if copy:
self.model = utils.clean_copy(model)
Expand Down Expand Up @@ -125,7 +133,7 @@ def add_job(self, job, layer, *,

else:
raise ValueError(job + ' is not a recognised job. Valid jobs '
'are:\n-', '\n- '.join(self.valid_jobs))
'are:\n-', '\n- '.join(self.valid_jobs))

# Get nodes to be operated on for this job
job_nodes = []
Expand Down Expand Up @@ -391,11 +399,11 @@ def _apply_delete_mask(self, node, inbound_masks):
data_format = getattr(layer, 'data_format', 'channels_last')
inbound_masks = utils.single_element(inbound_masks)
# otherwise, delete_mask.shape should be: layer.input_shape[1:]
layer_class = layer.__class__.__name__
if layer_class == 'InputLayer':
# layer_class = layer.__class__.__name__
if isinstance(layer, InputLayer):
raise RuntimeError('This should never get here!')

elif layer_class == 'Dense':
elif isinstance(layer, Dense):
if np.all(inbound_masks):
new_layer = layer
else:
Expand All @@ -406,11 +414,11 @@ def _apply_delete_mask(self, node, inbound_masks):
new_layer = type(layer).from_config(config)
outbound_mask = None

elif layer_class == 'Flatten':
elif isinstance(layer, Flatten):
outbound_mask = np.reshape(inbound_masks, [-1, ])
new_layer = layer

elif layer_class in ('Conv1D', 'Conv2D', 'Conv3D'):
elif isinstance(layer, (Conv1D, Conv2D, Conv3D)):
if np.all(inbound_masks):
new_layer = layer
else:
Expand All @@ -436,98 +444,93 @@ def _apply_delete_mask(self, node, inbound_masks):
new_layer = type(layer).from_config(config)
outbound_mask = None

elif layer_class in ('Cropping1D', 'Cropping2D', 'Cropping3D',
'MaxPooling1D', 'MaxPooling2D',
'MaxPooling3D',
'AveragePooling1D', 'AveragePooling2D',
'AveragePooling3D'):
index = [slice(None, x, None) for x in output_shape[1:]]
elif isinstance(layer, (GlobalMaxPooling1D,
GlobalMaxPooling2D,
GlobalAveragePooling1D,
GlobalAveragePooling2D)):
# Get slice of mask with all singleton dimensions except
# channels dimension
index = [0] * (len(input_shape) - 1)
if data_format == 'channels_first':
index[0] = slice(None)
elif data_format == 'channels_last':
index[-1] = slice(None)
else:
raise ValueError('Invalid data format')
outbound_mask = inbound_masks[tuple(index)]
channels_vector = inbound_masks[tuple(index)]
# Tile this slice to create the outbound mask
outbound_mask = channels_vector
new_layer = layer

elif layer_class in ('UpSampling1D',
'UpSampling2D',
'UpSampling3D',
'ZeroPadding1D',
'ZeroPadding2D',
'ZeroPadding3D'):

# Get slice of mask with all singleton dimensions except
# channels dimension
index = [slice(1)] * (len(input_shape) - 1)
tile_shape = list(output_shape[1:])
elif isinstance(layer, (Cropping1D, Cropping2D, Cropping3D,
Pooling1D, Pooling2D,
Pooling3D)):
index = [slice(None, x, None) for x in output_shape[1:]]
if data_format == 'channels_first':
index[0] = slice(None)
tile_shape[0] = 1
elif data_format == 'channels_last':
index[-1] = slice(None)
tile_shape[-1] = 1
else:
raise ValueError('Invalid data format')
channels_vector = inbound_masks[tuple(index)]
# Tile this slice to create the outbound mask
outbound_mask = np.tile(channels_vector, tile_shape)
outbound_mask = inbound_masks[tuple(index)]
new_layer = layer

elif layer_class in ('GlobalMaxPooling1D',
'GlobalMaxPooling2D',
'GlobalAveragePooling1D',
'GlobalAveragePooling2D'):
elif isinstance(layer, (UpSampling1D,
UpSampling2D,
UpSampling3D,
ZeroPadding1D,
ZeroPadding2D,
ZeroPadding3D)):

# Get slice of mask with all singleton dimensions except
# channels dimension
index = [0] * (len(input_shape) - 1)
index = [slice(1)] * (len(input_shape) - 1)
tile_shape = list(output_shape[1:])
if data_format == 'channels_first':
index[0] = slice(None)
tile_shape[0] = 1
elif data_format == 'channels_last':
index[-1] = slice(None)
tile_shape[-1] = 1
else:
raise ValueError('Invalid data format')
channels_vector = inbound_masks[tuple(index)]
# Tile this slice to create the outbound mask
outbound_mask = channels_vector
outbound_mask = np.tile(channels_vector, tile_shape)
new_layer = layer

elif layer_class in ('Dropout',
'Activation',
'SpatialDropout1D',
'SpatialDropout2D',
'SpatialDropout3D',
'ActivityRegularization',
'Masking',
'LeakyReLU',
'ELU',
'ThresholdedReLU',
'GaussianNoise',
'GaussianDropout',
'AlphaDropout',
'ReLU'):
elif isinstance(layer, (Dropout,
Activation,
ActivityRegularization,
Masking,
LeakyReLU,
ELU,
ThresholdedReLU,
GaussianNoise,
GaussianDropout,
AlphaDropout,
ReLU)):
# Pass-through layers
outbound_mask = inbound_masks
new_layer = layer

elif layer_class == 'Reshape':
elif isinstance(layer, Reshape):
outbound_mask = np.reshape(inbound_masks, layer.target_shape)
new_layer = layer

elif layer_class == 'Permute':
elif isinstance(layer, Permute):
outbound_mask = np.transpose(inbound_masks,
[x-1 for x in layer.dims])
[x - 1 for x in layer.dims])
new_layer = layer

elif layer_class == 'RepeatVector':
elif isinstance(layer, RepeatVector):
outbound_mask = np.repeat(
np.expand_dims(inbound_masks, 0),
layer.n,
axis=0)
new_layer = layer

elif layer_class == 'Embedding':
elif isinstance(layer, Embedding):
# Embedding will always be the first layer so it doesn't need
# to consider the inbound_delete_mask
if inbound_masks is not None:
Expand All @@ -537,25 +540,25 @@ def _apply_delete_mask(self, node, inbound_masks):
outbound_mask = None
new_layer = layer

elif layer_class in ('Add', 'Multiply', 'Average', 'Maximum'):
elif isinstance(layer, (Add, Multiply, Average, Maximum)):
# The inputs must be the same size
if not utils.all_equal(inbound_masks):
ValueError(
'{0} layers must have the same size inputs. All '
'inbound nodes must have the same channels deleted'
.format(layer_class))
.format(layer.__class__.__name__))
outbound_mask = inbound_masks[1]
new_layer = layer

elif layer_class == 'Concatenate':
elif isinstance(layer, Concatenate):
axis = layer.axis
if layer.axis < 0:
axis = axis % len(layer.input_shape[0])
# Below: axis=axis-1 because the mask excludes the batch dimension
outbound_mask = np.concatenate(inbound_masks, axis=axis-1)
outbound_mask = np.concatenate(inbound_masks, axis=axis - 1)
new_layer = layer

elif layer_class in ('SimpleRNN', 'GRU', 'LSTM'):
elif isinstance(layer, (SimpleRNN, GRU, LSTM)):
if np.all(inbound_masks):
new_layer = layer
else:
Expand All @@ -566,7 +569,7 @@ def _apply_delete_mask(self, node, inbound_masks):
new_layer = type(layer).from_config(config)
outbound_mask = None

elif layer_class == 'BatchNormalization':
elif isinstance(layer, BatchNormalization):
outbound_mask = inbound_masks
# Get slice of mask with all singleton dimensions except
# channels dimension
Expand Down Expand Up @@ -598,8 +601,9 @@ def _apply_delete_mask(self, node, inbound_masks):
# - Dot
# - PReLU
# Warning/error needed for Reshape if channels axis is split
print(type(layer))
raise ValueError('"{0}" layers are currently '
'unsupported.'.format(layer_class))
'unsupported.'.format(type(layer)))

if len(layer.inbound_nodes) > 1 and new_layer != layer:
self._replace_layers_map[layer] = (new_layer, outbound_mask)
Expand Down