diff --git a/src/imitation/algorithms/pebble/entropy_reward.py b/src/imitation/algorithms/pebble/entropy_reward.py new file mode 100644 index 000000000..724fbf314 --- /dev/null +++ b/src/imitation/algorithms/pebble/entropy_reward.py @@ -0,0 +1,44 @@ +import numpy as np +import torch as th +from gym.vector.utils import spaces +from stable_baselines3.common.preprocessing import get_obs_shape + +from imitation.policies.replay_buffer_wrapper import ReplayBufferView +from imitation.rewards.reward_function import RewardFn +from imitation.util import util +from imitation.util.networks import RunningNorm + + +class StateEntropyReward(RewardFn): + def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space): + self.nearest_neighbor_k = nearest_neighbor_k + # TODO support n_envs > 1 + self.entropy_stats = RunningNorm(1) + self.obs_shape = get_obs_shape(observation_space) + self.replay_buffer_view = ReplayBufferView( + np.empty(0, dtype=observation_space.dtype), lambda: slice(0) + ) + + def set_buffer_view(self, replay_buffer_view: ReplayBufferView): + self.replay_buffer_view = replay_buffer_view + + def __call__( + self, + state: np.ndarray, + action: np.ndarray, + next_state: np.ndarray, + done: np.ndarray, + ) -> np.ndarray: + # TODO: should this work with torch instead of numpy internally? + # (The RewardFn protocol requires numpy) + + all_observations = self.replay_buffer_view.observations + # ReplayBuffer sampling flattens the venv dimension, let's adapt to that + all_observations = all_observations.reshape((-1, *self.obs_shape)) + entropies = util.compute_state_entropy( + state, + all_observations, + self.nearest_neighbor_k, + ) + normalized_entropies = self.entropy_stats.forward(th.as_tensor(entropies)) + return normalized_entropies.numpy() diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index 60dcf24b6..b032704fc 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -24,6 +24,29 @@ def _samples_to_reward_fn_input( ) +class ReplayBufferView: + """A read-only view over a valid records in a ReplayBuffer. + + Args: + observations_buffer: Array buffer holding observations + buffer_slice_provider: Function returning slice of buffer + with valid observations + """ + + def __init__( + self, + observations_buffer: np.ndarray, + buffer_slice_provider: Callable[[], slice], + ): + self._observations_buffer = observations_buffer.view() + self._observations_buffer.flags.writeable = False + self._buffer_slice_provider = buffer_slice_provider + + @property + def observations(self): + return self._observations_buffer[self._buffer_slice_provider()] + + class ReplayBufferRewardWrapper(ReplayBuffer): """Relabel the rewards in transitions sampled from a ReplayBuffer.""" @@ -83,6 +106,13 @@ def full(self) -> bool: # type: ignore[override] def full(self, full: bool): self.replay_buffer.full = full + @property + def buffer_view(self) -> ReplayBufferView: + def valid_buffer_slice(): + return slice(None) if self.full else slice(self.pos) + + return ReplayBufferView(self.replay_buffer.observations, valid_buffer_slice) + def sample(self, *args, **kwargs): samples = self.replay_buffer.sample(*args, **kwargs) rewards = self.reward_fn(**_samples_to_reward_fn_input(samples)) @@ -171,7 +201,7 @@ def sample(self, *args, **kwargs): all_obs = all_obs.reshape((-1, *self.obs_shape)) entropies = util.compute_state_entropy( samples.observations, - all_obs.reshape((-1, *self.obs_shape)), + all_obs, self.k, ) diff --git a/src/imitation/util/networks.py b/src/imitation/util/networks.py index 0517f14bb..6e97db1cf 100644 --- a/src/imitation/util/networks.py +++ b/src/imitation/util/networks.py @@ -86,6 +86,9 @@ def forward(self, x: th.Tensor) -> th.Tensor: with th.no_grad(): self.update_stats(x) + return self.normalize(x) + + def normalize(self, x: th.Tensor) -> th.Tensor: return (x - self.running_mean) / th.sqrt(self.running_var + self.eps) @abc.abstractmethod diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index df8eb6a6a..d88f775cd 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -362,10 +362,10 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]: def compute_state_entropy( - obs: th.Tensor, - all_obs: th.Tensor, + obs: np.ndarray, + all_obs: np.ndarray, k: int, -) -> th.Tensor: +) -> np.ndarray: """Compute the state entropy given by KNN distance. Args: @@ -379,14 +379,19 @@ def compute_state_entropy( assert obs.shape[1:] == all_obs.shape[1:] with th.no_grad(): non_batch_dimensions = tuple(range(2, len(obs.shape) + 1)) - distances_tensor = th.linalg.vector_norm( + distances_tensor = np.linalg.norm( obs[:, None] - all_obs[None, :], - dim=non_batch_dimensions, + axis=non_batch_dimensions, ord=2, ) # Note that we take the k+1'th value because the closest neighbor to # a point is itself, which we want to skip. - knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values + knn_dists = kth_value(distances_tensor, k+1) state_entropy = knn_dists - return state_entropy.unsqueeze(1) + return np.expand_dims(state_entropy, axis=1) + + +def kth_value(x: np.ndarray, k: int): + assert k > 0 + return np.partition(x, k - 1, axis=-1)[..., k - 1] diff --git a/tests/algorithms/pebble/test_entropy_reward.py b/tests/algorithms/pebble/test_entropy_reward.py new file mode 100644 index 000000000..777a9b9d6 --- /dev/null +++ b/tests/algorithms/pebble/test_entropy_reward.py @@ -0,0 +1,70 @@ +from unittest.mock import patch + +import numpy as np +import torch as th +from gym.spaces import Discrete +from stable_baselines3.common.preprocessing import get_obs_shape + +from imitation.algorithms.pebble.entropy_reward import StateEntropyReward +from imitation.policies.replay_buffer_wrapper import ReplayBufferView +from imitation.util import util + +SPACE = Discrete(4) +PLACEHOLDER = np.empty(get_obs_shape(SPACE)) + +BUFFER_SIZE = 20 +K = 4 +BATCH_SIZE = 8 +VENVS = 2 + + +def test_state_entropy_reward_returns_entropy(rng): + obs_shape = get_obs_shape(SPACE) + all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape)) + + reward_fn = StateEntropyReward(K, SPACE) + reward_fn.set_buffer_view(ReplayBufferView(all_observations, lambda: slice(None))) + + # Act + observations = rng.random((BATCH_SIZE, *obs_shape)) + reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + + # Assert + expected = util.compute_state_entropy( + observations, all_observations.reshape(-1, *obs_shape), K + ) + expected_normalized = reward_fn.entropy_stats.normalize(th.as_tensor(expected)).numpy() + np.testing.assert_allclose(reward, expected_normalized) + + +def test_state_entropy_reward_returns_normalized_values(): + with patch("imitation.util.util.compute_state_entropy") as m: + # mock entropy computation so that we can test only stats collection in this test + m.side_effect = lambda obs, all_obs, k: obs + + reward_fn = StateEntropyReward(K, SPACE) + all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE))) + reward_fn.set_buffer_view( + ReplayBufferView(all_observations, lambda: slice(None)) + ) + + dim = 8 + shift = 3 + scale = 2 + + # Act + for _ in range(1000): + state = th.randn(dim) * scale + shift + reward_fn(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) + + normalized_reward = reward_fn( + np.zeros(dim), PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + ) + + # Assert + np.testing.assert_allclose( + normalized_reward, + np.repeat(-shift / scale, dim), + rtol=0.05, + atol=0.05, + ) diff --git a/tests/policies/test_replay_buffer_wrapper.py b/tests/policies/test_replay_buffer_wrapper.py index 5d06139aa..668208b58 100644 --- a/tests/policies/test_replay_buffer_wrapper.py +++ b/tests/policies/test_replay_buffer_wrapper.py @@ -2,6 +2,7 @@ import os.path as osp from typing import Type +from unittest.mock import Mock import gym import numpy as np @@ -10,7 +11,9 @@ import torch as th from gym import spaces from stable_baselines3.common import buffers, off_policy_algorithm, policies +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.preprocessing import get_obs_shape, get_action_dim from stable_baselines3.common.save_util import load_from_pkl from stable_baselines3.common.vec_env import DummyVecEnv @@ -225,3 +228,39 @@ def test_entropy_wrapper_class(tmpdir, rng): k=k, ) assert trained_entropy.mean() > initial_entropy.mean() + + +def test_replay_buffer_view_provides_buffered_observations(): + space = spaces.Box(np.array([0]), np.array([5])) + n_envs = 2 + buffer_size = 10 + action = np.empty((n_envs, get_action_dim(space))) + + obs_shape = get_obs_shape(space) + wrapper = ReplayBufferRewardWrapper( + buffer_size, + space, + space, + replay_buffer_class=ReplayBuffer, + reward_fn=Mock(), + n_envs=n_envs, + handle_timeout_termination=False, + ) + view = wrapper.buffer_view + + # initially empty + assert len(view.observations) == 0 + + # after adding observation + obs1 = np.random.random((n_envs, *obs_shape)) + wrapper.add(obs1, obs1, action, np.empty(n_envs), np.empty(n_envs), []) + np.testing.assert_allclose(view.observations, np.array([obs1])) + + # after filling buffer + observations = np.random.random((buffer_size // n_envs, n_envs, *obs_shape)) + for obs in observations: + wrapper.add(obs, obs, action, np.empty(n_envs), np.empty(n_envs), []) + + # ReplayBuffer internally uses a circular buffer + expected = np.roll(observations, 1, axis=0) + np.testing.assert_allclose(view.observations, expected) diff --git a/tests/util/test_util.py b/tests/util/test_util.py index 28678dc8b..be2487aee 100644 --- a/tests/util/test_util.py +++ b/tests/util/test_util.py @@ -11,6 +11,7 @@ from imitation.util import sacred as sacred_util from imitation.util import util +from imitation.util.util import kth_value def test_endless_iter(): @@ -144,3 +145,14 @@ def test_compute_state_entropy_2d(): util.compute_state_entropy(obs, all_obs, k=3), np.sqrt(20**2 + 2**2), ) + + +def test_kth_value(): + arr1 = np.arange(0, 10, 1) + np.random.shuffle(arr1) + arr2 = np.arange(0, 100, 10) + np.random.shuffle(arr2) + arr = np.stack([arr1, arr2]) + + result = kth_value(arr, 3) + np.testing.assert_array_equal(result, np.array([2, 20]))