Skip to content

Commit

Permalink
add implementations for Conv2DTranspose and SeparableConv2D
Browse files Browse the repository at this point in the history
  • Loading branch information
christian-steinmeyer committed Jul 6, 2023
1 parent 40b9328 commit 1688e8a
Showing 1 changed file with 66 additions and 2 deletions.
68 changes: 66 additions & 2 deletions src/kerassurgeon/surgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,60 @@ def _apply_delete_mask(self, node, inbound_masks):
new_layer = type(layer).from_config(config)
outbound_mask = None

elif layer_class in ('Conv2DTranspose'):
if np.all(inbound_masks):
new_layer = layer
else:
if data_format == 'channels_first':
inbound_masks = np.swapaxes(inbound_masks, 0, -1)
# Conv layer: trim down inbound_masks to filter shape
k_size = layer.kernel_size
index = [slice(None, 1, None) for _ in k_size]
inbound_masks = inbound_masks[tuple(index + [slice(None)])]
weights = layer.get_weights()
# Delete unused weights to obtain new_weights
# Each deleted channel was connected to all of the channels
# in layer; therefore, the mask must be repeated for each
# channel.
# `delete_mask`'s size: size(weights[0])
delete_mask = np.tile(
inbound_masks[..., np.newaxis], list(k_size) + [1, weights[0].shape[-2]]
).transpose(0, 1, 3, 2)
new_shape = list(weights[0].shape)
new_shape[-1] = -1 # Input size channels
weights[0] = np.reshape(weights[0][delete_mask], new_shape)
# Instantiate new layer with new_weights
config = layer.get_config()
config['weights'] = weights
new_layer = type(layer).from_config(config)

outbound_mask = None

elif layer_class in ('SeparableConv2D'):
if np.all(inbound_masks):
new_layer = layer
else:
if data_format == 'channels_first':
inbound_masks = np.swapaxes(inbound_masks, 0, -1)

# Conv layer: trim down inbound_masks to filter shape
k_size = layer.kernel_size
index = [slice(None, dim_size, None) for dim_size in k_size]
delete_mask = inbound_masks[tuple(index + [slice(None)])]
# Delete unused weights to obtain new_weights
weights = layer.get_weights()

new_shape = list(weights[0].shape)
new_shape[-2] = -1 # Weights always have channels_last
weights[0] = np.reshape(weights[0][delete_mask], new_shape)
weights[1] = weights[1][:, :, delete_mask[0][0], :]

# Instantiate new layer with new_weights
config = layer.get_config()
config['weights'] = weights
new_layer = type(layer).from_config(config)
outbound_mask = None

elif layer_class in (
'Cropping1D',
'Cropping2D',
Expand Down Expand Up @@ -598,8 +652,6 @@ def _apply_delete_mask(self, node, inbound_masks):
else:
# Not implemented:
# - Lambda
# - SeparableConv2D
# - Conv2DTranspose
# - LocallyConnected1D
# - LocallyConnected2D
# - TimeDistributed
Expand Down Expand Up @@ -660,6 +712,18 @@ def _delete_channel_weights(self, layer, channel_indices):
channel_indices_lstm = [layer.units * m + i for m in range(4) for i in channel_indices]
weights = [np.delete(w, channel_indices_lstm, axis=-1) for w in layer.get_weights()]
weights[1] = np.delete(weights[1], channel_indices, axis=0)
elif layer.__class__.__name__ == 'Conv2DTranspose':
weights = []
allweights = layer.get_weights()
w = np.delete(allweights[0], channel_indices, axis=2)
weights.append(w)
if len(allweights) == 2:
b = np.delete(allweights[1], channel_indices, axis=-1)
weights.append(b)
elif layer.__class__.__name__ == 'SeparableConv2D':
weights = layer.get_weights()
for i in range(1, len(weights)):
weights[i] = np.delete(weights[i], channel_indices, axis=-1)
else:
weights = [np.delete(w, channel_indices, axis=-1) for w in layer.get_weights()]
layer_config['weights'] = weights
Expand Down

0 comments on commit 1688e8a

Please sign in to comment.