Skip to content

Commit

Permalink
Update lib + misc (#63)
Browse files Browse the repository at this point in the history
* update attention

* update attention

* small improvements

* CI

* CI

* 4.1

* update attention

* update attention

* update attention

* update attention
  • Loading branch information
philipperemy authored Jan 23, 2022
1 parent 0f8b440 commit 08095bf
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 59 deletions.
13 changes: 5 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [ 3.7 ]
python-version: [ 3.8 ]

steps:
- uses: actions/checkout@v2
Expand All @@ -20,13 +20,10 @@ jobs:
- name: Install dependencies and package
run: |
python -m pip install --upgrade pip
pip install flake8
pip install -r examples/examples-requirements.txt
pip install -e .
- name: Lint with flake8
pip install flake8 pylint tox
- name: Static Analysis with Flake8
run: |
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --max-complexity 10 --max-line-length 127 --statistics
flake8 . --ignore E402 --count --max-complexity 10 --max-line-length 127 --select=E9,F63,F7,F82 --show-source --statistics
- name: Run example
run: |
cd examples && python example-attention.py 10
tox
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.png
77 changes: 57 additions & 20 deletions attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,78 @@
from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate
from tensorflow.keras.layers import Layer
import os

from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer

class Attention(Layer):
# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
# In debug mode, the class Attention is no longer a Keras layer.
# What it means in practice is that we can have access to the internal values
# of each tensor. If we don't use debug, Keras treats the object
# as a layer and we can only get the final output.
debug_flag = int(os.environ.get('KERAS_ATTENTION_DEBUG', 0))


class Attention(object if debug_flag else Layer):

def __init__(self, units=128, **kwargs):
super(Attention, self).__init__(**kwargs)
self.units = units
super().__init__(**kwargs)

def __call__(self, inputs):
# noinspection PyAttributeOutsideInit
def build(self, input_shape):
input_dim = int(input_shape[-1])
with K.name_scope(self.name if not debug_flag else 'attention'):
self.attention_score_vec = Dense(input_dim, use_bias=False, name='attention_score_vec')
self.h_t = Lambda(lambda x: x[:, -1, :], output_shape=(input_dim,), name='last_hidden_state')
self.attention_score = Dot(axes=[1, 2], name='attention_score')
self.attention_weight = Activation('softmax', name='attention_weight')
self.context_vector = Dot(axes=[1, 1], name='context_vector')
self.attention_output = Concatenate(name='attention_output')
self.attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')
if not debug_flag:
# debug: the call to build() is done in call().
super(Attention, self).build(input_shape)

def compute_output_shape(self, input_shape):
return input_shape[0], self.units

def __call__(self, inputs, training=None, **kwargs):
if debug_flag:
return self.call(inputs, training, **kwargs)
else:
return super(Attention, self).__call__(inputs, training, **kwargs)

# noinspection PyUnusedLocal
def call(self, inputs, training=None, **kwargs):
"""
Many-to-one attention mechanism for Keras.
@param inputs: 3D tensor with shape (batch_size, time_steps, input_dim).
@return: 2D tensor with shape (batch_size, 128)
@param training: not used in this layer.
@return: 2D tensor with shape (batch_size, units)
@author: felixhao28, philipperemy.
"""
hidden_states = inputs
hidden_size = int(hidden_states.shape[2])
if debug_flag:
self.build(inputs.shape)
# Inside dense layer
# hidden_states dot W => score_first_part
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size)
# W is the trainable weight matrix of attention Luong's multiplicative style score
score_first_part = Dense(hidden_size, use_bias=False, name='attention_score_vec')(hidden_states)
score_first_part = self.attention_score_vec(inputs)
# score_first_part dot last_hidden_state => attention_weights
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps)
h_t = Lambda(lambda x: x[:, -1, :], output_shape=(hidden_size,), name='last_hidden_state')(hidden_states)
score = Dot(axes=[1, 2], name='attention_score')([h_t, score_first_part])
attention_weights = Activation('softmax', name='attention_weight')(score)
h_t = self.h_t(inputs)
score = self.attention_score([h_t, score_first_part])
attention_weights = self.attention_weight(score)
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size)
context_vector = Dot(axes=[1, 1], name='context_vector')([hidden_states, attention_weights])
pre_activation = Concatenate(name='attention_output')([context_vector, h_t])
attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')(pre_activation)
context_vector = self.context_vector([inputs, attention_weights])
pre_activation = self.attention_output([context_vector, h_t])
attention_vector = self.attention_vector(pre_activation)
return attention_vector

def get_config(self):
return {'units': self.units}

@classmethod
def from_config(cls, config):
return cls(**config)
"""
Returns the config of a the layer. This is used for saving and loading from a model
:return: python dictionary with specs to rebuild layer
"""
config = super(Attention, self).get_config()
config.update({'units': self.units})
return config
33 changes: 25 additions & 8 deletions examples/add_two_numbers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import sys
from pathlib import Path
Expand All @@ -12,11 +13,21 @@
from tensorflow.keras.models import load_model, Model
from tensorflow.python.keras.utils.vis_utils import plot_model

# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
# In debug mode, the class Attention is no longer a Keras layer.
# What it means in practice is that we can have access to the internal values
# of each tensor. If we don't use debug, Keras treats the object
# as a layer and we can only get the final output.

# In this example we need it because we want to extract all the intermediate output values.
os.environ['KERAS_ATTENTION_DEBUG'] = '1'
from attention import Attention


def task_add_two_numbers_after_delimiter(n: int, seq_length: int, delimiter: float = 0.0,
index_1: int = None, index_2: int = None) -> (np.array, np.array):
def task_add_two_numbers_after_delimiter(
n: int, seq_length: int, delimiter: float = 0.0,
index_1: int = None, index_2: int = None
) -> (np.array, np.array):
"""
Task: Add the two numbers that come right after the delimiter.
x = [1, 2, 3, 0, 4, 5, 6, 0, 7, 8]. Result is y = 4 + 7 = 11.
Expand All @@ -43,7 +54,7 @@ def task_add_two_numbers_after_delimiter(n: int, seq_length: int, delimiter: flo

def main():
numpy.random.seed(7)
max_epoch = int(sys.argv[1]) if len(sys.argv) > 1 else 100
max_epoch = int(sys.argv[1]) if len(sys.argv) > 1 else 150

# data. definition of the problem.
seq_length = 20
Expand All @@ -70,7 +81,7 @@ def main():
model.compile(loss='mae', optimizer='adam')

# Visualize the model.
print(model.summary())
model.summary()
plot_model(model)

# Will display the activation map in task_add_two_numbers/
Expand All @@ -87,13 +98,17 @@ def on_epoch_end(self, epoch, logs=None):
iteration_no = str(epoch).zfill(3)
plt.axis('off')
plt.title(f'Iteration {iteration_no} / {max_epoch}')
plt.savefig(f'{output_dir}/epoch_{iteration_no}.png')
output_filename = f'{output_dir}/epoch_{iteration_no}.png'
print(f'Saving to {output_filename}.')
plt.savefig(output_filename)
plt.close()

# train.
model.fit(x_train, y_train, validation_data=(x_val, y_val),
epochs=max_epoch, verbose=2, batch_size=64,
callbacks=[VisualiseAttentionMap()])
model.fit(
x_train, y_train, validation_data=(x_val, y_val),
epochs=max_epoch, verbose=2, batch_size=64,
callbacks=[VisualiseAttentionMap()]
)

# test save/reload model.
pred1 = model.predict(x_val)
Expand All @@ -105,4 +120,6 @@ def on_epoch_end(self, epoch, logs=None):


if __name__ == '__main__':
# pip install pydot
# pip install keract
main()
6 changes: 3 additions & 3 deletions examples/example-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ def main():
# Define/compile the model.
model_input = Input(shape=(time_steps, input_dim))
x = LSTM(64, return_sequences=True)(model_input)
x = Attention(32)(x)
x = Attention(units=32)(x)
x = Dense(1)(x)
model = Model(model_input, x)
model.compile(loss='mae', optimizer='adam')
print(model.summary())
model.summary()

# train.
model.fit(data_x, data_y, epochs=10)

# test save/reload model.
pred1 = model.predict(data_x)
model.save('test_model.h5')
model_h5 = load_model('test_model.h5')
model_h5 = load_model('test_model.h5', custom_objects={'Attention': Attention})
pred2 = model_h5.predict(data_x)
np.testing.assert_almost_equal(pred1, pred2)
print('Success.')
Expand Down
44 changes: 34 additions & 10 deletions examples/find_max.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from keract import get_activations
from tensorflow.keras import Sequential
from keras import Input, Model
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Dense, LSTM

# In this example we need it because we want to extract all the intermediate output values.
os.environ['KERAS_ATTENTION_DEBUG'] = '1'
from attention import Attention

matplotlib.rcParams.update({'font.size': 8})


class VisualizeAttentionMap(Callback):

Expand All @@ -18,9 +25,10 @@ def __init__(self, model, x):
def on_epoch_begin(self, epoch, logs=None):
attention_map = get_activations(self.model, self.x, layer_names='attention_weight')['attention_weight']
x = self.x[..., 0]
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(5, 6))
plt.close()
fig, axes = plt.subplots(nrows=3, figsize=(10, 8))
maps = [attention_map, create_argmax_mask(attention_map), create_argmax_mask(x)]
maps_names = ['attention layer', 'attention layer - argmax()', 'ground truth - argmax()']
maps_names = ['attention layer (continuous)', 'attention layer - argmax (discrete)', 'ground truth (discrete)']
for i, ax in enumerate(axes.flat):
im = ax.imshow(maps[i], interpolation='none', cmap='jet')
ax.set_ylabel(maps_names[i] + '\n#sample axis')
Expand All @@ -29,8 +37,13 @@ def on_epoch_begin(self, epoch, logs=None):
ax.yaxis.set_ticks([])
cbar_ax = fig.add_axes([0.75, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
fig.suptitle(f'Epoch {epoch} - training')
plt.show()
fig.suptitle(f'Epoch {epoch} - training\nEach plot shows a 2-D matrix x-axis: sequence length * y-axis: '
f'batch/sample axis. \nThe first matrix contains the attention weights (softmax).'
f'\nWe manually apply argmax on the attention weights to see which time step ID has '
f'the strongest weight. \nFinally, the last matrix displays the ground truth. The task '
f'is solved when the second and third matrix match.')
plt.draw()
plt.pause(0.001)


def create_argmax_mask(x):
Expand All @@ -48,11 +61,22 @@ def main():
# If all the max(s) are concentrated around 1, then it makes the task easy for the model.
x_data = np.random.beta(a=1 / seq_length, b=1, size=(num_samples, seq_length, 1))
y_data = np.max(x_data, axis=1)
model = Sequential([
LSTM(128, input_shape=(seq_length, 1), return_sequences=True),
Attention(),
Dense(1, activation='linear')
])

# NOTE: can't use Sequential in debug with KERAS_ATTENTION_DEBUG=1.
# Because the layer Attention is no longer considered as a Keras layer.
# That's the trick to "see" the internal outputs of each tensor in the attention module.
# In practice, you can use Sequential without debug enabled ;)
# model = Sequential([
# LSTM(128, input_shape=(seq_length, 1), return_sequences=True),
# Attention(),
# Dense(1, activation='linear')
# ])
model_input = Input(shape=(seq_length, 1))
x = LSTM(128, return_sequences=True)(model_input)
x = Attention()(x)
x = Dense(1, activation='linear')(x)
model = Model(model_input, x)

model.compile(loss='mae')
max_epoch = 100
# visualize the attention on the first samples.
Expand Down
21 changes: 14 additions & 7 deletions examples/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ def train_and_evaluate_model_on_imdb(add_attention=True):
else [LSTM(100), Dense(350, activation='relu')]),
Dropout(0.5),
Dense(1, activation='sigmoid')
]
)
])

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())
model.summary()

class RecordBestTestAccuracy(Callback):

Expand All @@ -46,14 +45,22 @@ def on_epoch_end(self, epoch, logs=None):
self.val_accuracies.append(logs['val_accuracy'])
self.val_losses.append(logs['val_loss'])

rbta = RecordBestTestAccuracy()
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, batch_size=64, callbacks=[rbta])
record_callback = RecordBestTestAccuracy()
model.fit(
x_train, y_train,
verbose=2,
validation_data=(x_test, y_test),
epochs=10,
batch_size=64,
callbacks=[record_callback]
)

print(f"Max Test Accuracy: {100 * np.max(rbta.val_accuracies):.2f} %")
print(f"Mean Test Accuracy: {100 * np.mean(rbta.val_accuracies):.2f} %")
print(f"Max Test Accuracy: {100 * np.max(record_callback.val_accuracies):.2f} %")
print(f"Mean Test Accuracy: {100 * np.mean(record_callback.val_accuracies):.2f} %")


def main():
# Make sure to run on a GPU!
# 10 epochs.
# Max Test Accuracy: 88.02 %
# Mean Test Accuracy: 87.26 %
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from setuptools import setup
from setuptools import setup, find_packages

setup(
name='attention',
version='4.0',
version='4.1',
description='Keras Simple Attention',
author='Philippe Remy',
license='Apache 2.0',
long_description_content_type='text/markdown',
long_description=open('README.md').read(),
packages=['attention'],
packages=find_packages(),
install_requires=[
'numpy>=1.18.1',
'tensorflow>=2.1'
Expand Down
12 changes: 12 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[tox]
envlist = {py3}-tensorflow-{2.5.0,2.6.2,2.7.0,v2.8.0-rc0}

[testenv]
deps = -rexamples/examples-requirements.txt
tensorflow-2.5.0: tensorflow==2.5.0
tensorflow-2.6.2: tensorflow==2.6.2
tensorflow-2.7.0: tensorflow==2.7.0
tensorflow-v2.8.0-rc0: tensorflow==v2.8.0-rc0
changedir = examples
commands = python example-attention.py
install_command = pip install {packages}

0 comments on commit 08095bf

Please sign in to comment.