Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Building blocks for PEBBLE #625

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
8d5900a
Welfords alg and test
dan-pandori Nov 10, 2022
4aac074
Next func
dan-pandori Nov 10, 2022
383fce0
Test update
dan-pandori Nov 10, 2022
055fa67
compute_state_entropy and test
dan-pandori Nov 11, 2022
5c278f4
Sketch of the entropy reward replay buffer
dan-pandori Nov 11, 2022
49dc26f
Batchify state entropy func
dan-pandori Nov 11, 2022
394ad56
Final sketch of replay entropy buffer.
dan-pandori Nov 11, 2022
21da532
First test
dan-pandori Nov 11, 2022
15dad99
Test cleanup
dan-pandori Nov 11, 2022
0c28079
Update
dan-pandori Nov 11, 2022
5ab9d28
Commit for diff
dan-pandori Nov 12, 2022
9410c31
Push final-ish state
dan-pandori Nov 12, 2022
fdcdf0d
#625 refactor RunningMeanAndVar
Nov 29, 2022
0cd1255
#625 use RunningNorm instead of RunningMeanAndVar
Nov 29, 2022
d88ba44
#625 make copy of train_preference_comparisons.py for pebble
Nov 29, 2022
2d836de
#625 use an OffPolicy for pebble
Nov 29, 2022
ec5f67e
#625 fix assumptions about shapes in ReplayBufferEntropyRewardWrapper
Nov 30, 2022
da228bd
#625 entropy reward as a function
Nov 30, 2022
1ec645a
#625 make entropy reward serializable with pickle
Dec 1, 2022
4e16c42
#625 revert change of compute_state_entropy() from tensors to numpy
Dec 1, 2022
acb51be
#625 extract _preference_feedback_schedule()
Dec 1, 2022
8143ba3
#625 introduce parameter for pretraining steps
Dec 1, 2022
184e191
#625 add initialized callback to ReplayBufferRewardWrapper
Dec 1, 2022
52d914a
#625 fix entropy_reward.py
Dec 1, 2022
1f01a7a
#625 remove ReplayBufferEntropyRewardWrapper
Dec 1, 2022
1fbc590
#625 introduce ReplayBufferAwareRewardFn
Dec 1, 2022
e19dd85
#625 rename PebbleStateEntropyReward
Dec 1, 2022
da77f5c
#625 PebbleStateEntropyReward can switch from unsupervised pretraining
Dec 1, 2022
a11e775
#625 add optional pretraining to PreferenceComparisons
Dec 1, 2022
7b12162
#625 PebbleStateEntropyReward supports the initial phase before repla…
Dec 1, 2022
e354e16
#625 entropy_reward can automatically detect if enough observations a…
Dec 1, 2022
b8ccf2f
#625 fix entropy shape
Dec 1, 2022
c5f1dba
#625 rename unsupervised_agent_pretrain_frac parameter
Dec 1, 2022
0ba8959
#625 specialized PebbleAgentTrainer to distinguish from old preferenc…
Dec 1, 2022
c55fee7
#625 merge pebble to train_preference_comparisons.py and configure on…
Dec 1, 2022
1f9642a
#625 plug in pebble according to parameters
Dec 1, 2022
6f05b1d
#625 fix pre-commit errors
Dec 1, 2022
c787877
#625 add test for pebble agent trainer
Dec 1, 2022
b9c5614
#625 fix more pre-commit errors
Dec 1, 2022
40e7387
#625 fix even more pre-commit errors
Dec 2, 2022
aad2e7c
code review - Update src/imitation/policies/replay_buffer_wrapper.py
mifeet Dec 2, 2022
e0aea61
#625 code review
Dec 2, 2022
f0a3359
#625 code review: do not allocate timesteps for pretraining if there …
Dec 2, 2022
8cb2449
Update src/imitation/algorithms/preference_comparisons.py
mifeet Dec 2, 2022
378baa8
#625 code review: remove ignore
Dec 2, 2022
d7ad414
#625 code review - skip pretrainining if zero timesteps
Dec 2, 2022
412550d
#625 code review: separate pebble and environment configuration
Dec 2, 2022
7c3470e
#625 fix even even more pre-commit errors
Dec 2, 2022
73b1e36
#625 fix even even more pre-commit errors
Dec 2, 2022
6daa473
#641 code review: remove set_replay_buffer
Dec 7, 2022
c80fb80
#641 code review: fix comment
Dec 7, 2022
50577b0
#641 code review: replace RunningNorm with NormalizedRewardNet
Dec 10, 2022
531b353
#641 code review: refactor PebbleStateEntropyReward so that inner Rew…
Dec 10, 2022
74ba96b
#641 fix static analysis and tests
Dec 10, 2022
b344cbd
#641 increase coverage
Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/imitation/algorithms/pebble/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""PEBBLE specific algorithms."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a bit odd that we have preference_comparisons.py in a single file but PEBBLE (much smaller) split across several files. That's probably a sign we should split up preference_comparisons.py not aggregate PEBBLE though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that, e.g., classes for work with fragments and preference gathering seem like independent pieces of logic. Probably for another PR, though.

199 changes: 199 additions & 0 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Reward function for the PEBBLE training algorithm."""

import enum
from typing import Any, Callable, Optional, Tuple

import gym
import numpy as np
import torch as th

from imitation.policies.replay_buffer_wrapper import (
ReplayBufferAwareRewardFn,
ReplayBufferRewardWrapper,
ReplayBufferView,
)
from imitation.rewards.reward_function import RewardFn
from imitation.rewards.reward_nets import RewardNet
from imitation.util import util


class InsufficientObservations(RuntimeError):
"""Error signifying not enough observations for entropy calculation."""

pass


class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn):
"""RewardNet wrapping entropy reward function."""

__call__: Callable[..., Any] # Needed to appease pytype

def __init__(
self,
nearest_neighbor_k: int,
observation_space: gym.Space,
action_space: gym.Space,
normalize_images: bool = True,
replay_buffer_view: Optional[ReplayBufferView] = None,
):
"""Initialize the RewardNet.

Args:
nearest_neighbor_k: Parameter for entropy computation (see
compute_state_entropy())
observation_space: the observation space of the environment
action_space: the action space of the environment
normalize_images: whether to automatically normalize
image observations to [0, 1] (from 0 to 255). Defaults to True.
replay_buffer_view: Replay buffer view with observations to compare
against when computing entropy. If None is given, the buffer needs to
be set with on_replay_buffer_initialized() before EntropyRewardNet can
be used
"""
super().__init__(observation_space, action_space, normalize_images)
self.nearest_neighbor_k = nearest_neighbor_k
self._replay_buffer_view = replay_buffer_view

def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
"""Sets replay buffer.

This method needs to be called, e.g., after unpickling.
See also __getstate__() / __setstate__().

Args:
replay_buffer: replay buffer with history of observations
"""
assert self.observation_space == replay_buffer.observation_space
assert self.action_space == replay_buffer.action_space
self._replay_buffer_view = replay_buffer.buffer_view

def forward(
self,
state: th.Tensor,
action: th.Tensor,
next_state: th.Tensor,
done: th.Tensor,
) -> th.Tensor:
assert (
self._replay_buffer_view is not None
), "Missing replay buffer (possibly after unpickle)"

all_observations = self._replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape(
(-1,) + self.observation_space.shape,
)

if all_observations.shape[0] < self.nearest_neighbor_k:
raise InsufficientObservations(
"Insufficient observations for entropy calculation",
)

return util.compute_state_entropy(
state,
all_observations,
self.nearest_neighbor_k,
)

def preprocess(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
"""Override default preprocessing to avoid the default one-hot encoding.

We also know forward() only works with state, so no need to convert
other tensors.

Args:
state: The observation input.
action: The action input.
next_state: The observation input.
done: Whether the episode has terminated.

Returns:
Observations preprocessed by converting them to Tensor.
"""
state_th = util.safe_to_tensor(state).to(self.device)
action_th = next_state_th = done_th = th.empty(0)
return state_th, action_th, next_state_th, done_th

def __getstate__(self):
state = self.__dict__.copy()
del state["_replay_buffer_view"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._replay_buffer_view = None


class PebbleRewardPhase(enum.Enum):
"""States representing different behaviors for PebbleStateEntropyReward."""

UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward


class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
mifeet marked this conversation as resolved.
Show resolved Hide resolved
"""Reward function for implementation of the PEBBLE learning algorithm.

See https://arxiv.org/abs/2106.05091 .

The rewards returned by this function go through the three phases:
1. Before enough samples are collected for entropy calculation, the
underlying function is returned. This shouldn't matter because
OffPolicyAlgorithms have an initialization period for `learning_starts`
timesteps.
2. During the unsupervised exploration phase, entropy based reward is returned
3. After unsupervised exploration phase is finished, the underlying learned
reward is returned.

The second phase requires that a buffer with observations to compare against is
supplied with on_replay_buffer_initialized(). To transition to the last phase,
unsupervised_exploration_finish() needs to be called.
"""

def __init__(
self,
entropy_reward_fn: RewardFn,
learned_reward_fn: RewardFn,
):
"""Builds this class.

Args:
entropy_reward_fn: The entropy-based reward function used during
unsupervised exploration
learned_reward_fn: The learned reward function used after unsupervised
exploration is finished
"""
self.entropy_reward_fn = entropy_reward_fn
self.learned_reward_fn = learned_reward_fn
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION

def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
mifeet marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self.entropy_reward_fn, ReplayBufferAwareRewardFn):
self.entropy_reward_fn.on_replay_buffer_initialized(replay_buffer)

def unsupervised_exploration_finish(self):
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
self.state = PebbleRewardPhase.POLICY_AND_REWARD_LEARNING

def __call__(
self,
state: np.ndarray,
action: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> np.ndarray:
if self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION:
try:
return self.entropy_reward_fn(state, action, next_state, done)
except InsufficientObservations:
# not enough observations to compare to, fall back to the learned
# function; (falling back to a constant may also be ok)
return self.learned_reward_fn(state, action, next_state, done)
else:
return self.learned_reward_fn(state, action, next_state, done)
Loading