-
Notifications
You must be signed in to change notification settings - Fork 19
beam_attention_decoder
Higepon Taro Minowa edited this page Jul 9, 2017
·
2 revisions
Difference is
- output format this has been fixed
- some states treatment
- loop_functio
--- /Users/higepon/Desktop/a.py 2017-07-09 14:47:58.000000000 +0900
+++ /Users/higepon/Desktop/b.py 2017-07-09 14:47:34.000000000 +0900
@@ -26,7 +26,11 @@
v.append(variable_scope.get_variable("AttnV_%d" % a,
[attention_vec_size]))
- state = initial_state
+ state_size = int(initial_state.get_shape().with_rank(2)[1])
+ states = []
+ for kk in range(1):
+ states.append(initial_state)
+ state = tf.reshape(tf.concat(0, states), [-1, state_size])
def attention(query):
"""Put attention masks on hidden using hidden_features and query."""
@@ -53,9 +61,15 @@
for _ in xrange(num_heads)]
for a in attns: # Ensure the second shape of attention vectors is set.
a.set_shape([None, attn_size])
if initial_state_attention:
- attns = attention(initial_state)
+ attns = []
+ attns.append(attention(initial_state))
+ tmp = tf.reshape(tf.concat(0, attns), [-1, attn_size])
+ attns = []
+ attns.append(tmp)
+
+ log_beam_probs, beam_path, beam_symbols = [], [], []
for i, inp in enumerate(decoder_inputs):
if i > 0:
variable_scope.get_variable_scope().reuse_variables()
# If loop_function is set, we use it instead of decoder_inputs.
@@ -62,11 +78,11 @@
if loop_function is not None:
with variable_scope.variable_scope("loop_function", reuse=True):
if prev is not None:
- inp = loop_function(prev, i)
+ inp = loop_function(prev, i, log_beam_probs, beam_path, beam_symbols)
input_size = inp.get_shape().with_rank(2)[1]
x = linear([inp] + attns, input_size, True)
cell_output, state = cell(x, state)
# Run the attention mechanism.
if i == 0 and initial_state_attention:
with variable_scope.variable_scope(variable_scope.get_variable_scope(),
@@ -81,6 +96,16 @@
output = linear([cell_output] + attns, output_size, True)
if loop_function is not None:
prev = output
- outputs.append(output)
+ if i == 0:
+ states = []
+ for kk in range(beam_size):
+ states.append(state)
+ state = tf.reshape(tf.concat(0, states), [-1, state_size])
+ with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True):
+ attns = attention(state)
+
+ outputs.append(tf.argmax(nn_ops.xw_plus_b(
+ output, output_projection[0], output_projection[1]), dimension=1))
- return outputs, state
+ return outputs, state, tf.reshape(tf.concat(0, beam_path), [-1, beam_size]), tf.reshape(tf.concat(0, beam_symbols),
+ [-1, beam_size])