diff --git a/src/imitation/algorithms/pebble/entropy_reward.py b/src/imitation/algorithms/pebble/entropy_reward.py index 074281e90..eba53405b 100644 --- a/src/imitation/algorithms/pebble/entropy_reward.py +++ b/src/imitation/algorithms/pebble/entropy_reward.py @@ -1,7 +1,7 @@ """Reward function for the PEBBLE training algorithm.""" import enum -from typing import Optional, Tuple +from typing import Any, Callable, Optional, Tuple import gym import numpy as np @@ -18,10 +18,16 @@ class InsufficientObservations(RuntimeError): + """Error signifying not enough observations for entropy calculation.""" + pass class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn): + """RewardNet wrapping entropy reward function.""" + + __call__: Callable[..., Any] # Needed to appease pytype + def __init__( self, nearest_neighbor_k: int, @@ -53,6 +59,9 @@ def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper) This method needs to be called, e.g., after unpickling. See also __getstate__() / __setstate__(). + + Args: + replay_buffer: replay buffer with history of observations """ assert self.observation_space == replay_buffer.observation_space assert self.action_space == replay_buffer.action_space @@ -72,16 +81,18 @@ def forward( all_observations = self._replay_buffer_view.observations # ReplayBuffer sampling flattens the venv dimension, let's adapt to that all_observations = all_observations.reshape( - (-1,) + self.observation_space.shape + (-1,) + self.observation_space.shape, ) if all_observations.shape[0] < self.nearest_neighbor_k: raise InsufficientObservations( - "Insufficient observations for entropy calculation" + "Insufficient observations for entropy calculation", ) return util.compute_state_entropy( - state, all_observations, self.nearest_neighbor_k + state, + all_observations, + self.nearest_neighbor_k, ) def preprocess( @@ -95,6 +106,15 @@ def preprocess( We also know forward() only works with state, so no need to convert other tensors. + + Args: + state: The observation input. + action: The action input. + next_state: The observation input. + done: Whether the episode has terminated. + + Returns: + Observations preprocessed by converting them to Tensor. """ state_th = util.safe_to_tensor(state).to(self.device) action_th = next_state_th = done_th = th.empty(0) @@ -172,8 +192,8 @@ def __call__( try: return self.entropy_reward_fn(state, action, next_state, done) except InsufficientObservations: - # not enough observations to compare to, fall back to the learned function; - # (falling back to a constant may also be ok) + # not enough observations to compare to, fall back to the learned + # function; (falling back to a constant may also be ok) return self.learned_reward_fn(state, action, next_state, done) else: return self.learned_reward_fn(state, action, next_state, done) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 72f5da5cf..fccd7958d 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -96,13 +96,17 @@ def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None: """Pre-train an agent before collecting comparisons. Override this behavior in subclasses that implement pre-training. - If not overriden, this method raises ValueError when non-zero steps are + If not overridden, this method raises ValueError when non-zero steps are allocated for pre-training. Args: steps: number of environment steps to train for. **kwargs: additional keyword arguments to pass on to the training procedure. + + Raises: + ValueError: Unsupervised pre-training not implemented but non-zero + steps are allocated for pre-training. """ if steps > 0: raise ValueError( diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 524734713..5e07b094c 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -7,6 +7,7 @@ import pathlib from typing import Any, Mapping, Optional, Type, Union +import gym import numpy as np import torch as th from sacred.observers import FileStorageObserver @@ -24,6 +25,7 @@ ReplayBufferRewardWrapper, ) from imitation.rewards import reward_function, reward_nets +from imitation.rewards.reward_function import RewardFn from imitation.rewards.reward_nets import NormalizedRewardNet from imitation.scripts.common import common, reward from imitation.scripts.common import rl as rl_common @@ -80,21 +82,22 @@ def make_reward_function( reward_net.predict_processed, update_stats=False, ) - observation_space = reward_net.observation_space - action_space = reward_net.action_space if pebble_enabled: relabel_reward_fn = create_pebble_reward_fn( - relabel_reward_fn, + relabel_reward_fn, # type: ignore[assignment] pebble_nearest_neighbor_k, - action_space, - observation_space, + reward_net.action_space, + reward_net.observation_space, ) return relabel_reward_fn def create_pebble_reward_fn( - relabel_reward_fn, pebble_nearest_neighbor_k, action_space, observation_space -): + relabel_reward_fn: RewardFn, + pebble_nearest_neighbor_k: int, + action_space: gym.Space, + observation_space: gym.Space, +) -> PebbleStateEntropyReward: entropy_reward_net = EntropyRewardNet( nearest_neighbor_k=pebble_nearest_neighbor_k, observation_space=observation_space, @@ -111,13 +114,14 @@ def __call__(self, *args, **kwargs) -> np.ndarray: return normalized_entropy_reward_net.predict_processed(*args, **kwargs) def on_replay_buffer_initialized( - self, replay_buffer: ReplayBufferRewardWrapper + self, + replay_buffer: ReplayBufferRewardWrapper, ): entropy_reward_net.on_replay_buffer_initialized(replay_buffer) return PebbleStateEntropyReward( EntropyRewardFn(), - relabel_reward_fn, # type: ignore[assignment] + relabel_reward_fn, ) diff --git a/tests/algorithms/pebble/test_entropy_reward.py b/tests/algorithms/pebble/test_entropy_reward.py index 833a9ba94..b598ac75e 100644 --- a/tests/algorithms/pebble/test_entropy_reward.py +++ b/tests/algorithms/pebble/test_entropy_reward.py @@ -40,7 +40,10 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(): np.testing.assert_allclose(reward, expected_result) entropy_fn.assert_called_once_with( - observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, ) @@ -57,7 +60,10 @@ def test_pebble_entropy_reward_returns_learned_rew_on_insufficient_observations( np.testing.assert_allclose(reward, expected_result) learned_fn.assert_called_once_with( - observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, ) @@ -74,7 +80,10 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin np.testing.assert_allclose(reward, expected_result) learned_fn.assert_called_once_with( - observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, ) @@ -97,7 +106,10 @@ def test_entropy_reward_net_returns_entropy_for_pretraining(rng): # Act reward = reward_net.predict_processed( - observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, ) # Assert @@ -118,7 +130,10 @@ def test_entropy_reward_net_raises_on_insufficient_observations(rng): # Act with pytest.raises(InsufficientObservations): reward_net.predict_processed( - observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, ) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index f31fdceb8..d863cc4b0 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -18,12 +18,12 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons -from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward from imitation.data import types from imitation.data.types import TrajectoryWithRew from imitation.policies.replay_buffer_wrapper import ReplayBufferView from imitation.regularization import regularizers, updaters from imitation.rewards import reward_nets +from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn from imitation.util import networks, util UNCERTAINTY_ON = ["logit", "probability", "label"] @@ -84,9 +84,13 @@ def replay_buffer(rng): def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer): replay_buffer_mock = Mock() replay_buffer_mock.buffer_view = replay_buffer - replay_buffer_mock.obs_shape = (4,) - reward_fn = PebbleStateEntropyReward( - reward_net.predict_processed, venv.observation_space, venv.action_space + replay_buffer_mock.observation_space = venv.observation_space + replay_buffer_mock.action_space = venv.action_space + reward_fn = create_pebble_reward_fn( + reward_net.predict_processed, + 5, + venv.action_space, + venv.observation_space, ) reward_fn.on_replay_buffer_initialized(replay_buffer_mock) return preference_comparisons.PebbleAgentTrainer( diff --git a/tests/scripts/test_train_preference_comparisons.py b/tests/scripts/test_train_preference_comparisons.py index d05ebd27a..c4390dd6b 100644 --- a/tests/scripts/test_train_preference_comparisons.py +++ b/tests/scripts/test_train_preference_comparisons.py @@ -1,3 +1,5 @@ +"""Tests train_preferences_comparisons helper methods.""" + from unittest.mock import Mock, patch import numpy as np