Skip to content

Commit

Permalink
#625 entropy reward as a function
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Nov 30, 2022
1 parent 27b8a55 commit 2dec99f
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 8 deletions.
44 changes: 44 additions & 0 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import torch as th
from gym.vector.utils import spaces
from stable_baselines3.common.preprocessing import get_obs_shape

from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.rewards.reward_function import RewardFn
from imitation.util import util
from imitation.util.networks import RunningNorm


class StateEntropyReward(RewardFn):
def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
self.nearest_neighbor_k = nearest_neighbor_k
# TODO support n_envs > 1
self.entropy_stats = RunningNorm(1)
self.obs_shape = get_obs_shape(observation_space)
self.replay_buffer_view = ReplayBufferView(
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
)

def set_buffer_view(self, replay_buffer_view: ReplayBufferView):
self.replay_buffer_view = replay_buffer_view

def __call__(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> np.ndarray:
# TODO: should this work with torch instead of numpy internally?
# (The RewardFn protocol requires numpy)

all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape((-1, *self.obs_shape))
entropies = util.compute_state_entropy(
state,
all_observations,
self.nearest_neighbor_k,
)
normalized_entropies = self.entropy_stats.forward(th.as_tensor(entropies))
return normalized_entropies.numpy()
32 changes: 31 additions & 1 deletion src/imitation/policies/replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ def _samples_to_reward_fn_input(
)


class ReplayBufferView:
"""A read-only view over a valid records in a ReplayBuffer.
Args:
observations_buffer: Array buffer holding observations
buffer_slice_provider: Function returning slice of buffer
with valid observations
"""

def __init__(
self,
observations_buffer: np.ndarray,
buffer_slice_provider: Callable[[], slice],
):
self._observations_buffer = observations_buffer.view()
self._observations_buffer.flags.writeable = False
self._buffer_slice_provider = buffer_slice_provider

@property
def observations(self):
return self._observations_buffer[self._buffer_slice_provider()]


class ReplayBufferRewardWrapper(ReplayBuffer):
"""Relabel the rewards in transitions sampled from a ReplayBuffer."""

Expand Down Expand Up @@ -83,6 +106,13 @@ def full(self) -> bool: # type: ignore[override]
def full(self, full: bool):
self.replay_buffer.full = full

@property
def buffer_view(self) -> ReplayBufferView:
def valid_buffer_slice():
return slice(None) if self.full else slice(self.pos)

return ReplayBufferView(self.replay_buffer.observations, valid_buffer_slice)

def sample(self, *args, **kwargs):
samples = self.replay_buffer.sample(*args, **kwargs)
rewards = self.reward_fn(**_samples_to_reward_fn_input(samples))
Expand Down Expand Up @@ -171,7 +201,7 @@ def sample(self, *args, **kwargs):
all_obs = all_obs.reshape((-1, *self.obs_shape))
entropies = util.compute_state_entropy(
samples.observations,
all_obs.reshape((-1, *self.obs_shape)),
all_obs,
self.k,
)

Expand Down
3 changes: 3 additions & 0 deletions src/imitation/util/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def forward(self, x: th.Tensor) -> th.Tensor:
with th.no_grad():
self.update_stats(x)

return self.normalize(x)

def normalize(self, x: th.Tensor) -> th.Tensor:
return (x - self.running_mean) / th.sqrt(self.running_var + self.eps)

@abc.abstractmethod
Expand Down
19 changes: 12 additions & 7 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:


def compute_state_entropy(
obs: th.Tensor,
all_obs: th.Tensor,
obs: np.ndarray,
all_obs: np.ndarray,
k: int,
) -> th.Tensor:
) -> np.ndarray:
"""Compute the state entropy given by KNN distance.
Args:
Expand All @@ -379,14 +379,19 @@ def compute_state_entropy(
assert obs.shape[1:] == all_obs.shape[1:]
with th.no_grad():
non_batch_dimensions = tuple(range(2, len(obs.shape) + 1))
distances_tensor = th.linalg.vector_norm(
distances_tensor = np.linalg.norm(
obs[:, None] - all_obs[None, :],
dim=non_batch_dimensions,
axis=non_batch_dimensions,
ord=2,
)

# Note that we take the k+1'th value because the closest neighbor to
# a point is itself, which we want to skip.
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
knn_dists = kth_value(distances_tensor, k+1)
state_entropy = knn_dists
return state_entropy.unsqueeze(1)
return np.expand_dims(state_entropy, axis=1)


def kth_value(x: np.ndarray, k: int):
assert k > 0
return np.partition(x, k - 1, axis=-1)[..., k - 1]
70 changes: 70 additions & 0 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from unittest.mock import patch

import numpy as np
import torch as th
from gym.spaces import Discrete
from stable_baselines3.common.preprocessing import get_obs_shape

from imitation.algorithms.pebble.entropy_reward import StateEntropyReward
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.util import util

SPACE = Discrete(4)
PLACEHOLDER = np.empty(get_obs_shape(SPACE))

BUFFER_SIZE = 20
K = 4
BATCH_SIZE = 8
VENVS = 2


def test_state_entropy_reward_returns_entropy(rng):
obs_shape = get_obs_shape(SPACE)
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))

reward_fn = StateEntropyReward(K, SPACE)
reward_fn.set_buffer_view(ReplayBufferView(all_observations, lambda: slice(None)))

# Act
observations = rng.random((BATCH_SIZE, *obs_shape))
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Assert
expected = util.compute_state_entropy(
observations, all_observations.reshape(-1, *obs_shape), K
)
expected_normalized = reward_fn.entropy_stats.normalize(th.as_tensor(expected)).numpy()
np.testing.assert_allclose(reward, expected_normalized)


def test_state_entropy_reward_returns_normalized_values():
with patch("imitation.util.util.compute_state_entropy") as m:
# mock entropy computation so that we can test only stats collection in this test
m.side_effect = lambda obs, all_obs, k: obs

reward_fn = StateEntropyReward(K, SPACE)
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
reward_fn.set_buffer_view(
ReplayBufferView(all_observations, lambda: slice(None))
)

dim = 8
shift = 3
scale = 2

# Act
for _ in range(1000):
state = th.randn(dim) * scale + shift
reward_fn(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

normalized_reward = reward_fn(
np.zeros(dim), PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
)

# Assert
np.testing.assert_allclose(
normalized_reward,
np.repeat(-shift / scale, dim),
rtol=0.05,
atol=0.05,
)
39 changes: 39 additions & 0 deletions tests/policies/test_replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os.path as osp
from typing import Type
from unittest.mock import Mock

import gym
import numpy as np
Expand All @@ -10,7 +11,9 @@
import torch as th
from gym import spaces
from stable_baselines3.common import buffers, off_policy_algorithm, policies
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import get_obs_shape, get_action_dim
from stable_baselines3.common.save_util import load_from_pkl
from stable_baselines3.common.vec_env import DummyVecEnv

Expand Down Expand Up @@ -225,3 +228,39 @@ def test_entropy_wrapper_class(tmpdir, rng):
k=k,
)
assert trained_entropy.mean() > initial_entropy.mean()


def test_replay_buffer_view_provides_buffered_observations():
space = spaces.Box(np.array([0]), np.array([5]))
n_envs = 2
buffer_size = 10
action = np.empty((n_envs, get_action_dim(space)))

obs_shape = get_obs_shape(space)
wrapper = ReplayBufferRewardWrapper(
buffer_size,
space,
space,
replay_buffer_class=ReplayBuffer,
reward_fn=Mock(),
n_envs=n_envs,
handle_timeout_termination=False,
)
view = wrapper.buffer_view

# initially empty
assert len(view.observations) == 0

# after adding observation
obs1 = np.random.random((n_envs, *obs_shape))
wrapper.add(obs1, obs1, action, np.empty(n_envs), np.empty(n_envs), [])
np.testing.assert_allclose(view.observations, np.array([obs1]))

# after filling buffer
observations = np.random.random((buffer_size // n_envs, n_envs, *obs_shape))
for obs in observations:
wrapper.add(obs, obs, action, np.empty(n_envs), np.empty(n_envs), [])

# ReplayBuffer internally uses a circular buffer
expected = np.roll(observations, 1, axis=0)
np.testing.assert_allclose(view.observations, expected)
12 changes: 12 additions & 0 deletions tests/util/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from imitation.util import sacred as sacred_util
from imitation.util import util
from imitation.util.util import kth_value


def test_endless_iter():
Expand Down Expand Up @@ -144,3 +145,14 @@ def test_compute_state_entropy_2d():
util.compute_state_entropy(obs, all_obs, k=3),
np.sqrt(20**2 + 2**2),
)


def test_kth_value():
arr1 = np.arange(0, 10, 1)
np.random.shuffle(arr1)
arr2 = np.arange(0, 100, 10)
np.random.shuffle(arr2)
arr = np.stack([arr1, arr2])

result = kth_value(arr, 3)
np.testing.assert_array_equal(result, np.array([2, 20]))

0 comments on commit 2dec99f

Please sign in to comment.