From f3decf1185fb827259d878992fffcfb2fe940fa6 Mon Sep 17 00:00:00 2001 From: Jan Michelfeit Date: Thu, 1 Dec 2022 23:04:26 +0100 Subject: [PATCH] #625 fix pre-commit errors --- .../algorithms/pebble/entropy_reward.py | 41 +++++++++++-------- .../algorithms/preference_comparisons.py | 26 ++++++++---- .../policies/replay_buffer_wrapper.py | 20 ++++----- src/imitation/rewards/reward_function.py | 7 +++- src/imitation/scripts/common/rl.py | 3 +- .../config/train_preference_comparisons.py | 2 + .../scripts/train_preference_comparisons.py | 12 +++--- .../algorithms/pebble/test_entropy_reward.py | 29 +++++++++---- 8 files changed, 88 insertions(+), 52 deletions(-) diff --git a/src/imitation/algorithms/pebble/entropy_reward.py b/src/imitation/algorithms/pebble/entropy_reward.py index e0d94c171..7570d369f 100644 --- a/src/imitation/algorithms/pebble/entropy_reward.py +++ b/src/imitation/algorithms/pebble/entropy_reward.py @@ -1,12 +1,14 @@ +"""Reward function for the PEBBLE training algorithm.""" + from enum import Enum, auto -from typing import Tuple +from typing import Dict, Optional, Tuple, Union import numpy as np import torch as th from imitation.policies.replay_buffer_wrapper import ( - ReplayBufferView, ReplayBufferRewardWrapper, + ReplayBufferView, ) from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn from imitation.util import util @@ -14,16 +16,16 @@ class PebbleRewardPhase(Enum): - """States representing different behaviors for PebbleStateEntropyReward""" + """States representing different behaviors for PebbleStateEntropyReward.""" UNSUPERVISED_EXPLORATION = auto() # Entropy based reward POLICY_AND_REWARD_LEARNING = auto() # Learned reward class PebbleStateEntropyReward(ReplayBufferAwareRewardFn): - """ - Reward function for implementation of the PEBBLE learning algorithm - (https://arxiv.org/pdf/2106.05091.pdf). + """Reward function for implementation of the PEBBLE learning algorithm. + + See https://arxiv.org/pdf/2106.05091.pdf . The rewards returned by this function go through the three phases: 1. Before enough samples are collected for entropy calculation, the @@ -38,33 +40,38 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn): supplied with set_replay_buffer() or on_replay_buffer_initialized(). To transition to the last phase, unsupervised_exploration_finish() needs to be called. - - Args: - learned_reward_fn: The learned reward function used after unsupervised - exploration is finished - nearest_neighbor_k: Parameter for entropy computation (see - compute_state_entropy()) """ - # TODO #625: parametrize nearest_neighbor_k def __init__( self, learned_reward_fn: RewardFn, nearest_neighbor_k: int = 5, ): + """Builds this class. + + Args: + learned_reward_fn: The learned reward function used after unsupervised + exploration is finished + nearest_neighbor_k: Parameter for entropy computation (see + compute_state_entropy()) + """ self.learned_reward_fn = learned_reward_fn self.nearest_neighbor_k = nearest_neighbor_k self.entropy_stats = RunningNorm(1) self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION # These two need to be set with set_replay_buffer(): - self.replay_buffer_view = None - self.obs_shape = None + self.replay_buffer_view: Optional[ReplayBufferView] = None + self.obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]], None] = None def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper): self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape) - def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple): + def set_replay_buffer( + self, + replay_buffer: ReplayBufferView, + obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]], + ): self.replay_buffer_view = replay_buffer self.obs_shape = obs_shape @@ -87,7 +94,7 @@ def __call__( def _entropy_reward(self, state, action, next_state, done): if self.replay_buffer_view is None: raise ValueError( - "Replay buffer must be supplied before entropy reward can be used" + "Replay buffer must be supplied before entropy reward can be used", ) all_observations = self.replay_buffer_view.observations # ReplayBuffer sampling flattens the venv dimension, let's adapt to that diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index ea1bfc2e8..cca614320 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -77,8 +77,7 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]: """ # noqa: DAR202 def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None: - """Pre-train an agent if the trajectory generator uses one that - needs pre-training. + """Pre-train an agent before collecting comparisons. By default, this method does nothing and doesn't need to be overridden in subclasses that don't require pre-training. @@ -331,8 +330,8 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None: class PebbleAgentTrainer(AgentTrainer): - """ - Specialization of AgentTrainer for PEBBLE training. + """Specialization of AgentTrainer for PEBBLE training. + Includes unsupervised pretraining with an entropy based reward function. """ @@ -344,9 +343,20 @@ def __init__( reward_fn: PebbleStateEntropyReward, **kwargs, ) -> None: + """Builds PebbleAgentTrainer. + + Args: + reward_fn: Pebble reward function + **kwargs: additional keyword arguments to pass on to + the parent class + + Raises: + ValueError: Unexpected type of reward_fn given. + """ if not isinstance(reward_fn, PebbleStateEntropyReward): raise ValueError( - f"{self.__class__.__name__} expects {PebbleStateEntropyReward.__name__} reward function" + f"{self.__class__.__name__} expects " + f"{PebbleStateEntropyReward.__name__} reward function", ) super().__init__(reward_fn=reward_fn, **kwargs) @@ -1729,10 +1739,10 @@ def train( ################################################### with self.logger.accumulate_means("agent"): self.logger.log( - f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps" + f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps", ) self.trajectory_generator.unsupervised_pretrain( - unsupervised_pretrain_timesteps + unsupervised_pretrain_timesteps, ) for i, num_pairs in enumerate(preference_query_schedule): @@ -1811,7 +1821,7 @@ def _preference_gather_schedule(self, total_comparisons): def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]: unsupervised_pretrain_timesteps = int( - total_timesteps * self.unsupervised_agent_pretrain_frac + total_timesteps * self.unsupervised_agent_pretrain_frac, ) timesteps_per_iteration, extra_timesteps = divmod( total_timesteps - unsupervised_pretrain_timesteps, diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index e96148c39..81a3d579e 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -7,7 +7,7 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples -from imitation.rewards.reward_function import RewardFn, ReplayBufferAwareRewardFn +from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn from imitation.util import util @@ -24,19 +24,20 @@ def _samples_to_reward_fn_input( class ReplayBufferView: - """A read-only view over a valid records in a ReplayBuffer. - - Args: - observations_buffer: Array buffer holding observations - buffer_slice_provider: Function returning slice of buffer - with valid observations - """ + """A read-only view over a valid records in a ReplayBuffer.""" def __init__( self, observations_buffer: np.ndarray, buffer_slice_provider: Callable[[], slice], ): + """Builds ReplayBufferView. + + Args: + observations_buffer: Array buffer holding observations + buffer_slice_provider: Function returning slice of buffer + with valid observations + """ self._observations_buffer_view = observations_buffer.view() self._observations_buffer_view.flags.writeable = False self._buffer_slice_provider = buffer_slice_provider @@ -67,9 +68,6 @@ def __init__( action_space: Action space replay_buffer_class: Class of the replay buffer. reward_fn: Reward function for reward relabeling. - on_initialized_callback: Callback called with reference to this object after - this instance is fully initialized. This provides a hook to access the - buffer after it is created from inside a Stable Baselines algorithm. **kwargs: keyword arguments for ReplayBuffer. """ # Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of diff --git a/src/imitation/rewards/reward_function.py b/src/imitation/rewards/reward_function.py index e9d7bed30..00b1da958 100644 --- a/src/imitation/rewards/reward_function.py +++ b/src/imitation/rewards/reward_function.py @@ -35,6 +35,11 @@ def __call__( class ReplayBufferAwareRewardFn(RewardFn, abc.ABC): + """Abstract class for a reward function that needs access to a replay buffer.""" + @abc.abstractmethod - def on_replay_buffer_initialized(self, replay_buffer: "ReplayBufferRewardWrapper"): + def on_replay_buffer_initialized( + self, + replay_buffer: "ReplayBufferRewardWrapper", # type: ignore[name-defined] + ): pass diff --git a/src/imitation/scripts/common/rl.py b/src/imitation/scripts/common/rl.py index e879bbaf8..d71e35211 100644 --- a/src/imitation/scripts/common/rl.py +++ b/src/imitation/scripts/common/rl.py @@ -89,7 +89,8 @@ def _maybe_add_relabel_buffer( _buffer_kwargs = dict( reward_fn=relabel_reward_fn, replay_buffer_class=rl_kwargs.get( - "replay_buffer_class", buffers.ReplayBuffer + "replay_buffer_class", + buffers.ReplayBuffer, ), ) rl_kwargs["replay_buffer_class"] = ReplayBufferRewardWrapper diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ca0e786ff..9876ee952 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -60,8 +60,10 @@ def train_defaults(): checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only) query_schedule = "hyperbolic" + # Whether to use the PEBBLE algorithm (https://arxiv.org/pdf/2106.05091.pdf) pebble_enabled = False + unsupervised_agent_pretrain_frac = 0.0 @train_preference_comparisons_ex.named_config diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index c848a6d09..659b47a74 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -10,13 +10,13 @@ import numpy as np import torch as th from sacred.observers import FileStorageObserver -from stable_baselines3.common import type_aliases, base_class, vec_env +from stable_baselines3.common import base_class, type_aliases, vec_env from imitation.algorithms import preference_comparisons from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward from imitation.data import types from imitation.policies import serialize -from imitation.rewards import reward_nets, reward_function +from imitation.rewards import reward_function, reward_nets from imitation.scripts.common import common, reward from imitation.scripts.common import rl as rl_common from imitation.scripts.common import train @@ -65,7 +65,7 @@ def make_reward_function( reward_net: reward_nets.RewardNet, *, pebble_enabled: bool = False, - pebble_nearest_neighbor_k: Optional[int] = None, + pebble_nearest_neighbor_k: int = 5, ): relabel_reward_fn = functools.partial( reward_net.predict_processed, @@ -73,7 +73,8 @@ def make_reward_function( ) if pebble_enabled: relabel_reward_fn = PebbleStateEntropyReward( - relabel_reward_fn, pebble_nearest_neighbor_k + relabel_reward_fn, # type: ignore[assignment] + pebble_nearest_neighbor_k, ) return relabel_reward_fn @@ -92,6 +93,7 @@ def make_agent_trajectory_generator( trajectory_generator_kwargs: Mapping[str, Any], ) -> preference_comparisons.AgentTrainer: if pebble_enabled: + assert isinstance(relabel_reward_fn, PebbleStateEntropyReward) return preference_comparisons.PebbleAgentTrainer( algorithm=agent, reward_fn=relabel_reward_fn, @@ -138,7 +140,7 @@ def train_preference_comparisons( allow_variable_horizon: bool, checkpoint_interval: int, query_schedule: Union[str, type_aliases.Schedule], - unsupervised_agent_pretrain_frac: Optional[float], + unsupervised_agent_pretrain_frac: float, ) -> Mapping[str, Any]: """Train a reward model using preference comparisons. diff --git a/tests/algorithms/pebble/test_entropy_reward.py b/tests/algorithms/pebble/test_entropy_reward.py index 918222382..84b59107a 100644 --- a/tests/algorithms/pebble/test_entropy_reward.py +++ b/tests/algorithms/pebble/test_entropy_reward.py @@ -1,17 +1,18 @@ +"""Tests for `imitation.algorithms.entropy_reward`.""" + import pickle -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch import numpy as np import torch as th from gym.spaces import Discrete -from stable_baselines3.common.preprocessing import get_obs_shape from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward from imitation.policies.replay_buffer_wrapper import ReplayBufferView from imitation.util import util SPACE = Discrete(4) -OBS_SHAPE = get_obs_shape(SPACE) +OBS_SHAPE = (1,) PLACEHOLDER = np.empty(OBS_SHAPE) BUFFER_SIZE = 20 @@ -25,7 +26,8 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng): reward_fn = PebbleStateEntropyReward(Mock(), K) reward_fn.set_replay_buffer( - ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE + ReplayBufferView(all_observations, lambda: slice(None)), + OBS_SHAPE, ) # Act @@ -34,17 +36,20 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng): # Assert expected = util.compute_state_entropy( - observations, all_observations.reshape(-1, *OBS_SHAPE), K + observations, + all_observations.reshape(-1, *OBS_SHAPE), + K, ) expected_normalized = reward_fn.entropy_stats.normalize( - th.as_tensor(expected) + th.as_tensor(expected), ).numpy() np.testing.assert_allclose(reward, expected_normalized) def test_pebble_entropy_reward_returns_normalized_values_for_pretraining(): with patch("imitation.util.util.compute_state_entropy") as m: - # mock entropy computation so that we can test only stats collection in this test + # mock entropy computation so that we can test + # only stats collection in this test m.side_effect = lambda obs, all_obs, k: obs reward_fn = PebbleStateEntropyReward(Mock(), K) @@ -64,7 +69,10 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining(): reward_fn(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER) normalized_reward = reward_fn( - np.zeros(dim), PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + np.zeros(dim), + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, ) # Assert @@ -91,7 +99,10 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin # Assert assert reward == expected_reward learned_reward_mock.assert_called_once_with( - observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER + observations, + PLACEHOLDER, + PLACEHOLDER, + PLACEHOLDER, )