Skip to content

beam_attention_decoder

Higepon Taro Minowa edited this page Jul 9, 2017 · 2 revisions

Summary

Difference is

  • output format this has been fixed
  • some states treatment
  • loop_functio

Code with comment

--- /Users/higepon/Desktop/a.py	2017-07-09 14:47:58.000000000 +0900
+++ /Users/higepon/Desktop/b.py	2017-07-09 14:47:34.000000000 +0900
@@ -26,7 +26,11 @@
             v.append(variable_scope.get_variable("AttnV_%d" % a,
                                                  [attention_vec_size]))
 
-        state = initial_state
+        state_size = int(initial_state.get_shape().with_rank(2)[1])
+        states = []
+        for kk in range(1):
+            states.append(initial_state)
+        state = tf.reshape(tf.concat(0, states), [-1, state_size])
 
         def attention(query):
             """Put attention masks on hidden using hidden_features and query."""
@@ -53,9 +61,15 @@
                  for _ in xrange(num_heads)]
         for a in attns:  # Ensure the second shape of attention vectors is set.
             a.set_shape([None, attn_size])
         if initial_state_attention:
-            attns = attention(initial_state)
+            attns = []
+            attns.append(attention(initial_state))
+            tmp = tf.reshape(tf.concat(0, attns), [-1, attn_size])
+            attns = []
+            attns.append(tmp)
+
+        log_beam_probs, beam_path, beam_symbols = [], [], []
         for i, inp in enumerate(decoder_inputs):
             if i > 0:
                 variable_scope.get_variable_scope().reuse_variables()
             # If loop_function is set, we use it instead of decoder_inputs.
@@ -62,11 +78,11 @@
             if loop_function is not None:
                 with variable_scope.variable_scope("loop_function", reuse=True):
                     if prev is not None:
-                        inp = loop_function(prev, i)
+                        inp = loop_function(prev, i, log_beam_probs, beam_path, beam_symbols)
 
             input_size = inp.get_shape().with_rank(2)[1]
             x = linear([inp] + attns, input_size, True)
             cell_output, state = cell(x, state)
             # Run the attention mechanism.
             if i == 0 and initial_state_attention:
                 with variable_scope.variable_scope(variable_scope.get_variable_scope(),
@@ -81,6 +96,16 @@
                 output = linear([cell_output] + attns, output_size, True)
             if loop_function is not None:
                 prev = output
-            outputs.append(output)
+            if i == 0:
+                states = []
+                for kk in range(beam_size):
+                    states.append(state)
+                state = tf.reshape(tf.concat(0, states), [-1, state_size])
+                with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True):
+                    attns = attention(state)
+
+            outputs.append(tf.argmax(nn_ops.xw_plus_b(
+                output, output_projection[0], output_projection[1]), dimension=1))
 
-    return outputs, state
+    return outputs, state, tf.reshape(tf.concat(0, beam_path), [-1, beam_size]), tf.reshape(tf.concat(0, beam_symbols),
+                                                                                            [-1, beam_size])
Clone this wiki locally