Skip to content

Commit

Permalink
#625 PebbleStateEntropyReward can switch from unsupervised pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent d348534 commit 9090b0c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
26 changes: 22 additions & 4 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
ReplayBufferView,
ReplayBufferRewardWrapper,
)
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn
from imitation.util import util
from imitation.util.networks import RunningNorm


class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
# TODO #625: get rid of the observation_space parameter
def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
# TODO #625: parametrize nearest_neighbor_k
def __init__(
self,
trained_reward_fn: RewardFn,
observation_space: spaces.Space,
nearest_neighbor_k: int = 5,
):
self.trained_reward_fn = trained_reward_fn
self.nearest_neighbor_k = nearest_neighbor_k
# TODO support n_envs > 1
self.entropy_stats = RunningNorm(1)
Expand All @@ -25,24 +32,35 @@ def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
self.replay_buffer_view = ReplayBufferView(
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
)
# This indicates that the training is in the "Unsupervised exploration"
# phase of the Pebble algorithm, where entropy is used as reward
self.unsupervised_exploration_active = True

def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)

def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape:Tuple):
def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
self.replay_buffer_view = replay_buffer
self.obs_shape = obs_shape

def on_unsupervised_exploration_finished(self):
self.unsupervised_exploration_active = False

def __call__(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> np.ndarray:
if self.unsupervised_exploration_active:
return self._entropy_reward(state)
else:
return self.trained_reward_fn(state, action, next_state, done)

def _entropy_reward(self, state):
# TODO: should this work with torch instead of numpy internally?
# (The RewardFn protocol requires numpy)

all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape(
Expand Down
51 changes: 41 additions & 10 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
from unittest.mock import patch
from unittest.mock import patch, Mock

import numpy as np
import torch as th
Expand All @@ -19,13 +19,14 @@
VENVS = 2


def test_state_entropy_reward_returns_entropy(rng):
def test_pebble_entropy_reward_returns_entropy(rng):
obs_shape = get_obs_shape(SPACE)
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))


reward_fn = PebbleStateEntropyReward(K, SPACE)
reward_fn.set_replay_buffer(ReplayBufferView(all_observations, lambda: slice(None)), obs_shape)
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
reward_fn.set_replay_buffer(
ReplayBufferView(all_observations, lambda: slice(None)), obs_shape
)

# Act
observations = rng.random((BATCH_SIZE, *obs_shape))
Expand All @@ -41,16 +42,16 @@ def test_state_entropy_reward_returns_entropy(rng):
np.testing.assert_allclose(reward, expected_normalized)


def test_state_entropy_reward_returns_normalized_values():
def test_pebble_entropy_reward_returns_normalized_values():
with patch("imitation.util.util.compute_state_entropy") as m:
# mock entropy computation so that we can test only stats collection in this test
m.side_effect = lambda obs, all_obs, k: obs

reward_fn = PebbleStateEntropyReward(K, SPACE)
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
reward_fn.set_replay_buffer(
ReplayBufferView(all_observations, lambda: slice(None)),
get_obs_shape(SPACE)
get_obs_shape(SPACE),
)

dim = 8
Expand All @@ -75,12 +76,12 @@ def test_state_entropy_reward_returns_normalized_values():
)


def test_state_entropy_reward_can_pickle():
def test_pebble_entropy_reward_can_pickle():
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))

obs1 = np.random.rand(VENVS, *get_obs_shape(SPACE))
reward_fn = PebbleStateEntropyReward(K, SPACE)
reward_fn = PebbleStateEntropyReward(reward_fn_stub, SPACE, K)
reward_fn.set_replay_buffer(replay_buffer, get_obs_shape(SPACE))
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

Expand All @@ -94,3 +95,33 @@ def test_state_entropy_reward_can_pickle():
expected_result = reward_fn(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
actual_result = reward_fn_deserialized(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
np.testing.assert_allclose(actual_result, expected_result)


def test_pebble_entropy_reward_function_switches_to_inner():
obs_shape = get_obs_shape(SPACE)

expected_reward = np.ones(1)
reward_fn_mock = Mock()
reward_fn_mock.return_value = expected_reward
reward_fn = PebbleStateEntropyReward(reward_fn_mock, SPACE)

# Act
reward_fn.on_unsupervised_exploration_finished()
observations = np.ones((BATCH_SIZE, *obs_shape))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
assert reward == expected_reward
reward_fn_mock.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
)


def reward_fn_stub(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> np.ndarray:
return state

0 comments on commit 9090b0c

Please sign in to comment.