Skip to content

Commit

Permalink
Merge branch 'dev/rnnt' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed Oct 20, 2020
2 parents 72cd5d2 + 6c107e3 commit ade7891
Show file tree
Hide file tree
Showing 11 changed files with 893 additions and 93 deletions.
2 changes: 1 addition & 1 deletion examples/conformer/train_ga_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
config["decoder_config"],
corpus_files=args.subwords_corpus
)
text_featurizer.subwords.save_to_file(args.subwords_prefix)
text_featurizer.save_to_file(args.subwords)

if args.tfrecords:
train_dataset = ASRTFRecordDataset(
Expand Down
2 changes: 1 addition & 1 deletion examples/conformer/train_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
config["decoder_config"],
corpus_files=args.subwords_corpus
)
text_featurizer.subwords.save_to_file(args.subwords_prefix)
text_featurizer.save_to_file(args.subwords)

if args.tfrecords:
train_dataset = ASRTFRecordDataset(
Expand Down
26 changes: 14 additions & 12 deletions examples/streaming_transducer/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@ decoder_config:

model_config:
name: streaming_transducer
reduction_factor: 2
reduction_positions: [1]
encoder_dim: 320
encoder_units: 1024
encoder_layers: 8
encoder_reductions:
0: 3
1: 2
encoder_dmodel: 320
encoder_rnn_type: lstm
encoder_rnn_units: 1024
encoder_nlayers: 8
encoder_layer_norm: True
encoder_type: lstm
embed_dim: 320
embed_dropout: 0.1
num_rnns: 1
rnn_units: 320
rnn_type: lstm
layer_norm: True
prediction_embed_dim: 320
prediction_embed_dropout: 0.0
prediction_num_rnns: 2
prediction_rnn_units: 1024
prediction_rnn_type: lstm
prediction_projection_units: 320
prediction_layer_norm: True
joint_dim: 320

learning_config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
config["decoder_config"],
corpus_files=args.subwords_corpus
)
text_featurizer.subwords.save_to_file(args.subwords_prefix)
text_featurizer.save_to_file(args.subwords)

if args.tfrecords:
train_dataset = ASRTFRecordDataset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
config["decoder_config"],
corpus_files=args.subwords_corpus
)
text_featurizer.subwords.save_to_file(args.subwords_prefix)
text_featurizer.save_to_file(args.subwords)

if args.tfrecords:
train_dataset = ASRTFRecordDataset(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.2.7",
version="0.2.8",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_asr/featurizers/text_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ def load_from_file(cls, decoder_config: dict, filename: str = None):
subwords = tds.features.text.SubwordTextEncoder.load_from_file(filename_prefix)
return cls(decoder_config, subwords)

def save_to_file(self, filename: str = None):
if filename is not None:
filename_prefix = os.path.splitext(preprocess_paths(filename))[0]
else:
filename_prefix = self.decoder_config.get("vocabulary", None)
return self.subwords.save_to_file(filename_prefix)

def extract(self, text: str) -> tf.Tensor:
"""
Convert string to a list of integers
Expand Down
142 changes: 68 additions & 74 deletions tensorflow_asr/models/streaming_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,55 @@
import collections
import tensorflow as tf

from .layers.merge_two_last_dims import Merge2LastDims
from .layers.subsampling import TimeReduction
from .transducer import Transducer, BeamHypothesis
from ..utils.utils import get_rnn, get_shape_invariants
from ..utils.utils import get_rnn, get_shape_invariants, merge_two_last_dims

Hypothesis = collections.namedtuple(
"Hypothesis",
("index", "prediction", "encoder_states", "prediction_states")
)


class Reshape(tf.keras.layers.Layer):
def call(self, inputs): return merge_two_last_dims(inputs)


class StreamingTransducerBlock(tf.keras.Model):
def __init__(self,
reduction_factor: int = 3,
apply_reduction: bool = False,
encoder_dim: int = 320,
encoder_type: str = "lstm",
encoder_units: int = 1024,
encoder_layer_norm: bool = True,
apply_projection: bool = True,
reduction_factor: int = 0,
dmodel: int = 640,
rnn_type: str = "lstm",
rnn_units: int = 2048,
layer_norm: bool = True,
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
super(StreamingTransducerBlock, self).__init__(**kwargs)

if apply_reduction:
if reduction_factor > 0:
self.reduction = TimeReduction(reduction_factor, name=f"{self.name}_reduction")
else:
self.reduction = None

RNN = get_rnn(encoder_type)
RNN = get_rnn(rnn_type)
self.rnn = RNN(
units=encoder_units, return_sequences=True,
units=rnn_units, return_sequences=True,
name=f"{self.name}_rnn", return_state=True,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer
)

if encoder_layer_norm:
if layer_norm:
self.ln = tf.keras.layers.LayerNormalization(name=f"{self.name}_ln")
else:
self.ln = None

if apply_projection:
self.projection = tf.keras.layers.Dense(
encoder_dim, name=f"{self.name}_projection",
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer
)
else:
self.projection = None
self.projection = tf.keras.layers.Dense(
dmodel, name=f"{self.name}_projection",
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer
)

def call(self, inputs, training=False):
outputs = inputs
Expand All @@ -76,8 +74,7 @@ def call(self, inputs, training=False):
outputs = outputs[0]
if self.ln is not None:
outputs = self.ln(outputs, training=training)
if self.projection is not None:
outputs = self.projection(outputs, training=training)
outputs = self.projection(outputs, training=training)
return outputs

def recognize(self, inputs, states):
Expand All @@ -89,12 +86,11 @@ def recognize(self, inputs, states):
outputs = outputs[0]
if self.ln is not None:
outputs = self.ln(outputs, training=False)
if self.projection is not None:
outputs = self.projection(outputs, training=False)
outputs = self.projection(outputs, training=False)
return outputs, new_states

def get_config(self):
conf = super(StreamingTransducerBlock, self).get_config()
conf = {}
if self.reduction is not None:
conf.update(self.reduction.get_config())
conf.update(self.rnn.get_config())
Expand All @@ -106,38 +102,36 @@ def get_config(self):

class StreamingTransducerEncoder(tf.keras.Model):
def __init__(self,
reduction_factor: int = 3,
reduction_positions: list = [1],
encoder_dim: int = 320,
encoder_layers: int = 8,
encoder_type: str = "lstm",
encoder_units: int = 1024,
encoder_layer_norm: bool = True,
reductions: dict = {0: 3, 1: 2},
dmodel: int = 640,
nlayers: int = 8,
rnn_type: str = "lstm",
rnn_units: int = 2048,
layer_norm: bool = True,
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
super(StreamingTransducerEncoder, self).__init__(**kwargs)

self.merge = Merge2LastDims(name=f"{self.name}_merge")
self.reshape = Reshape(name=f"{self.name}_reshape")

self.blocks = [
StreamingTransducerBlock(
reduction_factor=reduction_factor,
apply_reduction=(i in reduction_positions),
apply_projection=(i != encoder_layers - 1),
encoder_dim=encoder_dim,
encoder_type=encoder_type,
encoder_units=encoder_units,
encoder_layer_norm=encoder_layer_norm,
reduction_factor=reductions.get(i, 0), # key is index, value is the factor
dmodel=dmodel,
rnn_type=rnn_type,
rnn_units=rnn_units,
layer_norm=layer_norm,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
name=f"{self.name}_{i}"
) for i in range(encoder_layers)
name=f"{self.name}_block_{i}"
) for i in range(nlayers)
]

self.time_reduction_factor = 1
for i in range(encoder_layers):
if i in reduction_positions: self.time_reduction_factor *= reduction_factor
for i in range(nlayers):
reduction_factor = reductions.get(i, 0)
if reduction_factor > 0: self.time_reduction_factor *= reduction_factor

def get_initial_state(self):
"""Get zeros states
Expand All @@ -157,7 +151,7 @@ def get_initial_state(self):
return tf.stack(states, axis=0)

def call(self, inputs, training=False):
outputs = self.merge(inputs)
outputs = self.reshape(inputs)
for block in self.blocks:
outputs = block(outputs, training=training)
return outputs
Expand All @@ -173,60 +167,60 @@ def recognize(self, inputs, states):
tf.Tensor: outputs with shape [1, T, E]
tf.Tensor: new states with shape [num_lstms, 1 or 2, 1, P]
"""
outputs = self.merge(inputs)
outputs = self.reshape(inputs)
new_states = []
for i, block in enumerate(self.blocks):
outputs, block_states = block.recognize(outputs, states=tf.unstack(states[i], axis=0))
new_states.append(block_states)
return outputs, tf.stack(new_states, axis=0)

def get_config(self):
conf = {}
conf = self.reshape.get_config()
for block in self.blocks: conf.update(block.get_config())
return conf


class StreamingTransducer(Transducer):
def __init__(self,
vocabulary_size: int,
reduction_factor: int = 2,
reduction_positions: list = [1],
encoder_dim: int = 320,
encoder_layers: int = 8,
encoder_type: str = "lstm",
encoder_units: int = 1024,
encoder_reductions: dict = {0: 3, 1: 2},
encoder_dmodel: int = 640,
encoder_nlayers: int = 8,
encoder_rnn_type: str = "lstm",
encoder_rnn_units: int = 2048,
encoder_layer_norm: bool = True,
embed_dim: int = 320,
embed_dropout: float = 0,
num_rnns: int = 2,
rnn_units: int = 1024,
rnn_type: str = "lstm",
layer_norm: bool = True,
joint_dim: int = 320,
prediction_embed_dim: int = 320,
prediction_embed_dropout: float = 0,
prediction_num_rnns: int = 2,
prediction_rnn_units: int = 2048,
prediction_rnn_type: str = "lstm",
prediction_layer_norm: bool = True,
prediction_projection_units: int = 640,
joint_dim: int = 640,
kernel_regularizer = None,
bias_regularizer = None,
name = "StreamingTransducer",
**kwargs):
super(StreamingTransducer, self).__init__(
encoder=StreamingTransducerEncoder(
reduction_factor=reduction_factor,
reduction_positions=reduction_positions,
encoder_dim=encoder_dim,
encoder_layers=encoder_layers,
encoder_type=encoder_type,
encoder_units=encoder_units,
encoder_layer_norm=encoder_layer_norm,
reductions=encoder_reductions,
dmodel=encoder_dmodel,
nlayers=encoder_nlayers,
rnn_type=encoder_rnn_type,
rnn_units=encoder_rnn_units,
layer_norm=encoder_layer_norm,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
name=f"{name}_encoder"
),
vocabulary_size=vocabulary_size,
embed_dim=embed_dim,
embed_dropout=embed_dropout,
num_rnns=num_rnns,
rnn_units=rnn_units,
rnn_type=rnn_type,
layer_norm=layer_norm,
embed_dim=prediction_embed_dim,
embed_dropout=prediction_embed_dropout,
num_rnns=prediction_num_rnns,
rnn_units=prediction_rnn_units,
rnn_type=prediction_rnn_type,
layer_norm=prediction_layer_norm,
projection_units=prediction_projection_units,
joint_dim=joint_dim,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
Expand Down
Loading

0 comments on commit ade7891

Please sign in to comment.