From ec7b853cf0c944091fecfeeca79bc55677db232e Mon Sep 17 00:00:00 2001 From: Jan Michelfeit Date: Thu, 1 Dec 2022 16:26:16 +0100 Subject: [PATCH] #625 fix entropy_reward.py --- src/imitation/algorithms/pebble/entropy_reward.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/pebble/entropy_reward.py b/src/imitation/algorithms/pebble/entropy_reward.py index a1fff0e46..01c2f9a9f 100644 --- a/src/imitation/algorithms/pebble/entropy_reward.py +++ b/src/imitation/algorithms/pebble/entropy_reward.py @@ -35,13 +35,14 @@ def __call__( all_observations = self.replay_buffer_view.observations # ReplayBuffer sampling flattens the venv dimension, let's adapt to that - all_observations = all_observations.reshape((-1, *self.obs_shape)) + all_observations = all_observations.reshape((-1, *state.shape[1:])) # TODO #625: fix self.obs_shape + # TODO #625: deal with the conversion back and forth between np and torch entropies = util.compute_state_entropy( - state, - all_observations, + th.tensor(state), + th.tensor(all_observations), self.nearest_neighbor_k, ) - normalized_entropies = self.entropy_stats.forward(th.as_tensor(entropies)) + normalized_entropies = self.entropy_stats.forward(entropies) return normalized_entropies.numpy() def __getstate__(self):