diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b95ba7e..2e0f36d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: Simple Keras Attention CI +name: Keras Attention Layer CI on: [ push, pull_request ] @@ -9,19 +9,19 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [ 3.8 ] + python-version: [ "3.8", "3.9", "3.10" ] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install dependencies and package run: | python -m pip install --upgrade pip pip install flake8 pylint tox - - name: Static Analysis with Flake8 + - name: Static Analysis run: | flake8 . --ignore E402 --count --max-complexity 10 --max-line-length 127 --select=E9,F63,F7,F82 --show-source --statistics - name: Run example diff --git a/attention/attention.py b/attention/attention.py index 50b9451..099d3a4 100644 --- a/attention/attention.py +++ b/attention/attention.py @@ -1,33 +1,64 @@ import os from tensorflow.keras import backend as K -from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer +from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer, RepeatVector, Add # KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode. # In debug mode, the class Attention is no longer a Keras layer. # What it means in practice is that we can have access to the internal values # of each tensor. If we don't use debug, Keras treats the object -# as a layer and we can only get the final output. +# as a layer, and we can only get the final output. debug_flag = int(os.environ.get('KERAS_ATTENTION_DEBUG', 0)) +# References: +# - https://arxiv.org/pdf/1508.04025.pdf (Luong). +# - https://arxiv.org/pdf/1409.0473.pdf (Bahdanau). +# - https://machinelearningmastery.com/the-bahdanau-attention-mechanism/ (Some more explanation). + class Attention(object if debug_flag else Layer): + SCORE_LUONG = 'luong' + SCORE_BAHDANAU = 'bahdanau' - def __init__(self, units=128, **kwargs): + def __init__(self, units: int = 128, score: str = 'luong', **kwargs): super(Attention, self).__init__(**kwargs) + if score not in {self.SCORE_LUONG, self.SCORE_BAHDANAU}: + raise ValueError(f'Possible values for score are: [{self.SCORE_LUONG}] and [{self.SCORE_BAHDANAU}].') self.units = units + self.score = score # noinspection PyAttributeOutsideInit def build(self, input_shape): input_dim = int(input_shape[-1]) with K.name_scope(self.name if not debug_flag else 'attention'): - self.attention_score_vec = Dense(input_dim, use_bias=False, name='attention_score_vec') + # W in W*h_S. + if self.score == self.SCORE_LUONG: + self.luong_w = Dense(input_dim, use_bias=False, name='luong_w') + # dot : last hidden state H_t and every hidden state H_s. + self.luong_dot = Dot(axes=[1, 2], name='attention_score') + else: + # Dense implements the operation: output = activation(dot(input, kernel) + bias) + self.bahdanau_v = Dense(1, use_bias=False, name='bahdanau_v') + self.bahdanau_w1 = Dense(input_dim, use_bias=False, name='bahdanau_w1') + self.bahdanau_w2 = Dense(input_dim, use_bias=False, name='bahdanau_w2') + self.bahdanau_repeat = RepeatVector(input_shape[1]) + self.bahdanau_tanh = Activation('tanh', name='bahdanau_tanh') + self.bahdanau_add = Add() + self.h_t = Lambda(lambda x: x[:, -1, :], output_shape=(input_dim,), name='last_hidden_state') - self.attention_score = Dot(axes=[1, 2], name='attention_score') - self.attention_weight = Activation('softmax', name='attention_weight') - self.context_vector = Dot(axes=[1, 1], name='context_vector') - self.attention_output = Concatenate(name='attention_output') - self.attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector') + + # exp / sum(exp) -> softmax. + self.softmax_normalizer = Activation('softmax', name='attention_weight') + + # dot : score * every hidden state H_s. + # dot product. SUM(v1*v2). H_s = every source hidden state. + self.dot_context = Dot(axes=[1, 1], name='context_vector') + + # [Ct; ht] + self.concat_c_h = Concatenate(name='attention_output') + + # x -> tanh(w_c(x)) + self.w_c = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector') if not debug_flag: # debug: the call to build() is done in call(). super(Attention, self).build(input_shape) @@ -44,35 +75,37 @@ def __call__(self, inputs, training=None, **kwargs): # noinspection PyUnusedLocal def call(self, inputs, training=None, **kwargs): """ - Many-to-one attention mechanism for Keras. + Many-to-one attention mechanism for Keras. Supports: + - Luong's multiplicative style. + - Bahdanau's additive style. @param inputs: 3D tensor with shape (batch_size, time_steps, input_dim). @param training: not used in this layer. @return: 2D tensor with shape (batch_size, units) - @author: felixhao28, philipperemy. + @author: philipperemy, felixhao28. """ + h_s = inputs if debug_flag: - self.build(inputs.shape) - # Inside dense layer - # hidden_states dot W => score_first_part - # (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size) - # W is the trainable weight matrix of attention Luong's multiplicative style score - score_first_part = self.attention_score_vec(inputs) - # score_first_part dot last_hidden_state => attention_weights - # (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps) - h_t = self.h_t(inputs) - score = self.attention_score([h_t, score_first_part]) - attention_weights = self.attention_weight(score) - # (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size) - context_vector = self.context_vector([inputs, attention_weights]) - pre_activation = self.attention_output([context_vector, h_t]) - attention_vector = self.attention_vector(pre_activation) - return attention_vector + self.build(h_s.shape) + h_t = self.h_t(h_s) + if self.score == self.SCORE_LUONG: + # Luong's multiplicative style. + score = self.luong_dot([h_t, self.luong_w(h_s)]) + else: + # Bahdanau's additive style. + self.bahdanau_w1(h_s) + a1 = self.bahdanau_w1(h_t) + a2 = self.bahdanau_w2(h_s) + a1 = self.bahdanau_repeat(a1) + score = self.bahdanau_tanh(self.bahdanau_add([a1, a2])) + score = self.bahdanau_v(score) + score = K.squeeze(score, axis=-1) + + alpha_s = self.softmax_normalizer(score) + context_vector = self.dot_context([h_s, alpha_s]) + a_t = self.w_c(self.concat_c_h([context_vector, h_t])) + return a_t def get_config(self): - """ - Returns the config of a the layer. This is used for saving and loading from a model - :return: python dictionary with specs to rebuild layer - """ config = super(Attention, self).get_config() - config.update({'units': self.units}) + config.update({'units': self.units, 'score': self.score}) return config diff --git a/examples/add_two_numbers.py b/examples/add_two_numbers.py index 3c14623..466c2b3 100644 --- a/examples/add_two_numbers.py +++ b/examples/add_two_numbers.py @@ -10,14 +10,14 @@ from tensorflow.keras import Input from tensorflow.keras.callbacks import Callback from tensorflow.keras.layers import Dense, Dropout, LSTM -from tensorflow.keras.models import load_model, Model +from tensorflow.keras.models import Model from tensorflow.python.keras.utils.vis_utils import plot_model # KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode. # In debug mode, the class Attention is no longer a Keras layer. # What it means in practice is that we can have access to the internal values # of each tensor. If we don't use debug, Keras treats the object -# as a layer and we can only get the final output. +# as a layer, and we can only get the final output. # In this example we need it because we want to extract all the intermediate output values. os.environ['KERAS_ATTENTION_DEBUG'] = '1' @@ -74,7 +74,7 @@ def main(): # Define/compile the model. model_input = Input(shape=(seq_length, 1)) x = LSTM(100, return_sequences=True)(model_input) - x = Attention()(x) + x = Attention(128, score='bahdanau')(x) x = Dropout(0.2)(x) x = Dense(1, activation='linear')(x) model = Model(model_input, x) @@ -82,7 +82,7 @@ def main(): # Visualize the model. model.summary() - plot_model(model) + plot_model(model, dpi=200, show_dtype=True, show_shapes=True, show_layer_names=True) # Will display the activation map in task_add_two_numbers/ output_dir = Path('task_add_two_numbers') @@ -110,14 +110,6 @@ def on_epoch_end(self, epoch, logs=None): callbacks=[VisualiseAttentionMap()] ) - # test save/reload model. - pred1 = model.predict(x_val) - model.save('test_model.h5') - model_h5 = load_model('test_model.h5') - pred2 = model_h5.predict(x_val) - np.testing.assert_almost_equal(pred1, pred2) - print('Success.') - if __name__ == '__main__': # pip install pydot diff --git a/examples/example-attention.py b/examples/example-attention.py index 2e65c44..cc3e034 100644 --- a/examples/example-attention.py +++ b/examples/example-attention.py @@ -6,23 +6,18 @@ from attention import Attention -def main(): - # Dummy data. There is nothing to learn in this example. - num_samples, time_steps, input_dim, output_dim = 100, 10, 1, 1 - data_x = np.random.uniform(size=(num_samples, time_steps, input_dim)) - data_y = np.random.uniform(size=(num_samples, output_dim)) - +def run_test(data_x, data_y, time_steps, input_dim, score): # Define/compile the model. model_input = Input(shape=(time_steps, input_dim)) x = LSTM(64, return_sequences=True)(model_input) - x = Attention(units=32)(x) + x = Attention(units=32, score=score)(x) x = Dense(1)(x) model = Model(model_input, x) model.compile(loss='mae', optimizer='adam') model.summary() # train. - model.fit(data_x, data_y, epochs=10) + model.fit(data_x, data_y, epochs=30) # test save/reload model. pred1 = model.predict(data_x) @@ -33,5 +28,14 @@ def main(): print('Success.') +def main(): + # Dummy data. There is nothing to learn in this example. + num_samples, time_steps, input_dim, output_dim = 100, 10, 1, 1 + data_x = np.random.uniform(size=(num_samples, time_steps, input_dim)) + data_y = np.random.uniform(size=(num_samples, output_dim)) + run_test(data_x, data_y, time_steps, input_dim, score='luong') + run_test(data_x, data_y, time_steps, input_dim, score='bahdanau') + + if __name__ == '__main__': main() diff --git a/setup.py b/setup.py index d74b1a7..33879cc 100644 --- a/setup.py +++ b/setup.py @@ -2,8 +2,8 @@ setup( name='attention', - version='4.1', - description='Keras Simple Attention', + version='5.0.0', + description='Keras Attention Layer', author='Philippe Remy', license='Apache 2.0', long_description_content_type='text/markdown', diff --git a/tox.ini b/tox.ini index 5d8a8a3..f8f59fd 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,18 @@ [tox] -envlist = {py3}-tensorflow-{2.5.0,2.6.2,2.7.0,v2.8.0-rc0} +envlist = {py3}-tensorflow-{2.8,2.9,2.10,2.11} [testenv] +setenv = + PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python deps = -rexamples/examples-requirements.txt - tensorflow-2.5.0: tensorflow==2.5.0 - tensorflow-2.6.2: tensorflow==2.6.2 - tensorflow-2.7.0: tensorflow==2.7.0 - tensorflow-v2.8.0-rc0: tensorflow==v2.8.0-rc0 + flake8 + pylint + tensorflow-2.8: tensorflow==2.8 + tensorflow-2.9: tensorflow==2.9 + tensorflow-2.10: tensorflow==2.10 + tensorflow-2.11: tensorflow==2.11 changedir = examples -commands = python example-attention.py +commands = pylint --disable=R,C,W,E1136,E0401 ../attention + python example-attention.py +passenv = * install_command = pip install {packages}