Skip to content

Commit

Permalink
#625 fix pre-commit errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 1, 2022
1 parent 2ab0780 commit f3decf1
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 52 deletions.
41 changes: 24 additions & 17 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
"""Reward function for the PEBBLE training algorithm."""

from enum import Enum, auto
from typing import Tuple
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch as th

from imitation.policies.replay_buffer_wrapper import (
ReplayBufferView,
ReplayBufferRewardWrapper,
ReplayBufferView,
)
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn
from imitation.util import util
from imitation.util.networks import RunningNorm


class PebbleRewardPhase(Enum):
"""States representing different behaviors for PebbleStateEntropyReward"""
"""States representing different behaviors for PebbleStateEntropyReward."""

UNSUPERVISED_EXPLORATION = auto() # Entropy based reward
POLICY_AND_REWARD_LEARNING = auto() # Learned reward


class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
"""
Reward function for implementation of the PEBBLE learning algorithm
(https://arxiv.org/pdf/2106.05091.pdf).
"""Reward function for implementation of the PEBBLE learning algorithm.
See https://arxiv.org/pdf/2106.05091.pdf .
The rewards returned by this function go through the three phases:
1. Before enough samples are collected for entropy calculation, the
Expand All @@ -38,33 +40,38 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
supplied with set_replay_buffer() or on_replay_buffer_initialized().
To transition to the last phase, unsupervised_exploration_finish() needs
to be called.
Args:
learned_reward_fn: The learned reward function used after unsupervised
exploration is finished
nearest_neighbor_k: Parameter for entropy computation (see
compute_state_entropy())
"""

# TODO #625: parametrize nearest_neighbor_k
def __init__(
self,
learned_reward_fn: RewardFn,
nearest_neighbor_k: int = 5,
):
"""Builds this class.
Args:
learned_reward_fn: The learned reward function used after unsupervised
exploration is finished
nearest_neighbor_k: Parameter for entropy computation (see
compute_state_entropy())
"""
self.learned_reward_fn = learned_reward_fn
self.nearest_neighbor_k = nearest_neighbor_k
self.entropy_stats = RunningNorm(1)
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION

# These two need to be set with set_replay_buffer():
self.replay_buffer_view = None
self.obs_shape = None
self.replay_buffer_view: Optional[ReplayBufferView] = None
self.obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]], None] = None

def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)

def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
def set_replay_buffer(
self,
replay_buffer: ReplayBufferView,
obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]],
):
self.replay_buffer_view = replay_buffer
self.obs_shape = obs_shape

Expand All @@ -87,7 +94,7 @@ def __call__(
def _entropy_reward(self, state, action, next_state, done):
if self.replay_buffer_view is None:
raise ValueError(
"Replay buffer must be supplied before entropy reward can be used"
"Replay buffer must be supplied before entropy reward can be used",
)
all_observations = self.replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
Expand Down
26 changes: 18 additions & 8 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
""" # noqa: DAR202

def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
"""Pre-train an agent if the trajectory generator uses one that
needs pre-training.
"""Pre-train an agent before collecting comparisons.
By default, this method does nothing and doesn't need
to be overridden in subclasses that don't require pre-training.
Expand Down Expand Up @@ -331,8 +330,8 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None:


class PebbleAgentTrainer(AgentTrainer):
"""
Specialization of AgentTrainer for PEBBLE training.
"""Specialization of AgentTrainer for PEBBLE training.
Includes unsupervised pretraining with an entropy based reward function.
"""

Expand All @@ -344,9 +343,20 @@ def __init__(
reward_fn: PebbleStateEntropyReward,
**kwargs,
) -> None:
"""Builds PebbleAgentTrainer.
Args:
reward_fn: Pebble reward function
**kwargs: additional keyword arguments to pass on to
the parent class
Raises:
ValueError: Unexpected type of reward_fn given.
"""
if not isinstance(reward_fn, PebbleStateEntropyReward):
raise ValueError(
f"{self.__class__.__name__} expects {PebbleStateEntropyReward.__name__} reward function"
f"{self.__class__.__name__} expects "
f"{PebbleStateEntropyReward.__name__} reward function",
)
super().__init__(reward_fn=reward_fn, **kwargs)

Expand Down Expand Up @@ -1729,10 +1739,10 @@ def train(
###################################################
with self.logger.accumulate_means("agent"):
self.logger.log(
f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps"
f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps",
)
self.trajectory_generator.unsupervised_pretrain(
unsupervised_pretrain_timesteps
unsupervised_pretrain_timesteps,
)

for i, num_pairs in enumerate(preference_query_schedule):
Expand Down Expand Up @@ -1811,7 +1821,7 @@ def _preference_gather_schedule(self, total_comparisons):

def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]:
unsupervised_pretrain_timesteps = int(
total_timesteps * self.unsupervised_agent_pretrain_frac
total_timesteps * self.unsupervised_agent_pretrain_frac,
)
timesteps_per_iteration, extra_timesteps = divmod(
total_timesteps - unsupervised_pretrain_timesteps,
Expand Down
20 changes: 9 additions & 11 deletions src/imitation/policies/replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples

from imitation.rewards.reward_function import RewardFn, ReplayBufferAwareRewardFn
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn
from imitation.util import util


Expand All @@ -24,19 +24,20 @@ def _samples_to_reward_fn_input(


class ReplayBufferView:
"""A read-only view over a valid records in a ReplayBuffer.
Args:
observations_buffer: Array buffer holding observations
buffer_slice_provider: Function returning slice of buffer
with valid observations
"""
"""A read-only view over a valid records in a ReplayBuffer."""

def __init__(
self,
observations_buffer: np.ndarray,
buffer_slice_provider: Callable[[], slice],
):
"""Builds ReplayBufferView.
Args:
observations_buffer: Array buffer holding observations
buffer_slice_provider: Function returning slice of buffer
with valid observations
"""
self._observations_buffer_view = observations_buffer.view()
self._observations_buffer_view.flags.writeable = False
self._buffer_slice_provider = buffer_slice_provider
Expand Down Expand Up @@ -67,9 +68,6 @@ def __init__(
action_space: Action space
replay_buffer_class: Class of the replay buffer.
reward_fn: Reward function for reward relabeling.
on_initialized_callback: Callback called with reference to this object after
this instance is fully initialized. This provides a hook to access the
buffer after it is created from inside a Stable Baselines algorithm.
**kwargs: keyword arguments for ReplayBuffer.
"""
# Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of
Expand Down
7 changes: 6 additions & 1 deletion src/imitation/rewards/reward_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def __call__(


class ReplayBufferAwareRewardFn(RewardFn, abc.ABC):
"""Abstract class for a reward function that needs access to a replay buffer."""

@abc.abstractmethod
def on_replay_buffer_initialized(self, replay_buffer: "ReplayBufferRewardWrapper"):
def on_replay_buffer_initialized(
self,
replay_buffer: "ReplayBufferRewardWrapper", # type: ignore[name-defined]
):
pass
3 changes: 2 additions & 1 deletion src/imitation/scripts/common/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def _maybe_add_relabel_buffer(
_buffer_kwargs = dict(
reward_fn=relabel_reward_fn,
replay_buffer_class=rl_kwargs.get(
"replay_buffer_class", buffers.ReplayBuffer
"replay_buffer_class",
buffers.ReplayBuffer,
),
)
rl_kwargs["replay_buffer_class"] = ReplayBufferRewardWrapper
Expand Down
2 changes: 2 additions & 0 deletions src/imitation/scripts/config/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ def train_defaults():

checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only)
query_schedule = "hyperbolic"

# Whether to use the PEBBLE algorithm (https://arxiv.org/pdf/2106.05091.pdf)
pebble_enabled = False
unsupervised_agent_pretrain_frac = 0.0


@train_preference_comparisons_ex.named_config
Expand Down
12 changes: 7 additions & 5 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import numpy as np
import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common import type_aliases, base_class, vec_env
from stable_baselines3.common import base_class, type_aliases, vec_env

from imitation.algorithms import preference_comparisons
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.data import types
from imitation.policies import serialize
from imitation.rewards import reward_nets, reward_function
from imitation.rewards import reward_function, reward_nets
from imitation.scripts.common import common, reward
from imitation.scripts.common import rl as rl_common
from imitation.scripts.common import train
Expand Down Expand Up @@ -65,15 +65,16 @@ def make_reward_function(
reward_net: reward_nets.RewardNet,
*,
pebble_enabled: bool = False,
pebble_nearest_neighbor_k: Optional[int] = None,
pebble_nearest_neighbor_k: int = 5,
):
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
update_stats=False,
)
if pebble_enabled:
relabel_reward_fn = PebbleStateEntropyReward(
relabel_reward_fn, pebble_nearest_neighbor_k
relabel_reward_fn, # type: ignore[assignment]
pebble_nearest_neighbor_k,
)
return relabel_reward_fn

Expand All @@ -92,6 +93,7 @@ def make_agent_trajectory_generator(
trajectory_generator_kwargs: Mapping[str, Any],
) -> preference_comparisons.AgentTrainer:
if pebble_enabled:
assert isinstance(relabel_reward_fn, PebbleStateEntropyReward)
return preference_comparisons.PebbleAgentTrainer(
algorithm=agent,
reward_fn=relabel_reward_fn,
Expand Down Expand Up @@ -138,7 +140,7 @@ def train_preference_comparisons(
allow_variable_horizon: bool,
checkpoint_interval: int,
query_schedule: Union[str, type_aliases.Schedule],
unsupervised_agent_pretrain_frac: Optional[float],
unsupervised_agent_pretrain_frac: float,
) -> Mapping[str, Any]:
"""Train a reward model using preference comparisons.
Expand Down
29 changes: 20 additions & 9 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""Tests for `imitation.algorithms.entropy_reward`."""

import pickle
from unittest.mock import patch, Mock
from unittest.mock import Mock, patch

import numpy as np
import torch as th
from gym.spaces import Discrete
from stable_baselines3.common.preprocessing import get_obs_shape

from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.util import util

SPACE = Discrete(4)
OBS_SHAPE = get_obs_shape(SPACE)
OBS_SHAPE = (1,)
PLACEHOLDER = np.empty(OBS_SHAPE)

BUFFER_SIZE = 20
Expand All @@ -25,7 +26,8 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):

reward_fn = PebbleStateEntropyReward(Mock(), K)
reward_fn.set_replay_buffer(
ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE
ReplayBufferView(all_observations, lambda: slice(None)),
OBS_SHAPE,
)

# Act
Expand All @@ -34,17 +36,20 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):

# Assert
expected = util.compute_state_entropy(
observations, all_observations.reshape(-1, *OBS_SHAPE), K
observations,
all_observations.reshape(-1, *OBS_SHAPE),
K,
)
expected_normalized = reward_fn.entropy_stats.normalize(
th.as_tensor(expected)
th.as_tensor(expected),
).numpy()
np.testing.assert_allclose(reward, expected_normalized)


def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
with patch("imitation.util.util.compute_state_entropy") as m:
# mock entropy computation so that we can test only stats collection in this test
# mock entropy computation so that we can test
# only stats collection in this test
m.side_effect = lambda obs, all_obs, k: obs

reward_fn = PebbleStateEntropyReward(Mock(), K)
Expand All @@ -64,7 +69,10 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
reward_fn(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)

normalized_reward = reward_fn(
np.zeros(dim), PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
np.zeros(dim),
PLACEHOLDER,
PLACEHOLDER,
PLACEHOLDER,
)

# Assert
Expand All @@ -91,7 +99,10 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
# Assert
assert reward == expected_reward
learned_reward_mock.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
observations,
PLACEHOLDER,
PLACEHOLDER,
PLACEHOLDER,
)


Expand Down

0 comments on commit f3decf1

Please sign in to comment.