Skip to content

Commit

Permalink
First pass at a TPU loop for Transformer (tensorflow#4296)
Browse files Browse the repository at this point in the history
* port changes from previous branch now that transformer util changes are in master

fix incorrect count

correct (hopefully) treatment of batch_size

set eval_metrics to a dummy function for now

add some comments

start bringing metrics to transformer TPU

resolve logits shape

metrics are now working except for tf.py_func metrics

increase batch_size for tpu, and create summary host call

fix host call

reduce tpu default batch size

further tune batch sizes

add minibatch loss to summary

handle case of single_iteration_train_steps > number points in an epoch

begin to incorporate hooks

add sleep workarounds

disable hooks altogether

generalize host call function and move to newly created tpu utils module

remove all traces of params as an object

switch from  to

address some PR comments, and change the number of data points.

minor tweaks

add tpu dry run for testing, and use matmul for TPU embedding

infeed/outfeed queue issue is fixed. Sleeps are no longer necessary

add some documentation.

cleanup and address PR comments

delint

add accelerator __init__

fix embedding

missed PR comment

address PR comments

fix validator bug

rewrite cloud storage validator, and add oauth dependency to requirements.txt

* delint
  • Loading branch information
Taylor Robie authored Jun 4, 2018
1 parent bd56a06 commit 2eeb85f
Show file tree
Hide file tree
Showing 17 changed files with 727 additions and 180 deletions.
2 changes: 2 additions & 0 deletions official/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ numpy
pandas
psutil>=5.4.3
py-cpuinfo>=3.3.0
google-api-python-client>=1.6.7
google-cloud-bigquery>=0.31.0
oauth2client>=4.1.2
5 changes: 5 additions & 0 deletions official/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The model also applies embeddings on the input and output tokens, and adds a con
* [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model)
* [Compute official BLEU score](#compute-official-bleu-score)
* [TPU](#tpu)
* [Implementation overview](#implementation-overview)
* [Model Definition](#model-definition)
* [Model Estimator](#model-estimator)
Expand Down Expand Up @@ -200,6 +201,10 @@ big | 28.9
* `--reference`: Path to file containing reference translations.
* Use the `--help` or `-h` flag to get a full list of possible arguments.

5. ### TPU
TPU support for this version of Transformer is experimental. Currently it is present for
demonstration purposes only, but will be optimized in the coming weeks.

## Implementation overview

A brief look at each component in the code:
Expand Down
38 changes: 30 additions & 8 deletions official/transformer/model/embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,31 @@
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.transformer.model import model_utils
from official.utils.accelerator import tpu as tpu_utils


class EmbeddingSharedWeights(tf.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights."""

def __init__(self, vocab_size, hidden_size):
def __init__(self, vocab_size, hidden_size, method="gather"):
"""Specify characteristic parameters of embedding layer.
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
method: Strategy for performing embedding lookup. "gather" uses tf.gather
which performs well on CPUs and GPUs, but very poorly on TPUs. "matmul"
one-hot encodes the indicies and formulates the embedding as a sparse
matrix multiplication. The matmul formulation is wasteful as it does
extra work, however matrix multiplication is very fast on TPUs which
makes "matmul" considerably faster than "gather" on TPUs.
"""
super(EmbeddingSharedWeights, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
if method not in ("gather", "matmul"):
raise ValueError("method {} must be 'gather' or 'matmul'".format(method))
self.method = method

def build(self, _):
with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE):
Expand All @@ -53,19 +69,25 @@ def call(self, x):
locations of the padding tokens in x.
"""
with tf.name_scope("embedding"):
embeddings = tf.gather(self.shared_weights, x)
# Create binary mask of size [batch_size, length]
mask = tf.to_float(tf.not_equal(x, 0))

if self.method == "gather":
embeddings = tf.gather(self.shared_weights, x)
else: # matmul
embeddings = tpu_utils.embedding_matmul(
embedding_table=self.shared_weights,
values=tf.cast(x, dtype=tf.int32),
mask=mask
)
embeddings *= tf.expand_dims(mask, -1)

# Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5

# Create binary array of size [batch_size, length]
# where 1 = padding, 0 = not padding
padding = model_utils.get_padding(x)

# Set all padding embedding values to 0
embeddings *= tf.expand_dims(1 - padding, -1)
return embeddings


def linear(self, x):
"""Computes logits by running x through a linear layer.
Expand Down
10 changes: 7 additions & 3 deletions official/transformer/model/ffn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
class FeedFowardNetwork(tf.layers.Layer):
"""Fully connected feedforward network."""

def __init__(self, hidden_size, filter_size, relu_dropout, train):
def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad):
super(FeedFowardNetwork, self).__init__()
self.hidden_size = hidden_size
self.filter_size = filter_size
self.relu_dropout = relu_dropout
self.train = train
self.allow_pad = allow_pad

self.filter_dense_layer = tf.layers.Dense(
filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer")
Expand All @@ -42,13 +43,16 @@ def call(self, x, padding=None):
Args:
x: tensor with shape [batch_size, length, hidden_size]
padding: (optional) If set, the padding values are temporarily removed
from x. The padding values are placed back in the output tensor in the
same locations. shape [batch_size, length]
from x (provided self.allow_pad is set). The padding values are placed
back in the output tensor in the same locations.
shape [batch_size, length]
Returns:
Output of the feedforward network.
tensor with shape [batch_size, length, hidden_size]
"""
padding = None if not self.allow_pad else padding

# Retrieve dynamically known shapes
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
Expand Down
103 changes: 61 additions & 42 deletions official/transformer/model/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,64 @@
"""Defines Transformer model parameters."""


class TransformerBaseParams(object):
"""Parameters for the base Transformer model."""
# Input params
batch_size = 2048 # Maximum number of tokens per batch of examples.
max_length = 256 # Maximum number of tokens per example.

# Model params
initializer_gain = 1.0 # Used in trainable variable initialization.
vocab_size = 33708 # Number of tokens defined in the vocabulary file.
hidden_size = 512 # Model dimension in the hidden layers.
num_hidden_layers = 6 # Number of layers in the encoder and decoder stacks.
num_heads = 8 # Number of heads to use in multi-headed attention.
filter_size = 2048 # Inner layer dimensionality in the feedforward network.

# Dropout values (only used when training)
layer_postprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1

# Training params
label_smoothing = 0.1
learning_rate = 2.0
learning_rate_decay_rate = 1.0
learning_rate_warmup_steps = 16000

# Optimizer params
optimizer_adam_beta1 = 0.9
optimizer_adam_beta2 = 0.997
optimizer_adam_epsilon = 1e-09

# Default prediction params
extra_decode_length = 50
beam_size = 4
alpha = 0.6 # used to calculate length normalization in beam search


class TransformerBigParams(TransformerBaseParams):
"""Parameters for the big Transformer model."""
batch_size = 4096
hidden_size = 1024
filter_size = 4096
num_heads = 16
BASE_PARAMS = dict(
# Input params
default_batch_size=2048, # Maximum number of tokens per batch of examples.
default_batch_size_tpu=32768,
max_length=256, # Maximum number of tokens per example.

# Model params
initializer_gain=1.0, # Used in trainable variable initialization.
vocab_size=33708, # Number of tokens defined in the vocabulary file.
hidden_size=512, # Model dimension in the hidden layers.
num_hidden_layers=6, # Number of layers in the encoder and decoder stacks.
num_heads=8, # Number of heads to use in multi-headed attention.
filter_size=2048, # Inner layer dimension in the feedforward network.

# Dropout values (only used when training)
layer_postprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,

# Training params
label_smoothing=0.1,
learning_rate=2.0,
learning_rate_decay_rate=1.0,
learning_rate_warmup_steps=16000,

# Optimizer params
optimizer_adam_beta1=0.9,
optimizer_adam_beta2=0.997,
optimizer_adam_epsilon=1e-09,

# Default prediction params
extra_decode_length=50,
beam_size=4,
alpha=0.6, # used to calculate length normalization in beam search

# TPU specific parameters
use_tpu=False,
static_batch=False,
allow_ffn_pad=True,
)

BIG_PARAMS = dict(BASE_PARAMS)
BIG_PARAMS.update(dict(
default_batch_size=4096,

# default batch size is smaller than for BASE_PARAMS due to memory limits.
default_batch_size_tpu=16384,

hidden_size=1024,
filter_size=4096,
num_heads=16,
))

TINY_PARAMS = dict(BASE_PARAMS)
TINY_PARAMS.update(dict(
default_batch_size=1024,
default_batch_size_tpu=1024,
hidden_size=32,
num_heads=4,
filter_size=256,
))
56 changes: 31 additions & 25 deletions official/transformer/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self, params, train):
self.params = params

self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params.vocab_size, params.hidden_size)
params["vocab_size"], params["hidden_size"],
method="matmul" if params["tpu"] else "gather")
self.encoder_stack = EncoderStack(params, train)
self.decoder_stack = DecoderStack(params, train)

Expand All @@ -79,7 +80,7 @@ def __call__(self, inputs, targets=None):
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
initializer = tf.variance_scaling_initializer(
self.params.initializer_gain, mode="fan_avg", distribution="uniform")
self.params["initializer_gain"], mode="fan_avg", distribution="uniform")
with tf.variable_scope("Transformer", initializer=initializer):
# Calculate attention bias for encoder self-attention and decoder
# multi-headed attention layers.
Expand Down Expand Up @@ -116,12 +117,12 @@ def encode(self, inputs, attention_bias):
with tf.name_scope("add_pos_encoding"):
length = tf.shape(embedded_inputs)[1]
pos_encoding = model_utils.get_position_encoding(
length, self.params.hidden_size)
length, self.params["hidden_size"])
encoder_inputs = embedded_inputs + pos_encoding

if self.train:
encoder_inputs = tf.nn.dropout(
encoder_inputs, 1 - self.params.layer_postprocess_dropout)
encoder_inputs, 1 - self.params["layer_postprocess_dropout"])

return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding)

Expand Down Expand Up @@ -149,10 +150,10 @@ def decode(self, targets, encoder_outputs, attention_bias):
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
decoder_inputs += model_utils.get_position_encoding(
length, self.params.hidden_size)
length, self.params["hidden_size"])
if self.train:
decoder_inputs = tf.nn.dropout(
decoder_inputs, 1 - self.params.layer_postprocess_dropout)
decoder_inputs, 1 - self.params["layer_postprocess_dropout"])

# Run values
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
Expand All @@ -167,7 +168,7 @@ def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""

timing_signal = model_utils.get_position_encoding(
max_decode_length + 1, self.params.hidden_size)
max_decode_length + 1, self.params["hidden_size"])
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length)

Expand Down Expand Up @@ -206,7 +207,7 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
"""Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params.extra_decode_length
max_decode_length = input_length + self.params["extra_decode_length"]

symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)

Expand All @@ -216,9 +217,9 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
# Create cache storing decoder attention values for each layer.
cache = {
"layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params.hidden_size]),
"v": tf.zeros([batch_size, 0, self.params.hidden_size]),
} for layer in range(self.params.num_hidden_layers)}
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
} for layer in range(self.params["num_hidden_layers"])}

# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
Expand All @@ -229,9 +230,9 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self.params.vocab_size,
beam_size=self.params.beam_size,
alpha=self.params.alpha,
vocab_size=self.params["vocab_size"],
beam_size=self.params["beam_size"],
alpha=self.params["alpha"],
max_decode_length=max_decode_length,
eos_id=EOS_ID)

Expand Down Expand Up @@ -268,11 +269,11 @@ class PrePostProcessingWrapper(object):

def __init__(self, layer, params, train):
self.layer = layer
self.postprocess_dropout = params.layer_postprocess_dropout
self.postprocess_dropout = params["layer_postprocess_dropout"]
self.train = train

# Create normalization layer
self.layer_norm = LayerNormalization(params.hidden_size)
self.layer_norm = LayerNormalization(params["hidden_size"])

def __call__(self, x, *args, **kwargs):
# Preprocessing: apply layer normalization
Expand All @@ -299,19 +300,21 @@ class EncoderStack(tf.layers.Layer):
def __init__(self, params, train):
super(EncoderStack, self).__init__()
self.layers = []
for _ in range(params.num_hidden_layers):
for _ in range(params["num_hidden_layers"]):
# Create sublayers for each layer.
self_attention_layer = attention_layer.SelfAttention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params.hidden_size, params.filter_size, params.relu_dropout, train)
params["hidden_size"], params["filter_size"],
params["relu_dropout"], train, params["allow_ffn_pad"])

self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])

# Create final layer normalization layer.
self.output_normalization = LayerNormalization(params.hidden_size)
self.output_normalization = LayerNormalization(params["hidden_size"])

def call(self, encoder_inputs, attention_bias, inputs_padding):
"""Return the output of the encoder layer stacks.
Expand Down Expand Up @@ -354,20 +357,23 @@ class DecoderStack(tf.layers.Layer):
def __init__(self, params, train):
super(DecoderStack, self).__init__()
self.layers = []
for _ in range(params.num_hidden_layers):
for _ in range(params["num_hidden_layers"]):
self_attention_layer = attention_layer.SelfAttention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
enc_dec_attention_layer = attention_layer.Attention(
params.hidden_size, params.num_heads, params.attention_dropout, train)
params["hidden_size"], params["num_heads"],
params["attention_dropout"], train)
feed_forward_network = ffn_layer.FeedFowardNetwork(
params.hidden_size, params.filter_size, params.relu_dropout, train)
params["hidden_size"], params["filter_size"],
params["relu_dropout"], train, params["allow_ffn_pad"])

self.layers.append([
PrePostProcessingWrapper(self_attention_layer, params, train),
PrePostProcessingWrapper(enc_dec_attention_layer, params, train),
PrePostProcessingWrapper(feed_forward_network, params, train)])

self.output_normalization = LayerNormalization(params.hidden_size)
self.output_normalization = LayerNormalization(params["hidden_size"])

def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias,
attention_bias, cache=None):
Expand Down
Loading

0 comments on commit 2eeb85f

Please sign in to comment.