Skip to content

Commit

Permalink
#625 entropy_reward can automatically detect if enough observations a…
Browse files Browse the repository at this point in the history
…re present
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent 88371e1 commit ddd7b2f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 73 deletions.
62 changes: 31 additions & 31 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,28 @@
class PebbleRewardPhase(Enum):
"""States representing different behaviors for PebbleStateEntropyReward"""

# Collecting samples so that we have something for entropy calculation
LEARNING_START = auto()
# Entropy based reward
UNSUPERVISED_EXPLORATION = auto()
# Learned reward
POLICY_AND_REWARD_LEARNING = auto()
UNSUPERVISED_EXPLORATION = auto() # Entropy based reward
POLICY_AND_REWARD_LEARNING = auto() # Learned reward


class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
"""
Reward function for implementation of the PEBBLE learning algorithm
(https://arxiv.org/pdf/2106.05091.pdf).
The rewards returned by this function go through the three phases
defined in PebbleRewardPhase. To transition between these phases,
unsupervised_exploration_start() and unsupervised_exploration_finish()
need to be called.
The rewards returned by this function go through the three phases:
1. Before enough samples are collected for entropy calculation, the
underlying function is returned. This shouldn't matter because
OffPolicyAlgorithms have an initialization period for `learning_starts`
timesteps.
2. During the unsupervised exploration phase, entropy based reward is returned
3. After unsupervised exploration phase is finished, the underlying learned
reward is returned.
The second phase (UNSUPERVISED_EXPLORATION) also requires that a buffer
with observations to compare against is supplied with set_replay_buffer()
or on_replay_buffer_initialized().
The second phase requires that a buffer with observations to compare against is
supplied with set_replay_buffer() or on_replay_buffer_initialized().
To transition to the last phase, unsupervised_exploration_finish() needs
to be called.
Args:
learned_reward_fn: The learned reward function used after unsupervised
Expand All @@ -51,11 +52,10 @@ def __init__(
learned_reward_fn: RewardFn,
nearest_neighbor_k: int = 5,
):
self.trained_reward_fn = learned_reward_fn
self.learned_reward_fn = learned_reward_fn
self.nearest_neighbor_k = nearest_neighbor_k
# TODO support n_envs > 1
self.entropy_stats = RunningNorm(1)
self.state = PebbleRewardPhase.LEARNING_START
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION

# These two need to be set with set_replay_buffer():
self.replay_buffer_view = None
Expand All @@ -68,10 +68,6 @@ def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
self.replay_buffer_view = replay_buffer
self.obs_shape = obs_shape

def unsupervised_exploration_start(self):
assert self.state == PebbleRewardPhase.LEARNING_START
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION

def unsupervised_exploration_finish(self):
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
self.state = PebbleRewardPhase.POLICY_AND_REWARD_LEARNING
Expand All @@ -84,26 +80,30 @@ def __call__(
done: np.ndarray,
) -> np.ndarray:
if self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION:
return self._entropy_reward(state)
return self._entropy_reward(state, action, next_state, done)
else:
return self.trained_reward_fn(state, action, next_state, done)
return self.learned_reward_fn(state, action, next_state, done)

def _entropy_reward(self, state):
def _entropy_reward(self, state, action, next_state, done):
if self.replay_buffer_view is None:
raise ValueError(
"Replay buffer must be supplied before entropy reward can be used"
)

all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape((-1, *self.obs_shape))
# TODO #625: deal with the conversion back and forth between np and torch
entropies = util.compute_state_entropy(
th.tensor(state),
th.tensor(all_observations),
self.nearest_neighbor_k,
)
normalized_entropies = self.entropy_stats.forward(entropies)

if all_observations.shape[0] < self.nearest_neighbor_k:
# not enough observations to compare to, fall back to the learned function
return self.learned_reward_fn(state, action, next_state, done)
else:
# TODO #625: deal with the conversion back and forth between np and torch
entropies = util.compute_state_entropy(
th.tensor(state),
th.tensor(all_observations),
self.nearest_neighbor_k,
)
normalized_entropies = self.entropy_stats.forward(entropies)
return normalized_entropies.numpy()

def __getstate__(self):
Expand Down
64 changes: 22 additions & 42 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,51 +20,13 @@
VENVS = 2


def test_pebble_entropy_reward_function_returns_learned_reward_initially():
expected_reward = np.ones(1)
learned_reward_mock = Mock()
learned_reward_mock.return_value = expected_reward
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE)

# Act
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
assert reward == expected_reward
learned_reward_mock.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
)


def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_training():
expected_reward = np.ones(1)
learned_reward_mock = Mock()
learned_reward_mock.return_value = expected_reward
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE)
# move all the way to the last state
reward_fn.unsupervised_exploration_start()
reward_fn.unsupervised_exploration_finish()

# Act
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
assert reward == expected_reward
learned_reward_mock.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
)


def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
all_observations = rng.random((BUFFER_SIZE, VENVS, *(OBS_SHAPE)))

reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
reward_fn = PebbleStateEntropyReward(Mock(), K)
reward_fn.set_replay_buffer(
ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE
)
reward_fn.unsupervised_exploration_start()

# Act
observations = th.rand((BATCH_SIZE, *(OBS_SHAPE)))
Expand All @@ -85,13 +47,12 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
# mock entropy computation so that we can test only stats collection in this test
m.side_effect = lambda obs, all_obs, k: obs

reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
reward_fn = PebbleStateEntropyReward(Mock(), K)
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
reward_fn.set_replay_buffer(
ReplayBufferView(all_observations, lambda: slice(None)),
OBS_SHAPE,
)
reward_fn.unsupervised_exploration_start()

dim = 8
shift = 3
Expand All @@ -115,12 +76,31 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
)


def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_training():
expected_reward = np.ones(1)
learned_reward_mock = Mock()
learned_reward_mock.return_value = expected_reward
reward_fn = PebbleStateEntropyReward(learned_reward_mock)
# move all the way to the last state
reward_fn.unsupervised_exploration_finish()

# Act
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
assert reward == expected_reward
learned_reward_mock.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
)


def test_pebble_entropy_reward_can_pickle():
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))

obs1 = np.random.rand(VENVS, *OBS_SHAPE)
reward_fn = PebbleStateEntropyReward(reward_fn_stub, SPACE, K)
reward_fn = PebbleStateEntropyReward(reward_fn_stub, K)
reward_fn.set_replay_buffer(replay_buffer, OBS_SHAPE)
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

Expand Down

0 comments on commit ddd7b2f

Please sign in to comment.