From 1688e8a31c61d6fef8b2253a62473952206e6c81 Mon Sep 17 00:00:00 2001 From: Christian Steinmeyer Date: Tue, 20 Jun 2023 09:14:08 +0200 Subject: [PATCH] add implementations for Conv2DTranspose and SeparableConv2D see https://github.com/BenWhetton/keras-surgeon/pull/55 and https://github.com/BenWhetton/keras-surgeon/pull/27 --- src/kerassurgeon/surgeon.py | 68 +++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/src/kerassurgeon/surgeon.py b/src/kerassurgeon/surgeon.py index 2aa0641..ba42cfc 100644 --- a/src/kerassurgeon/surgeon.py +++ b/src/kerassurgeon/surgeon.py @@ -437,6 +437,60 @@ def _apply_delete_mask(self, node, inbound_masks): new_layer = type(layer).from_config(config) outbound_mask = None + elif layer_class in ('Conv2DTranspose'): + if np.all(inbound_masks): + new_layer = layer + else: + if data_format == 'channels_first': + inbound_masks = np.swapaxes(inbound_masks, 0, -1) + # Conv layer: trim down inbound_masks to filter shape + k_size = layer.kernel_size + index = [slice(None, 1, None) for _ in k_size] + inbound_masks = inbound_masks[tuple(index + [slice(None)])] + weights = layer.get_weights() + # Delete unused weights to obtain new_weights + # Each deleted channel was connected to all of the channels + # in layer; therefore, the mask must be repeated for each + # channel. + # `delete_mask`'s size: size(weights[0]) + delete_mask = np.tile( + inbound_masks[..., np.newaxis], list(k_size) + [1, weights[0].shape[-2]] + ).transpose(0, 1, 3, 2) + new_shape = list(weights[0].shape) + new_shape[-1] = -1 # Input size channels + weights[0] = np.reshape(weights[0][delete_mask], new_shape) + # Instantiate new layer with new_weights + config = layer.get_config() + config['weights'] = weights + new_layer = type(layer).from_config(config) + + outbound_mask = None + + elif layer_class in ('SeparableConv2D'): + if np.all(inbound_masks): + new_layer = layer + else: + if data_format == 'channels_first': + inbound_masks = np.swapaxes(inbound_masks, 0, -1) + + # Conv layer: trim down inbound_masks to filter shape + k_size = layer.kernel_size + index = [slice(None, dim_size, None) for dim_size in k_size] + delete_mask = inbound_masks[tuple(index + [slice(None)])] + # Delete unused weights to obtain new_weights + weights = layer.get_weights() + + new_shape = list(weights[0].shape) + new_shape[-2] = -1 # Weights always have channels_last + weights[0] = np.reshape(weights[0][delete_mask], new_shape) + weights[1] = weights[1][:, :, delete_mask[0][0], :] + + # Instantiate new layer with new_weights + config = layer.get_config() + config['weights'] = weights + new_layer = type(layer).from_config(config) + outbound_mask = None + elif layer_class in ( 'Cropping1D', 'Cropping2D', @@ -598,8 +652,6 @@ def _apply_delete_mask(self, node, inbound_masks): else: # Not implemented: # - Lambda - # - SeparableConv2D - # - Conv2DTranspose # - LocallyConnected1D # - LocallyConnected2D # - TimeDistributed @@ -660,6 +712,18 @@ def _delete_channel_weights(self, layer, channel_indices): channel_indices_lstm = [layer.units * m + i for m in range(4) for i in channel_indices] weights = [np.delete(w, channel_indices_lstm, axis=-1) for w in layer.get_weights()] weights[1] = np.delete(weights[1], channel_indices, axis=0) + elif layer.__class__.__name__ == 'Conv2DTranspose': + weights = [] + allweights = layer.get_weights() + w = np.delete(allweights[0], channel_indices, axis=2) + weights.append(w) + if len(allweights) == 2: + b = np.delete(allweights[1], channel_indices, axis=-1) + weights.append(b) + elif layer.__class__.__name__ == 'SeparableConv2D': + weights = layer.get_weights() + for i in range(1, len(weights)): + weights[i] = np.delete(weights[i], channel_indices, axis=-1) else: weights = [np.delete(w, channel_indices, axis=-1) for w in layer.get_weights()] layer_config['weights'] = weights