Skip to content

Commit

Permalink
✍️ update streaming transducer encoder recognize
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed Apr 9, 2021
1 parent 9a67d87 commit 8769192
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
7 changes: 5 additions & 2 deletions tensorflow_asr/models/keras/streaming_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .transducer import Transducer
from ..streaming_transducer import StreamingTransducerEncoder
from ...utils.utils import shape_list


class StreamingTransducer(Transducer):
Expand Down Expand Up @@ -113,7 +114,8 @@ def recognize(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
batch_size, _, _, _ = shape_list(features)
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
return self._perform_greedy_batch(encoded, input_length,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand Down Expand Up @@ -179,7 +181,8 @@ def recognize_beam(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
batch_size, _, _, _ = shape_list(features)
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
return self._perform_beam_search_batch(encoded, input_length, lm,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand Down
12 changes: 7 additions & 5 deletions tensorflow_asr/models/streaming_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .layers.subsampling import TimeReduction
from .transducer import Transducer
from ..utils.utils import get_rnn, merge_two_last_dims
from ..utils.utils import get_rnn, merge_two_last_dims, shape_list


class Reshape(tf.keras.layers.Layer):
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(self,
reduction_factor = reductions.get(i, 0)
if reduction_factor > 0: self.time_reduction_factor *= reduction_factor

def get_initial_state(self):
def get_initial_state(self, batch_size=1):
"""Get zeros states
Returns:
Expand All @@ -138,7 +138,7 @@ def get_initial_state(self):
states.append(
tf.stack(
block.rnn.get_initial_state(
tf.zeros([1, 1, 1], dtype=tf.float32)
tf.zeros([batch_size, 1, 1], dtype=tf.float32)
), axis=0
)
)
Expand Down Expand Up @@ -269,7 +269,8 @@ def recognize(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
batch_size, _, _, _ = shape_list(features)
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
return self._perform_greedy_batch(encoded, input_length,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand Down Expand Up @@ -335,7 +336,8 @@ def recognize_beam(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
batch_size, _, _, _ = shape_list(features)
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
return self._perform_beam_search_batch(encoded, input_length, lm,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand Down

0 comments on commit 8769192

Please sign in to comment.