Skip to content

Commit

Permalink
#641 fix static analysis and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Michelfeit committed Dec 10, 2022
1 parent 531b353 commit 74ba96b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 25 deletions.
32 changes: 26 additions & 6 deletions src/imitation/algorithms/pebble/entropy_reward.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Reward function for the PEBBLE training algorithm."""

import enum
from typing import Optional, Tuple
from typing import Any, Callable, Optional, Tuple

import gym
import numpy as np
Expand All @@ -18,10 +18,16 @@


class InsufficientObservations(RuntimeError):
"""Error signifying not enough observations for entropy calculation."""

pass


class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn):
"""RewardNet wrapping entropy reward function."""

__call__: Callable[..., Any] # Needed to appease pytype

def __init__(
self,
nearest_neighbor_k: int,
Expand Down Expand Up @@ -53,6 +59,9 @@ def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper)
This method needs to be called, e.g., after unpickling.
See also __getstate__() / __setstate__().
Args:
replay_buffer: replay buffer with history of observations
"""
assert self.observation_space == replay_buffer.observation_space
assert self.action_space == replay_buffer.action_space
Expand All @@ -72,16 +81,18 @@ def forward(
all_observations = self._replay_buffer_view.observations
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
all_observations = all_observations.reshape(
(-1,) + self.observation_space.shape
(-1,) + self.observation_space.shape,
)

if all_observations.shape[0] < self.nearest_neighbor_k:
raise InsufficientObservations(
"Insufficient observations for entropy calculation"
"Insufficient observations for entropy calculation",
)

return util.compute_state_entropy(
state, all_observations, self.nearest_neighbor_k
state,
all_observations,
self.nearest_neighbor_k,
)

def preprocess(
Expand All @@ -95,6 +106,15 @@ def preprocess(
We also know forward() only works with state, so no need to convert
other tensors.
Args:
state: The observation input.
action: The action input.
next_state: The observation input.
done: Whether the episode has terminated.
Returns:
Observations preprocessed by converting them to Tensor.
"""
state_th = util.safe_to_tensor(state).to(self.device)
action_th = next_state_th = done_th = th.empty(0)
Expand Down Expand Up @@ -172,8 +192,8 @@ def __call__(
try:
return self.entropy_reward_fn(state, action, next_state, done)
except InsufficientObservations:
# not enough observations to compare to, fall back to the learned function;
# (falling back to a constant may also be ok)
# not enough observations to compare to, fall back to the learned
# function; (falling back to a constant may also be ok)
return self.learned_reward_fn(state, action, next_state, done)
else:
return self.learned_reward_fn(state, action, next_state, done)
6 changes: 5 additions & 1 deletion src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,17 @@ def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
"""Pre-train an agent before collecting comparisons.
Override this behavior in subclasses that implement pre-training.
If not overriden, this method raises ValueError when non-zero steps are
If not overridden, this method raises ValueError when non-zero steps are
allocated for pre-training.
Args:
steps: number of environment steps to train for.
**kwargs: additional keyword arguments to pass on to
the training procedure.
Raises:
ValueError: Unsupervised pre-training not implemented but non-zero
steps are allocated for pre-training.
"""
if steps > 0:
raise ValueError(
Expand Down
22 changes: 13 additions & 9 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pathlib
from typing import Any, Mapping, Optional, Type, Union

import gym
import numpy as np
import torch as th
from sacred.observers import FileStorageObserver
Expand All @@ -24,6 +25,7 @@
ReplayBufferRewardWrapper,
)
from imitation.rewards import reward_function, reward_nets
from imitation.rewards.reward_function import RewardFn
from imitation.rewards.reward_nets import NormalizedRewardNet
from imitation.scripts.common import common, reward
from imitation.scripts.common import rl as rl_common
Expand Down Expand Up @@ -80,21 +82,22 @@ def make_reward_function(
reward_net.predict_processed,
update_stats=False,
)
observation_space = reward_net.observation_space
action_space = reward_net.action_space
if pebble_enabled:
relabel_reward_fn = create_pebble_reward_fn(
relabel_reward_fn,
relabel_reward_fn, # type: ignore[assignment]
pebble_nearest_neighbor_k,
action_space,
observation_space,
reward_net.action_space,
reward_net.observation_space,
)
return relabel_reward_fn


def create_pebble_reward_fn(
relabel_reward_fn, pebble_nearest_neighbor_k, action_space, observation_space
):
relabel_reward_fn: RewardFn,
pebble_nearest_neighbor_k: int,
action_space: gym.Space,
observation_space: gym.Space,
) -> PebbleStateEntropyReward:
entropy_reward_net = EntropyRewardNet(
nearest_neighbor_k=pebble_nearest_neighbor_k,
observation_space=observation_space,
Expand All @@ -111,13 +114,14 @@ def __call__(self, *args, **kwargs) -> np.ndarray:
return normalized_entropy_reward_net.predict_processed(*args, **kwargs)

def on_replay_buffer_initialized(
self, replay_buffer: ReplayBufferRewardWrapper
self,
replay_buffer: ReplayBufferRewardWrapper,
):
entropy_reward_net.on_replay_buffer_initialized(replay_buffer)

return PebbleStateEntropyReward(
EntropyRewardFn(),
relabel_reward_fn, # type: ignore[assignment]
relabel_reward_fn,
)


Expand Down
25 changes: 20 additions & 5 deletions tests/algorithms/pebble/test_entropy_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining():

np.testing.assert_allclose(reward, expected_result)
entropy_fn.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
observations,
PLACEHOLDER,
PLACEHOLDER,
PLACEHOLDER,
)


Expand All @@ -57,7 +60,10 @@ def test_pebble_entropy_reward_returns_learned_rew_on_insufficient_observations(

np.testing.assert_allclose(reward, expected_result)
learned_fn.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
observations,
PLACEHOLDER,
PLACEHOLDER,
PLACEHOLDER,
)


Expand All @@ -74,7 +80,10 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin

np.testing.assert_allclose(reward, expected_result)
learned_fn.assert_called_once_with(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
observations,
PLACEHOLDER,
PLACEHOLDER,
PLACEHOLDER,
)


Expand All @@ -97,7 +106,10 @@ def test_entropy_reward_net_returns_entropy_for_pretraining(rng):

# Act
reward = reward_net.predict_processed(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
observations,
PLACEHOLDER,
PLACEHOLDER,
PLACEHOLDER,
)

# Assert
Expand All @@ -118,7 +130,10 @@ def test_entropy_reward_net_raises_on_insufficient_observations(rng):
# Act
with pytest.raises(InsufficientObservations):
reward_net.predict_processed(
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
observations,
PLACEHOLDER,
PLACEHOLDER,
PLACEHOLDER,
)


Expand Down
12 changes: 8 additions & 4 deletions tests/algorithms/test_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

import imitation.testing.reward_nets as testing_reward_nets
from imitation.algorithms import preference_comparisons
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
from imitation.data import types
from imitation.data.types import TrajectoryWithRew
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
from imitation.regularization import regularizers, updaters
from imitation.rewards import reward_nets
from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn
from imitation.util import networks, util

UNCERTAINTY_ON = ["logit", "probability", "label"]
Expand Down Expand Up @@ -84,9 +84,13 @@ def replay_buffer(rng):
def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer):
replay_buffer_mock = Mock()
replay_buffer_mock.buffer_view = replay_buffer
replay_buffer_mock.obs_shape = (4,)
reward_fn = PebbleStateEntropyReward(
reward_net.predict_processed, venv.observation_space, venv.action_space
replay_buffer_mock.observation_space = venv.observation_space
replay_buffer_mock.action_space = venv.action_space
reward_fn = create_pebble_reward_fn(
reward_net.predict_processed,
5,
venv.action_space,
venv.observation_space,
)
reward_fn.on_replay_buffer_initialized(replay_buffer_mock)
return preference_comparisons.PebbleAgentTrainer(
Expand Down
2 changes: 2 additions & 0 deletions tests/scripts/test_train_preference_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tests train_preferences_comparisons helper methods."""

from unittest.mock import Mock, patch

import numpy as np
Expand Down

0 comments on commit 74ba96b

Please sign in to comment.