Skip to content

Commit

Permalink
Add support for the bahdanau score (#66)
Browse files Browse the repository at this point in the history
* Add support for the bahdanau score

* update TF versions on the tox file

* protobuf constraint for <v2.8

* protobuf constraint for <v2.8

* add pylint

* add pylint

* Remove E0401 from pylint

* Remove E0401 from pylint

* 5.0.0
  • Loading branch information
philipperemy authored Mar 19, 2023
1 parent 482b0c9 commit 0600c95
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 65 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Simple Keras Attention CI
name: Keras Attention Layer CI

on: [ push, pull_request ]

Expand All @@ -9,19 +9,19 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [ 3.8 ]
python-version: [ "3.8", "3.9", "3.10" ]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies and package
run: |
python -m pip install --upgrade pip
pip install flake8 pylint tox
- name: Static Analysis with Flake8
- name: Static Analysis
run: |
flake8 . --ignore E402 --count --max-complexity 10 --max-line-length 127 --select=E9,F63,F7,F82 --show-source --statistics
- name: Run example
Expand Down
97 changes: 65 additions & 32 deletions attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,64 @@
import os

from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer
from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer, RepeatVector, Add

# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
# In debug mode, the class Attention is no longer a Keras layer.
# What it means in practice is that we can have access to the internal values
# of each tensor. If we don't use debug, Keras treats the object
# as a layer and we can only get the final output.
# as a layer, and we can only get the final output.
debug_flag = int(os.environ.get('KERAS_ATTENTION_DEBUG', 0))


# References:
# - https://arxiv.org/pdf/1508.04025.pdf (Luong).
# - https://arxiv.org/pdf/1409.0473.pdf (Bahdanau).
# - https://machinelearningmastery.com/the-bahdanau-attention-mechanism/ (Some more explanation).

class Attention(object if debug_flag else Layer):
SCORE_LUONG = 'luong'
SCORE_BAHDANAU = 'bahdanau'

def __init__(self, units=128, **kwargs):
def __init__(self, units: int = 128, score: str = 'luong', **kwargs):
super(Attention, self).__init__(**kwargs)
if score not in {self.SCORE_LUONG, self.SCORE_BAHDANAU}:
raise ValueError(f'Possible values for score are: [{self.SCORE_LUONG}] and [{self.SCORE_BAHDANAU}].')
self.units = units
self.score = score

# noinspection PyAttributeOutsideInit
def build(self, input_shape):
input_dim = int(input_shape[-1])
with K.name_scope(self.name if not debug_flag else 'attention'):
self.attention_score_vec = Dense(input_dim, use_bias=False, name='attention_score_vec')
# W in W*h_S.
if self.score == self.SCORE_LUONG:
self.luong_w = Dense(input_dim, use_bias=False, name='luong_w')
# dot : last hidden state H_t and every hidden state H_s.
self.luong_dot = Dot(axes=[1, 2], name='attention_score')
else:
# Dense implements the operation: output = activation(dot(input, kernel) + bias)
self.bahdanau_v = Dense(1, use_bias=False, name='bahdanau_v')
self.bahdanau_w1 = Dense(input_dim, use_bias=False, name='bahdanau_w1')
self.bahdanau_w2 = Dense(input_dim, use_bias=False, name='bahdanau_w2')
self.bahdanau_repeat = RepeatVector(input_shape[1])
self.bahdanau_tanh = Activation('tanh', name='bahdanau_tanh')
self.bahdanau_add = Add()

self.h_t = Lambda(lambda x: x[:, -1, :], output_shape=(input_dim,), name='last_hidden_state')
self.attention_score = Dot(axes=[1, 2], name='attention_score')
self.attention_weight = Activation('softmax', name='attention_weight')
self.context_vector = Dot(axes=[1, 1], name='context_vector')
self.attention_output = Concatenate(name='attention_output')
self.attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')

# exp / sum(exp) -> softmax.
self.softmax_normalizer = Activation('softmax', name='attention_weight')

# dot : score * every hidden state H_s.
# dot product. SUM(v1*v2). H_s = every source hidden state.
self.dot_context = Dot(axes=[1, 1], name='context_vector')

# [Ct; ht]
self.concat_c_h = Concatenate(name='attention_output')

# x -> tanh(w_c(x))
self.w_c = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')
if not debug_flag:
# debug: the call to build() is done in call().
super(Attention, self).build(input_shape)
Expand All @@ -44,35 +75,37 @@ def __call__(self, inputs, training=None, **kwargs):
# noinspection PyUnusedLocal
def call(self, inputs, training=None, **kwargs):
"""
Many-to-one attention mechanism for Keras.
Many-to-one attention mechanism for Keras. Supports:
- Luong's multiplicative style.
- Bahdanau's additive style.
@param inputs: 3D tensor with shape (batch_size, time_steps, input_dim).
@param training: not used in this layer.
@return: 2D tensor with shape (batch_size, units)
@author: felixhao28, philipperemy.
@author: philipperemy, felixhao28.
"""
h_s = inputs
if debug_flag:
self.build(inputs.shape)
# Inside dense layer
# hidden_states dot W => score_first_part
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size)
# W is the trainable weight matrix of attention Luong's multiplicative style score
score_first_part = self.attention_score_vec(inputs)
# score_first_part dot last_hidden_state => attention_weights
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps)
h_t = self.h_t(inputs)
score = self.attention_score([h_t, score_first_part])
attention_weights = self.attention_weight(score)
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size)
context_vector = self.context_vector([inputs, attention_weights])
pre_activation = self.attention_output([context_vector, h_t])
attention_vector = self.attention_vector(pre_activation)
return attention_vector
self.build(h_s.shape)
h_t = self.h_t(h_s)
if self.score == self.SCORE_LUONG:
# Luong's multiplicative style.
score = self.luong_dot([h_t, self.luong_w(h_s)])
else:
# Bahdanau's additive style.
self.bahdanau_w1(h_s)
a1 = self.bahdanau_w1(h_t)
a2 = self.bahdanau_w2(h_s)
a1 = self.bahdanau_repeat(a1)
score = self.bahdanau_tanh(self.bahdanau_add([a1, a2]))
score = self.bahdanau_v(score)
score = K.squeeze(score, axis=-1)

alpha_s = self.softmax_normalizer(score)
context_vector = self.dot_context([h_s, alpha_s])
a_t = self.w_c(self.concat_c_h([context_vector, h_t]))
return a_t

def get_config(self):
"""
Returns the config of a the layer. This is used for saving and loading from a model
:return: python dictionary with specs to rebuild layer
"""
config = super(Attention, self).get_config()
config.update({'units': self.units})
config.update({'units': self.units, 'score': self.score})
return config
16 changes: 4 additions & 12 deletions examples/add_two_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from tensorflow.keras import Input
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.models import Model
from tensorflow.python.keras.utils.vis_utils import plot_model

# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
# In debug mode, the class Attention is no longer a Keras layer.
# What it means in practice is that we can have access to the internal values
# of each tensor. If we don't use debug, Keras treats the object
# as a layer and we can only get the final output.
# as a layer, and we can only get the final output.

# In this example we need it because we want to extract all the intermediate output values.
os.environ['KERAS_ATTENTION_DEBUG'] = '1'
Expand Down Expand Up @@ -74,15 +74,15 @@ def main():
# Define/compile the model.
model_input = Input(shape=(seq_length, 1))
x = LSTM(100, return_sequences=True)(model_input)
x = Attention()(x)
x = Attention(128, score='bahdanau')(x)
x = Dropout(0.2)(x)
x = Dense(1, activation='linear')(x)
model = Model(model_input, x)
model.compile(loss='mae', optimizer='adam')

# Visualize the model.
model.summary()
plot_model(model)
plot_model(model, dpi=200, show_dtype=True, show_shapes=True, show_layer_names=True)

# Will display the activation map in task_add_two_numbers/
output_dir = Path('task_add_two_numbers')
Expand Down Expand Up @@ -110,14 +110,6 @@ def on_epoch_end(self, epoch, logs=None):
callbacks=[VisualiseAttentionMap()]
)

# test save/reload model.
pred1 = model.predict(x_val)
model.save('test_model.h5')
model_h5 = load_model('test_model.h5')
pred2 = model_h5.predict(x_val)
np.testing.assert_almost_equal(pred1, pred2)
print('Success.')


if __name__ == '__main__':
# pip install pydot
Expand Down
20 changes: 12 additions & 8 deletions examples/example-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,18 @@
from attention import Attention


def main():
# Dummy data. There is nothing to learn in this example.
num_samples, time_steps, input_dim, output_dim = 100, 10, 1, 1
data_x = np.random.uniform(size=(num_samples, time_steps, input_dim))
data_y = np.random.uniform(size=(num_samples, output_dim))

def run_test(data_x, data_y, time_steps, input_dim, score):
# Define/compile the model.
model_input = Input(shape=(time_steps, input_dim))
x = LSTM(64, return_sequences=True)(model_input)
x = Attention(units=32)(x)
x = Attention(units=32, score=score)(x)
x = Dense(1)(x)
model = Model(model_input, x)
model.compile(loss='mae', optimizer='adam')
model.summary()

# train.
model.fit(data_x, data_y, epochs=10)
model.fit(data_x, data_y, epochs=30)

# test save/reload model.
pred1 = model.predict(data_x)
Expand All @@ -33,5 +28,14 @@ def main():
print('Success.')


def main():
# Dummy data. There is nothing to learn in this example.
num_samples, time_steps, input_dim, output_dim = 100, 10, 1, 1
data_x = np.random.uniform(size=(num_samples, time_steps, input_dim))
data_y = np.random.uniform(size=(num_samples, output_dim))
run_test(data_x, data_y, time_steps, input_dim, score='luong')
run_test(data_x, data_y, time_steps, input_dim, score='bahdanau')


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

setup(
name='attention',
version='4.1',
description='Keras Simple Attention',
version='5.0.0',
description='Keras Attention Layer',
author='Philippe Remy',
license='Apache 2.0',
long_description_content_type='text/markdown',
Expand Down
18 changes: 12 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
[tox]
envlist = {py3}-tensorflow-{2.5.0,2.6.2,2.7.0,v2.8.0-rc0}
envlist = {py3}-tensorflow-{2.8,2.9,2.10,2.11}

[testenv]
setenv =
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
deps = -rexamples/examples-requirements.txt
tensorflow-2.5.0: tensorflow==2.5.0
tensorflow-2.6.2: tensorflow==2.6.2
tensorflow-2.7.0: tensorflow==2.7.0
tensorflow-v2.8.0-rc0: tensorflow==v2.8.0-rc0
flake8
pylint
tensorflow-2.8: tensorflow==2.8
tensorflow-2.9: tensorflow==2.9
tensorflow-2.10: tensorflow==2.10
tensorflow-2.11: tensorflow==2.11
changedir = examples
commands = python example-attention.py
commands = pylint --disable=R,C,W,E1136,E0401 ../attention
python example-attention.py
passenv = *
install_command = pip install {packages}

0 comments on commit 0600c95

Please sign in to comment.