Skip to content

Commit

Permalink
#625 introduce ReplayBufferAwareRewardFn
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent d1aae17 commit 3d7cfca
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 19 deletions.
17 changes: 13 additions & 4 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from typing import Tuple

import numpy as np
import torch as th
from gym.vector.utils import spaces
from stable_baselines3.common.preprocessing import get_obs_shape

from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.rewards.reward_function import RewardFn
from imitation.policies.replay_buffer_wrapper import (
ReplayBufferView,
ReplayBufferRewardWrapper,
)
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn
from imitation.util import util
from imitation.util.networks import RunningNorm


class StateEntropyReward(RewardFn):
class StateEntropyReward(ReplayBufferAwareRewardFn):
def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
self.nearest_neighbor_k = nearest_neighbor_k
# TODO support n_envs > 1
Expand All @@ -20,8 +25,12 @@ def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
)

def set_replay_buffer(self, replay_buffer: ReplayBufferView):
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)

def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape:Tuple):
self.replay_buffer_view = replay_buffer
self.obs_shape = obs_shape

def __call__(
self,
Expand Down
13 changes: 6 additions & 7 deletions src/imitation/policies/replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples

from imitation.rewards.reward_function import RewardFn
from imitation.rewards.reward_function import RewardFn, ReplayBufferAwareRewardFn
from imitation.util import util


Expand Down Expand Up @@ -37,13 +37,13 @@ def __init__(
observations_buffer: np.ndarray,
buffer_slice_provider: Callable[[], slice],
):
self._observations_buffer = observations_buffer.view()
self._observations_buffer.flags.writeable = False
self._observations_buffer_view = observations_buffer.view()
self._observations_buffer_view.flags.writeable = False
self._buffer_slice_provider = buffer_slice_provider

@property
def observations(self):
return self._observations_buffer[self._buffer_slice_provider()]
return self._observations_buffer_view[self._buffer_slice_provider()]


class ReplayBufferRewardWrapper(ReplayBuffer):
Expand All @@ -57,7 +57,6 @@ def __init__(
*,
replay_buffer_class: Type[ReplayBuffer],
reward_fn: RewardFn,
on_initialized_callback: Callable[["ReplayBufferRewardWrapper"], None] = None,
**kwargs,
):
"""Builds ReplayBufferRewardWrapper.
Expand Down Expand Up @@ -88,8 +87,8 @@ def __init__(
self.reward_fn = reward_fn
_base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]}
super().__init__(buffer_size, observation_space, action_space, **_base_kwargs)
if on_initialized_callback is not None:
on_initialized_callback(self)
if isinstance(reward_fn, ReplayBufferAwareRewardFn):
reward_fn.on_replay_buffer_initialized(self)

# TODO(juan) remove the type ignore once the merged PR
# https://github.com/python/mypy/pull/13475
Expand Down
6 changes: 6 additions & 0 deletions src/imitation/rewards/reward_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ def __call__(
Returns:
Computed rewards of shape `(batch_size,`).
""" # noqa: DAR202


class ReplayBufferAwareRewardFn(RewardFn, abc.ABC):
@abc.abstractmethod
def on_replay_buffer_initialized(self, replay_buffer: "ReplayBufferRewardWrapper"):
pass
8 changes: 5 additions & 3 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def test_state_entropy_reward_returns_entropy(rng):
obs_shape = get_obs_shape(SPACE)
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))


reward_fn = StateEntropyReward(K, SPACE)
reward_fn.set_buffer_view(ReplayBufferView(all_observations, lambda: slice(None)))
reward_fn.set_replay_buffer(ReplayBufferView(all_observations, lambda: slice(None)), obs_shape)

# Act
observations = rng.random((BATCH_SIZE, *obs_shape))
Expand All @@ -48,7 +49,8 @@ def test_state_entropy_reward_returns_normalized_values():
reward_fn = StateEntropyReward(K, SPACE)
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
reward_fn.set_replay_buffer(
ReplayBufferView(all_observations, lambda: slice(None))
ReplayBufferView(all_observations, lambda: slice(None)),
get_obs_shape(SPACE)
)

dim = 8
Expand Down Expand Up @@ -79,7 +81,7 @@ def test_state_entropy_reward_can_pickle():

obs1 = np.random.rand(VENVS, *get_obs_shape(SPACE))
reward_fn = StateEntropyReward(K, SPACE)
reward_fn.set_replay_buffer(replay_buffer)
reward_fn.set_replay_buffer(replay_buffer, get_obs_shape(SPACE))
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

# Act
Expand Down
10 changes: 5 additions & 5 deletions tests/policies/test_replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from stable_baselines3.common.save_util import load_from_pkl

from imitation.policies.replay_buffer_wrapper import ReplayBufferRewardWrapper
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn
from imitation.util import util


Expand Down Expand Up @@ -175,16 +176,15 @@ def test_replay_buffer_view_provides_buffered_observations():
np.testing.assert_allclose(view.observations, expected)


def test_replay_buffer_reward_wrapper_calls_initialization_callback_with_itself():
callback = Mock()
def test_replay_buffer_reward_wrapper_calls_reward_initialization_callback():
reward_fn = Mock(spec=ReplayBufferAwareRewardFn)
buffer = ReplayBufferRewardWrapper(
10,
spaces.Discrete(2),
spaces.Discrete(2),
replay_buffer_class=ReplayBuffer,
reward_fn=Mock(),
reward_fn=reward_fn,
n_envs=2,
handle_timeout_termination=False,
on_initialized_callback=callback,
)
assert callback.call_args.args[0] is buffer
assert reward_fn.on_replay_buffer_initialized.call_args.args[0] is buffer

0 comments on commit 3d7cfca

Please sign in to comment.