Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: residual depth #265

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ TCN(
kernel_initializer='he_normal',
use_batch_norm=False,
use_layer_norm=False,
use_weight_norm=False,
go_backwards=False,
return_state=False,
**kwargs
Expand All @@ -64,7 +63,6 @@ TCN(
- `kernel_initializer`: Initializer for the kernel weights matrix (Conv1D).
- `use_batch_norm`: Whether to use batch normalization in the residual layers or not.
- `use_layer_norm`: Whether to use layer normalization in the residual layers or not.
- `use_weight_norm`: Whether to use weight normalization in the residual layers or not.
- `go_backwards`: Boolean (default False). If True, process the input sequence backwards and return the reversed sequence.
- `return_state`: Boolean. Whether to return the last state in addition to the output. Default: False.
- `kwargs`: Any other set of arguments for configuring the parent class Layer. For example "name=str", Name of the model. Use unique names when using multiple TCN.
Expand Down Expand Up @@ -96,7 +94,7 @@ Here are some of my notes regarding my experience using TCN:
- `activation`: Leave it to default. I have never changed it.
- `kernel_initializer`: If the training of the TCN gets stuck, it might be worth changing this parameter. For example: `glorot_uniform`.

- `use_batch_norm`, `use_weight_norm`, `use_layer_norm`: Use normalization if your network is big enough and the task contains enough data. I usually prefer using `use_layer_norm`, but you can try them all and see which one works the best.
- `use_batch_norm`, `use_layer_norm`: Use normalization if your network is big enough and the task contains enough data. I usually prefer using `use_layer_norm`, but you can try them both and see which one works the best.


### Receptive field
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
os.environ['GRPC_PYTHON_BUILD_SYSTEM_OPENSSL'] = '1'
os.environ['GRPC_PYTHON_BUILD_SYSTEM_ZLIB'] = '1'

install_requires = ['numpy', tensorflow, 'tensorflow_addons']
install_requires = ['numpy', tensorflow]

setup(
name='keras-tcn',
Expand Down
1 change: 0 additions & 1 deletion tasks/adding_problem/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def run_task():
nb_stacks=1,
max_len=x_train.shape[1],
use_skip_connections=False,
use_weight_norm=True,
regression=True,
dropout_rate=0
)
Expand Down
1 change: 0 additions & 1 deletion tasks/copy_memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def run_task():
use_skip_connections=True,
opt='rmsprop',
lr=5e-4,
use_weight_norm=True,
return_sequences=True)

print(f'x_train.shape = {x_train.shape}')
Expand Down
1 change: 0 additions & 1 deletion tasks/mnist_pixel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def run_task():
dilations=[2 ** i for i in range(9)],
nb_stacks=1,
max_len=x_train[0:1].shape[1],
use_weight_norm=True,
use_skip_connections=True)

print(f'x_train.shape = {x_train.shape}')
Expand Down
1 change: 0 additions & 1 deletion tasks/time_series_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
kernel_size=2,
use_skip_connections=False,
use_batch_norm=False,
use_weight_norm=False,
use_layer_norm=False
),
Dense(1, activation='linear')
Expand Down
42 changes: 17 additions & 25 deletions tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ResidualBlock(Layer):

def __init__(self,
dilation_rate: int,
residual_depth: int,
nb_filters: int,
kernel_size: int,
padding: str,
Expand All @@ -36,13 +37,13 @@ def __init__(self,
kernel_initializer: str = 'he_normal',
use_batch_norm: bool = False,
use_layer_norm: bool = False,
use_weight_norm: bool = False,
**kwargs):
"""Defines the residual block for the WaveNet TCN
Args:
x: The previous layer in the model
training: boolean indicating whether the layer should behave in training mode or in inference mode
dilation_rate: The dilation power of 2 we are using for this residual block
residual_depth: The number of residual convolutions to use in this block
nb_filters: The number of convolutional filters to use in this block
kernel_size: The size of the convolutional kernel
padding: The padding used in the convolutional layers, 'same' or 'causal'.
Expand All @@ -51,19 +52,18 @@ def __init__(self,
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
use_batch_norm: Whether to use batch normalization in the residual layers or not.
use_layer_norm: Whether to use layer normalization in the residual layers or not.
use_weight_norm: Whether to use weight normalization in the residual layers or not.
kwargs: Any initializers for Layer class.
"""

self.dilation_rate = dilation_rate
self.residual_depth = residual_depth
self.nb_filters = nb_filters
self.kernel_size = kernel_size
self.padding = padding
self.activation = activation
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
self.use_layer_norm = use_layer_norm
self.use_weight_norm = use_weight_norm
self.kernel_initializer = kernel_initializer
self.layers = []
self.shape_match_conv = None
Expand All @@ -88,7 +88,7 @@ def build(self, input_shape):
self.layers = []
self.res_output_shape = input_shape

for k in range(2): # dilated conv block.
for k in range(self.residual_depth): # dilated conv block.
name = 'conv1D_{}'.format(k)
with K.name_scope(name): # name scope used to make sure weights get unique names
conv = Conv1D(
Expand All @@ -99,20 +99,13 @@ def build(self, input_shape):
name=name,
kernel_initializer=self.kernel_initializer
)
if self.use_weight_norm:
from tensorflow_addons.layers import WeightNormalization
# wrap it. WeightNormalization API is different than BatchNormalization or LayerNormalization.
with K.name_scope('norm_{}'.format(k)):
conv = WeightNormalization(conv)
self._build_layer(conv)

with K.name_scope('norm_{}'.format(k)):
if self.use_batch_norm:
self._build_layer(BatchNormalization())
elif self.use_layer_norm:
self._build_layer(LayerNormalization())
elif self.use_weight_norm:
pass # done above.

with K.name_scope('act_and_dropout_{}'.format(k)):
self._build_layer(Activation(self.activation, name='Act_Conv1D_{}'.format(k)))
Expand Down Expand Up @@ -191,6 +184,7 @@ class TCN(Layer):
nb_filters: The number of filters to use in the convolutional layers. Can be a list.
kernel_size: The size of the kernel to use in each convolutional layer.
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
residual_depth: The depth of a residual block. Default is 2.
nb_stacks : The number of stacks of residual blocks to use.
padding: The padding to use in the convolutional layers, 'causal' or 'same'.
use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
Expand All @@ -200,7 +194,6 @@ class TCN(Layer):
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
use_batch_norm: Whether to use batch normalization in the residual layers or not.
use_layer_norm: Whether to use layer normalization in the residual layers or not.
use_weight_norm: Whether to use weight normalization in the residual layers or not.
go_backwards: Boolean (default False). If True, process the input sequence backwards and
return the reversed sequence.
return_state: Boolean. Whether to return the last state in addition to the output. Default: False.
Expand All @@ -213,8 +206,9 @@ class TCN(Layer):
def __init__(self,
nb_filters=64,
kernel_size=3,
nb_stacks=1,
dilations=(1, 2, 4, 8, 16, 32),
residual_depth=2,
nb_stacks=1,
padding='causal',
use_skip_connections=True,
dropout_rate=0.0,
Expand All @@ -223,7 +217,6 @@ def __init__(self,
kernel_initializer='he_normal',
use_batch_norm=False,
use_layer_norm=False,
use_weight_norm=False,
go_backwards=False,
return_state=False,
**kwargs):
Expand All @@ -232,6 +225,7 @@ def __init__(self,
self.dropout_rate = dropout_rate
self.use_skip_connections = use_skip_connections
self.dilations = dilations
self.residual_depth = residual_depth
self.nb_stacks = nb_stacks
self.kernel_size = kernel_size
self.nb_filters = nb_filters
Expand All @@ -240,7 +234,6 @@ def __init__(self,
self.kernel_initializer = kernel_initializer
self.use_batch_norm = use_batch_norm
self.use_layer_norm = use_layer_norm
self.use_weight_norm = use_weight_norm
self.go_backwards = go_backwards
self.return_state = return_state
self.skip_connections = []
Expand All @@ -251,7 +244,10 @@ def __init__(self,
self.output_slice_index = None # in case return_sequence=False
self.padding_same_and_time_dim_unknown = False # edge case if padding='same' and time_dim = None

if self.use_batch_norm + self.use_layer_norm + self.use_weight_norm > 1:
if self.residual_depth < 1:
raise ValueError('Residual depth must be at least 1.')

if self.use_batch_norm + self.use_layer_norm > 1:
raise ValueError('Only one normalization can be specified at once.')

if isinstance(self.nb_filters, list):
Expand All @@ -268,7 +264,7 @@ def __init__(self,

@property
def receptive_field(self):
return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)
return 1 + self.residual_depth * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)

def tolist(self, shape):
try:
Expand All @@ -291,14 +287,14 @@ def build(self, input_shape):
for i, d in enumerate(self.dilations):
res_block_filters = self.nb_filters[i] if isinstance(self.nb_filters, list) else self.nb_filters
self.residual_blocks.append(ResidualBlock(dilation_rate=d,
residual_depth=self.residual_depth,
nb_filters=res_block_filters,
kernel_size=self.kernel_size,
padding=self.padding,
activation=self.activation_name,
dropout_rate=self.dropout_rate,
use_batch_norm=self.use_batch_norm,
use_layer_norm=self.use_layer_norm,
use_weight_norm=self.use_weight_norm,
kernel_initializer=self.kernel_initializer,
name='residual_block_{}'.format(len(self.residual_blocks))))
# build newest residual block
Expand Down Expand Up @@ -355,7 +351,7 @@ def call(self, inputs, training=None, **kwargs):
self.skip_connections.append(skip_out)
self.layers_outputs.append(x)

if self.use_skip_connections:
if self.use_skip_connections and len(self.skip_connections) > 0:
if len(self.skip_connections) > 1:
# Keras: A merge layer should be called on a list of at least 2 inputs. Got 1 input.
x = layers.add(self.skip_connections, name='Add_Skip_Connections')
Expand Down Expand Up @@ -388,7 +384,6 @@ def get_config(self):
config['activation'] = self.activation_name
config['use_batch_norm'] = self.use_batch_norm
config['use_layer_norm'] = self.use_layer_norm
config['use_weight_norm'] = self.use_weight_norm
config['kernel_initializer'] = self.kernel_initializer
config['go_backwards'] = self.go_backwards
config['return_state'] = self.return_state
Expand All @@ -414,8 +409,7 @@ def compiled_tcn(num_feat, # type: int
opt='adam',
lr=0.002,
use_batch_norm=False,
use_layer_norm=False,
use_weight_norm=False):
use_layer_norm=False,):
# type: (...) -> Model
"""Creates a compiled TCN model for a given task (i.e. regression or classification).
Classification uses a sparse categorical loss. Please input class ids and not one-hot encodings.
Expand All @@ -440,7 +434,6 @@ def compiled_tcn(num_feat, # type: int
lr: Learning rate.
use_batch_norm: Whether to use batch normalization in the residual layers or not.
use_layer_norm: Whether to use layer normalization in the residual layers or not.
use_weight_norm: Whether to use weight normalization in the residual layers or not.
Returns:
A compiled keras TCN.
"""
Expand All @@ -451,8 +444,7 @@ def compiled_tcn(num_feat, # type: int

x = TCN(nb_filters, kernel_size, nb_stacks, dilations, padding,
use_skip_connections, dropout_rate, return_sequences,
activation, kernel_initializer, use_batch_norm, use_layer_norm,
use_weight_norm, name=name)(input_layer)
activation, kernel_initializer, use_batch_norm, use_layer_norm, name=name)(input_layer)

print('x.shape=', x.shape)

Expand Down