Skip to content

Commit

Permalink
#625 fix entropy_reward.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent ad29c34 commit ec7b853
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ def __call__(

all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape((-1, *self.obs_shape))
all_observations = all_observations.reshape((-1, *state.shape[1:])) # TODO #625: fix self.obs_shape
# TODO #625: deal with the conversion back and forth between np and torch
entropies = util.compute_state_entropy(
state,
all_observations,
th.tensor(state),
th.tensor(all_observations),
self.nearest_neighbor_k,
)
normalized_entropies = self.entropy_stats.forward(th.as_tensor(entropies))
normalized_entropies = self.entropy_stats.forward(entropies)
return normalized_entropies.numpy()

def __getstate__(self):
Expand Down

0 comments on commit ec7b853

Please sign in to comment.