We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hello!
I'm running into a reshaping error when using RL and intermediate rewards.
The output of intermediate_rewards() is a # list of max_dec_step * (batch_size, k)(line 241)
intermediate_rewards()
# list of max_dec_step * (batch_size, k)
and then this is stacked and has shape (batch_size, k) - stored in self.sampling_discounted_rewards.
(batch_size, k)
self.sampling_discounted_rewards
But then in _add_loss_op(), you iterate k times and append:
_add_loss_op()
for _ in range(self._hps.k): self._sampled_rewards.append(self.sampling_discounted_rewards[:, :, _]) # shape (max_enc_steps, batch_size)
But the index [:, :, _] would run into a dimension error because the shape of self.sampling_discounted_rewards is (batch_size, k).
Am I missing something here? What should be the correct shape/reshaping? Thank you for uploading this code!
The text was updated successfully, but these errors were encountered:
Possible solution:
Change lines 414 and 427 of attention_decoder.py from
if FLAGS.use_discounted_rewards:
to
if FLAGS.use_discounted_rewards or FLAGS.use_intermediate_rewards:
Sorry, something went wrong.
No branches or pull requests
Hello!
I'm running into a reshaping error when using RL and intermediate rewards.
The output of
intermediate_rewards()
is a# list of max_dec_step * (batch_size, k)
(line 241)and then this is stacked and has shape
(batch_size, k)
- stored inself.sampling_discounted_rewards
.But then in
_add_loss_op()
, you iterate k times and append:But the index [:, :, _] would run into a dimension error because the shape of
self.sampling_discounted_rewards
is(batch_size, k)
.Am I missing something here? What should be the correct shape/reshaping? Thank you for uploading this code!
The text was updated successfully, but these errors were encountered: