diff --git a/src/kerassurgeon/surgeon.py b/src/kerassurgeon/surgeon.py index 2a6d7f4..b66c683 100644 --- a/src/kerassurgeon/surgeon.py +++ b/src/kerassurgeon/surgeon.py @@ -2,7 +2,14 @@ import numpy as np import tensorflow as tf -from tensorflow.keras.layers import BatchNormalization +from tensorflow.keras.layers import BatchNormalization, InputLayer, Dense, Flatten, Conv1D, Conv2D, Conv3D, Cropping1D,\ + Cropping2D, Cropping3D, UpSampling1D, UpSampling2D, UpSampling3D, ZeroPadding1D, ZeroPadding2D, ZeroPadding3D, \ + GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalAveragePooling1D, GlobalAveragePooling2D, Dropout, Activation, \ + ActivityRegularization, Masking, LeakyReLU, ELU, ThresholdedReLU, GaussianNoise, GaussianDropout, AlphaDropout, \ + ReLU, Reshape, Permute, RepeatVector, Embedding, Add, Multiply, Average, Maximum, Concatenate, SimpleRNN, GRU, LSTM +from tensorflow.keras.layers.pooling.base_pooling1d import Pooling1D +from tensorflow.keras.layers.pooling.base_pooling2d import Pooling2D +from tensorflow.keras.layers.pooling.base_pooling3d import Pooling3D from tensorflow.keras.models import Model from kerassurgeon import utils @@ -32,6 +39,7 @@ class Surgeon: copy: If True, the model will be copied before and after any operations This keeps the layers in the original model and the new model separate. """ + def __init__(self, model, copy=None): if copy: self.model = utils.clean_copy(model) @@ -125,7 +133,7 @@ def add_job(self, job, layer, *, else: raise ValueError(job + ' is not a recognised job. Valid jobs ' - 'are:\n-', '\n- '.join(self.valid_jobs)) + 'are:\n-', '\n- '.join(self.valid_jobs)) # Get nodes to be operated on for this job job_nodes = [] @@ -391,11 +399,11 @@ def _apply_delete_mask(self, node, inbound_masks): data_format = getattr(layer, 'data_format', 'channels_last') inbound_masks = utils.single_element(inbound_masks) # otherwise, delete_mask.shape should be: layer.input_shape[1:] - layer_class = layer.__class__.__name__ - if layer_class == 'InputLayer': + # layer_class = layer.__class__.__name__ + if isinstance(layer, InputLayer): raise RuntimeError('This should never get here!') - elif layer_class == 'Dense': + elif isinstance(layer, Dense): if np.all(inbound_masks): new_layer = layer else: @@ -406,11 +414,11 @@ def _apply_delete_mask(self, node, inbound_masks): new_layer = type(layer).from_config(config) outbound_mask = None - elif layer_class == 'Flatten': + elif isinstance(layer, Flatten): outbound_mask = np.reshape(inbound_masks, [-1, ]) new_layer = layer - elif layer_class in ('Conv1D', 'Conv2D', 'Conv3D'): + elif isinstance(layer, (Conv1D, Conv2D, Conv3D)): if np.all(inbound_masks): new_layer = layer else: @@ -436,98 +444,93 @@ def _apply_delete_mask(self, node, inbound_masks): new_layer = type(layer).from_config(config) outbound_mask = None - elif layer_class in ('Cropping1D', 'Cropping2D', 'Cropping3D', - 'MaxPooling1D', 'MaxPooling2D', - 'MaxPooling3D', - 'AveragePooling1D', 'AveragePooling2D', - 'AveragePooling3D'): - index = [slice(None, x, None) for x in output_shape[1:]] + elif isinstance(layer, (GlobalMaxPooling1D, + GlobalMaxPooling2D, + GlobalAveragePooling1D, + GlobalAveragePooling2D)): + # Get slice of mask with all singleton dimensions except + # channels dimension + index = [0] * (len(input_shape) - 1) if data_format == 'channels_first': index[0] = slice(None) elif data_format == 'channels_last': index[-1] = slice(None) else: raise ValueError('Invalid data format') - outbound_mask = inbound_masks[tuple(index)] + channels_vector = inbound_masks[tuple(index)] + # Tile this slice to create the outbound mask + outbound_mask = channels_vector new_layer = layer - elif layer_class in ('UpSampling1D', - 'UpSampling2D', - 'UpSampling3D', - 'ZeroPadding1D', - 'ZeroPadding2D', - 'ZeroPadding3D'): - - # Get slice of mask with all singleton dimensions except - # channels dimension - index = [slice(1)] * (len(input_shape) - 1) - tile_shape = list(output_shape[1:]) + elif isinstance(layer, (Cropping1D, Cropping2D, Cropping3D, + Pooling1D, Pooling2D, + Pooling3D)): + index = [slice(None, x, None) for x in output_shape[1:]] if data_format == 'channels_first': index[0] = slice(None) - tile_shape[0] = 1 elif data_format == 'channels_last': index[-1] = slice(None) - tile_shape[-1] = 1 else: raise ValueError('Invalid data format') - channels_vector = inbound_masks[tuple(index)] - # Tile this slice to create the outbound mask - outbound_mask = np.tile(channels_vector, tile_shape) + outbound_mask = inbound_masks[tuple(index)] new_layer = layer - elif layer_class in ('GlobalMaxPooling1D', - 'GlobalMaxPooling2D', - 'GlobalAveragePooling1D', - 'GlobalAveragePooling2D'): + elif isinstance(layer, (UpSampling1D, + UpSampling2D, + UpSampling3D, + ZeroPadding1D, + ZeroPadding2D, + ZeroPadding3D)): + # Get slice of mask with all singleton dimensions except # channels dimension - index = [0] * (len(input_shape) - 1) + index = [slice(1)] * (len(input_shape) - 1) + tile_shape = list(output_shape[1:]) if data_format == 'channels_first': index[0] = slice(None) + tile_shape[0] = 1 elif data_format == 'channels_last': index[-1] = slice(None) + tile_shape[-1] = 1 else: raise ValueError('Invalid data format') channels_vector = inbound_masks[tuple(index)] # Tile this slice to create the outbound mask - outbound_mask = channels_vector + outbound_mask = np.tile(channels_vector, tile_shape) new_layer = layer - elif layer_class in ('Dropout', - 'Activation', - 'SpatialDropout1D', - 'SpatialDropout2D', - 'SpatialDropout3D', - 'ActivityRegularization', - 'Masking', - 'LeakyReLU', - 'ELU', - 'ThresholdedReLU', - 'GaussianNoise', - 'GaussianDropout', - 'AlphaDropout', - 'ReLU'): + elif isinstance(layer, (Dropout, + Activation, + ActivityRegularization, + Masking, + LeakyReLU, + ELU, + ThresholdedReLU, + GaussianNoise, + GaussianDropout, + AlphaDropout, + ReLU)): # Pass-through layers outbound_mask = inbound_masks new_layer = layer - elif layer_class == 'Reshape': + elif isinstance(layer, Reshape): outbound_mask = np.reshape(inbound_masks, layer.target_shape) new_layer = layer - elif layer_class == 'Permute': + elif isinstance(layer, Permute): outbound_mask = np.transpose(inbound_masks, - [x-1 for x in layer.dims]) + [x - 1 for x in layer.dims]) new_layer = layer - elif layer_class == 'RepeatVector': + elif isinstance(layer, RepeatVector): outbound_mask = np.repeat( np.expand_dims(inbound_masks, 0), layer.n, axis=0) new_layer = layer - elif layer_class == 'Embedding': + elif isinstance(layer, Embedding): # Embedding will always be the first layer so it doesn't need # to consider the inbound_delete_mask if inbound_masks is not None: @@ -537,25 +540,25 @@ def _apply_delete_mask(self, node, inbound_masks): outbound_mask = None new_layer = layer - elif layer_class in ('Add', 'Multiply', 'Average', 'Maximum'): + elif isinstance(layer, (Add, Multiply, Average, Maximum)): # The inputs must be the same size if not utils.all_equal(inbound_masks): ValueError( '{0} layers must have the same size inputs. All ' 'inbound nodes must have the same channels deleted' - .format(layer_class)) + .format(layer.__class__.__name__)) outbound_mask = inbound_masks[1] new_layer = layer - elif layer_class == 'Concatenate': + elif isinstance(layer, Concatenate): axis = layer.axis if layer.axis < 0: axis = axis % len(layer.input_shape[0]) # Below: axis=axis-1 because the mask excludes the batch dimension - outbound_mask = np.concatenate(inbound_masks, axis=axis-1) + outbound_mask = np.concatenate(inbound_masks, axis=axis - 1) new_layer = layer - elif layer_class in ('SimpleRNN', 'GRU', 'LSTM'): + elif isinstance(layer, (SimpleRNN, GRU, LSTM)): if np.all(inbound_masks): new_layer = layer else: @@ -566,7 +569,7 @@ def _apply_delete_mask(self, node, inbound_masks): new_layer = type(layer).from_config(config) outbound_mask = None - elif layer_class == 'BatchNormalization': + elif isinstance(layer, BatchNormalization): outbound_mask = inbound_masks # Get slice of mask with all singleton dimensions except # channels dimension @@ -598,8 +601,9 @@ def _apply_delete_mask(self, node, inbound_masks): # - Dot # - PReLU # Warning/error needed for Reshape if channels axis is split + print(type(layer)) raise ValueError('"{0}" layers are currently ' - 'unsupported.'.format(layer_class)) + 'unsupported.'.format(type(layer))) if len(layer.inbound_nodes) > 1 and new_layer != layer: self._replace_layers_map[layer] = (new_layer, outbound_mask)