diff --git a/ci/clean_notebooks.py b/ci/clean_notebooks.py
index 2531757be..9b87aaea3 100755
--- a/ci/clean_notebooks.py
+++ b/ci/clean_notebooks.py
@@ -18,6 +18,7 @@ class UncleanNotebookError(Exception):
"metadata": {"do": "constant", "value": dict()},
"source": {"do": "keep"},
"id": {"do": "keep"},
+ "attachments": {"do": "constant", "value": {}},
}
code_structure: Dict[str, Dict[str, Any]] = {
@@ -76,7 +77,8 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None:
if key not in structure[cell["cell_type"]]:
if check_only:
raise UncleanNotebookError(
- f"Notebook {file} has unknown cell key {key}",
+ f"Notebook {file} has unknown cell key {key} for cell type "
+ + f"{cell['cell_type']}",
)
del cell[key]
was_dirty = True
@@ -108,7 +110,12 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None:
def parse_args():
- """Parse command-line arguments."""
+ """Parse command-line arguments.
+
+ Returns:
+ parser: The parser object.
+ args: The parsed arguments.
+ """
# if the argument --check has been passed, check if the notebooks are clean
# otherwise, clean them in-place
parser = argparse.ArgumentParser()
@@ -125,7 +132,14 @@ def parse_args():
def get_files(input_paths: List):
- """Build list of files to scan from list of paths and files."""
+ """Build list of files to scan from list of paths and files.
+
+ Args:
+ input_paths: List of paths and files to scan.
+
+ Returns:
+ files: List of files to scan.
+ """
files = []
for file in input_paths:
if file.is_dir():
diff --git a/docs/index.rst b/docs/index.rst
index 3b3a9e1be..a0c061a48 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -86,6 +86,7 @@ If you use ``imitation`` in your research project, please cite our paper to help
tutorials/4_train_airl
tutorials/5_train_preference_comparisons
tutorials/5a_train_preference_comparisons_with_cnn
+ tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback
tutorials/6_train_mce
tutorials/7_train_density
tutorials/8_train_sqil
diff --git a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb
new file mode 100644
index 000000000..5f768784d
--- /dev/null
+++ b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb
@@ -0,0 +1,251 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/5_train_preference_comparisons.ipynb)\n",
+ "# Learning a Reward Function using Preference Comparisons with Synchronous Human Feedback\n",
+ "\n",
+ "You can request human feedback via synchronous CLI or Notebook interactions as well. The setup is only slightly different than it would be with a synthetic preference gatherer."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here's the starting setup. The major differences from the synthetic setup are indicated with comments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pathlib\n",
+ "import random\n",
+ "import tempfile\n",
+ "from imitation.algorithms import preference_comparisons\n",
+ "from imitation.rewards.reward_nets import BasicRewardNet\n",
+ "from imitation.util import video_wrapper\n",
+ "from imitation.util.networks import RunningNorm\n",
+ "from imitation.util.util import make_vec_env\n",
+ "from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor\n",
+ "import gym\n",
+ "from stable_baselines3 import PPO\n",
+ "import numpy as np\n",
+ "\n",
+ "# Add a temporary directory for video recordings of trajectories. Unfortunately Jupyter\n",
+ "# won't play videos outside the current directory, so we have to put them here. We'll\n",
+ "# delete them at the end of the script.\n",
+ "video_dir = tempfile.mkdtemp(dir=\".\", prefix=\"videos_\")\n",
+ "\n",
+ "rng = np.random.default_rng(0)\n",
+ "\n",
+ "# Add a video wrapper to the environment. This will record videos of the agent's\n",
+ "# trajectories so we can review them later.\n",
+ "venv = make_vec_env(\n",
+ " \"Pendulum-v1\",\n",
+ " rng=rng,\n",
+ " post_wrappers={\n",
+ " \"VideoWrapper\": video_wrapper.video_wrapper_factory(\n",
+ " pathlib.Path(video_dir), single_video=False\n",
+ " )\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "reward_net = BasicRewardNet(\n",
+ " venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm\n",
+ ")\n",
+ "\n",
+ "fragmenter = preference_comparisons.RandomFragmenter(\n",
+ " warning_threshold=0,\n",
+ " rng=rng,\n",
+ ")\n",
+ "\n",
+ "querent = preference_comparisons.PreferenceQuerent()\n",
+ "\n",
+ "# This gatherer will show the user (you!) pairs of trajectories and ask it to choose\n",
+ "# which one is better. It will then use the user's feedback to train the reward network.\n",
+ "gatherer = preference_comparisons.SynchronousHumanGatherer(video_dir=video_dir)\n",
+ "\n",
+ "preference_model = preference_comparisons.PreferenceModel(reward_net)\n",
+ "reward_trainer = preference_comparisons.BasicRewardTrainer(\n",
+ " preference_model=preference_model,\n",
+ " loss=preference_comparisons.CrossEntropyRewardLoss(),\n",
+ " epochs=3,\n",
+ " rng=rng,\n",
+ ")\n",
+ "\n",
+ "agent = PPO(\n",
+ " policy=FeedForward32Policy,\n",
+ " policy_kwargs=dict(\n",
+ " features_extractor_class=NormalizeFeaturesExtractor,\n",
+ " features_extractor_kwargs=dict(normalize_class=RunningNorm),\n",
+ " ),\n",
+ " env=venv,\n",
+ " seed=0,\n",
+ " n_steps=2048 // venv.num_envs,\n",
+ " batch_size=64,\n",
+ " ent_coef=0.0,\n",
+ " learning_rate=0.0003,\n",
+ " n_epochs=10,\n",
+ ")\n",
+ "\n",
+ "trajectory_generator = preference_comparisons.AgentTrainer(\n",
+ " algorithm=agent,\n",
+ " reward_fn=reward_net,\n",
+ " venv=venv,\n",
+ " exploration_frac=0.0,\n",
+ " rng=rng,\n",
+ ")\n",
+ "\n",
+ "pref_comparisons = preference_comparisons.PreferenceComparisons(\n",
+ " trajectory_generator,\n",
+ " reward_net,\n",
+ " num_iterations=5,\n",
+ " fragmenter=fragmenter,\n",
+ " preference_querent=querent,\n",
+ " preference_gatherer=gatherer,\n",
+ " reward_trainer=reward_trainer,\n",
+ " fragment_length=100,\n",
+ " transition_oversampling=1,\n",
+ " initial_comparison_frac=0.1,\n",
+ " allow_variable_horizon=False,\n",
+ " initial_epoch_multiplier=1,\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We're going to train with only 20 comparisons to make it faster for you to evaluate. The videos will appear in-line in this notebook for you to watch, and a text input will appear for you to choose one."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pref_comparisons.train(\n",
+ " total_timesteps=5_000, # For good performance this should be 1_000_000\n",
+ " total_comparisons=20, # For good performance this should be 5_000\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "From this point onward, this notebook is the same as [the synthetic gatherer notebook](5_train_preference_comparisons.ipynb).\n",
+ "\n",
+ "After we trained the reward network using the preference comparisons algorithm, we can wrap our environment with that learned reward."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n",
+ "\n",
+ "\n",
+ "learned_reward_venv = RewardVecEnvWrapper(venv, reward_net.predict)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we can train an agent, that only sees those learned reward."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from stable_baselines3 import PPO\n",
+ "from stable_baselines3.ppo import MlpPolicy\n",
+ "\n",
+ "learner = PPO(\n",
+ " policy=MlpPolicy,\n",
+ " env=learned_reward_venv,\n",
+ " seed=0,\n",
+ " batch_size=64,\n",
+ " ent_coef=0.0,\n",
+ " learning_rate=0.0003,\n",
+ " n_epochs=10,\n",
+ " n_steps=64,\n",
+ ")\n",
+ "learner.learn(1000) # Note: set to 100000 to train a proficient expert"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Then we can evaluate it using the original reward."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from stable_baselines3.common.evaluation import evaluate_policy\n",
+ "\n",
+ "reward, _ = evaluate_policy(learner.policy, venv, 10)\n",
+ "print(reward)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# clean up the videos we made\n",
+ "import shutil\n",
+ "\n",
+ "shutil.rmtree(video_dir)"
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "439158cd89905785fcc749928062ade7bfccc3f087fab145e5671f895c635937"
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/setup.py b/setup.py
index 1aa407456..c62c5f7d5 100644
--- a/setup.py
+++ b/setup.py
@@ -55,6 +55,8 @@
"wandb==0.12.21",
"setuptools_scm~=7.0.5",
"pre-commit>=2.20.0",
+ "types-requests~=2.31.0.1",
+ "requests-mock~=1.11.0",
]
+ PARALLEL_REQUIRE
+ ATARI_REQUIRE
@@ -209,6 +211,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"huggingface_sb3~=3.0",
"optuna>=3.0.1",
"datasets>=2.8.0",
+ "opencv-python", # TODO: specify version
],
tests_require=TESTS_REQUIRE,
extras_require={
diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py
index 14a8fad5b..d4500f8ea 100644
--- a/src/imitation/algorithms/preference_comparisons.py
+++ b/src/imitation/algorithms/preference_comparisons.py
@@ -3,10 +3,16 @@
Trains a reward model and optionally a policy based on preferences
between trajectory fragments.
"""
+
import abc
+import base64
import math
+import os
+import pathlib
import pickle
import re
+import uuid
+from urllib.parse import urljoin
from collections import defaultdict
from typing import (
Any,
@@ -24,8 +30,11 @@
overload,
)
+import cv2
import numpy as np
+import requests
import torch as th
+from moviepy.editor import ImageSequenceClip
from scipy import special
from stable_baselines3.common import base_class, type_aliases, utils, vec_env
from torch import nn
@@ -778,45 +787,276 @@ def variance_estimate(self, rews1: th.Tensor, rews2: th.Tensor) -> float:
return var_estimate
-class PreferenceGatherer(abc.ABC):
- """Base class for gathering preference comparisons between trajectory fragments."""
+class PreferenceQuerent:
+ """Dummy class for querying preferences between trajectory fragments."""
def __init__(
self,
rng: Optional[np.random.Generator] = None,
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
) -> None:
+ """Initializes the preference querent.
+
+ Args:
+ rng: random number generator
+ custom_logger: Where to log to; if None (default), creates a new logger.
+ """
+ self.logger = custom_logger or imit_logger.configure()
+ self.rng = rng
+
+ def __call__(
+ self,
+ queries: Sequence[TrajectoryWithRewPair],
+ ) -> Dict[str, TrajectoryWithRewPair]:
+ """Queries the user for their preferences.
+
+ This dummy implementation does nothing because by default the queries are
+ answered by an oracle.
+
+ Args:
+ queries: sequence of pairs of trajectory fragments
+
+ Returns:
+ dictionary with queries and their respective UUIDs
+ """
+ return {str(uuid.uuid4()): query for query in queries}
+
+
+class VideoBasedQuerent(PreferenceQuerent):
+ """Writes videos for each query to the local file system for later use by child querent (and gatherer) classes."""
+
+ def __init__(
+ self,
+ video_output_dir: Union[str, os.PathLike],
+ video_type="webm",
+ video_fps: int = 20,
+ rng: Optional[np.random.Generator] = None,
+ custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
+ ):
+ """Initializes the querent.
+
+ Args:
+ video_output_dir: path to the video clip directory.
+ video_type: specifies the video format, e.g. 'webm' or 'mp4'
+ video_fps: frames per second of the generated videos.
+ rng: random number generator, if applicable.
+ custom_logger: Where to log to; if None (default), creates a new logger.
+ """
+ super().__init__(custom_logger=custom_logger)
+ self.rng = rng
+ self.video_output_dir = video_output_dir
+ self.video_type = video_type
+ self.frames_per_second = video_fps
+
+ # Create video directory
+ os.makedirs(self.video_output_dir, exist_ok=True)
+
+ def __call__(
+ self,
+ queries: Sequence[TrajectoryWithRewPair],
+ ) -> Dict[str, TrajectoryWithRewPair]:
+ identified_queries = super().__call__(queries)
+ for query_id, query in identified_queries.items():
+ self._write_query_videos(query_id, query)
+ return identified_queries
+
+ def _write_query_videos(self, query_id, query):
+ for i, alternative in enumerate(("left", "right")):
+ self._write_fragment_video(
+ fragment=query[i],
+ output_path=self._create_query_video_path(query_id, alternative),
+ )
+
+ def _create_query_video_path(self, query_id: str, alternative: str):
+ return (
+ pathlib.Path(self.video_output_dir)
+ / f"{query_id}-{alternative}.{self.video_type}"
+ )
+
+ def _write_fragment_video(
+ self,
+ fragment: TrajectoryWithRew,
+ output_path: AnyPath,
+ progress_logger: bool = True,
+ ) -> None:
+ """Write fragment video clip."""
+ frames = self._get_frames(fragment)
+ self._write(frames, output_path, progress_logger)
+
+ def _get_frames(
+ self, fragment: TrajectoryWithRew,
+ ) -> list[Union[os.PathLike, np.ndarray]]:
+ if self._rendered_image_of_observation_is_available(fragment):
+ return self._get_frames_for_each_observation(fragment)
+ else:
+ raise ValueError(
+ "No rendered images contained in info dict. "
+ "Please apply `RenderImageWrapper` to your environment.",
+ )
+
+ @staticmethod
+ def _rendered_image_of_observation_is_available(
+ fragment: TrajectoryWithRew,
+ ) -> bool:
+ return fragment.infos is not None and "rendered_img" in fragment.infos[0]
+
+ @staticmethod
+ def _get_frames_for_each_observation(fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]:
+ frames: list[Union[os.PathLike, np.ndarray]] = []
+ for i in range(len(fragment.infos)):
+ frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"]
+ frames.append(frame)
+ return frames
+
+ def _write(
+ self,
+ frames: List[Union[os.PathLike, np.ndarray]],
+ output_path: Union[str, bytes, os.PathLike],
+ progress_logger: bool,
+ ):
+ clip = ImageSequenceClip(
+ frames, fps=self.frames_per_second,
+ ) # accepts list of image paths and numpy arrays
+ output_path = pathlib.Path(output_path)
+ if output_path.suffix == ".gif":
+ clip.write_gif(
+ str(output_path),
+ program="ffmpeg",
+ logger="bar" if progress_logger else None,
+ )
+ else:
+ clip.write_videofile(
+ str(output_path), logger="bar" if progress_logger else None,
+ )
+
+
+class RESTQuerent(VideoBasedQuerent):
+ """Sends queries to a REST web service.
+
+ The queries are sent via PUT request as json payload to `collection_service_address`/`query_id`
+ in the following form:
+
+ {
+ "uuid": "1234",
+ "left": b64 encoded video,
+ "right" b64 encoded video
+ }
+
+ """
+
+ def __init__(
+ self,
+ collection_service_address: str,
+ video_output_dir: str,
+ video_fps: int = 20,
+ rng: Optional[np.random.Generator] = None,
+ custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
+ ):
+ """Initializes the querent.
+
+ Args:
+ collection_service_address: Network address of the collection service's REST interface.
+ video_output_dir: path to the video clip directory.
+ video_fps: frames per second of the generated videos.
+ rng: random number generator, if applicable.
+ custom_logger: Where to log to; if None (default), creates a new logger.
+ """
+ super().__init__(
+ video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger,
+ )
+ self.query_endpoint = collection_service_address
+
+ def __call__(
+ self,
+ queries: Sequence[TrajectoryWithRewPair],
+ ) -> Dict[str, TrajectoryWithRewPair]:
+ identified_queries = super().__call__(queries)
+ for query_id, query in identified_queries.items():
+ self._query(query_id)
+ return identified_queries
+
+ def _query(self, query_id: str):
+ video_data = self._load_video_data(query_id)
+ requests.put(
+ urljoin(self.query_endpoint, query_id),
+ json={"uuid": "{}".format(query_id), **video_data},
+ )
+
+ def _load_video_data(self, query_id: str) -> dict[str, str]:
+ video_data = {}
+ for alternative in ("left", "right"):
+ video_path = self._create_query_video_path(query_id, alternative)
+ if video_path.exists():
+ with open(video_path, "rb") as video_file:
+ video_data[alternative] = base64.b64encode(
+ video_file.read(),
+ ).decode("utf-8")
+ else:
+ raise RuntimeError(
+ f"Video to be loaded does not exist at {video_path}.",
+ )
+ return video_data
+
+
+class PreferenceGatherer(abc.ABC):
+ """Base class for gathering preference comparisons between trajectory fragments."""
+
+ def __init__(self, rng: Optional[np.random.Generator] = None,
+ custom_logger: Optional[imit_logger.HierarchicalLogger] = None) -> None:
"""Initializes the preference gatherer.
Args:
rng: random number generator, if applicable.
custom_logger: Where to log to; if None (default), creates a new logger.
+ querent_kwargs: Keyword arguments passed to the querent.
"""
# The random seed isn't used here, but it's useful to have this
# as an argument nevertheless because that means we can always
# pass in a seed in training scripts (without worrying about whether
# the PreferenceGatherer we use needs one).
- del rng
+ self.querent = PreferenceQuerent(rng, custom_logger)
self.logger = custom_logger or imit_logger.configure()
+ self.pending_queries: Dict = {}
- @abc.abstractmethod
- def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray:
- """Gathers the probabilities that fragment 1 is preferred in `fragment_pairs`.
-
- Args:
- fragment_pairs: sequence of pairs of trajectory fragments
+ def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]:
+ """Gathers preference probabilities for completed queries.
Returns:
- A numpy array with shape (b, ), where b is the length of the input
- (i.e. batch size). Each item in the array is the probability that
- fragment 1 is preferred over fragment 2 for the corresponding
- pair of fragments.
+ * A list of length b with queries for which preferences have been gathered.
+
+ * A numpy array with shape (b, ).
+ Each item in the array is the probability that fragment 1 is preferred
+ over fragment 2 for the corresponding query in the list above.
Note that for human feedback, these probabilities are simply 0 or 1
(or 0.5 in case of indifference), but synthetic models may yield other
probabilities.
""" # noqa: DAR202
+ gathered_queries = []
+ gathered_preferences = []
+
+ for query_id, query in list(self.pending_queries.items()):
+ preference = self._gather_preference(query_id)
+
+ if preference is not None:
+ # Preference for this query has been provided
+ if 0 <= preference <= 1:
+ gathered_queries.append(query)
+ gathered_preferences.append(preference)
+ # else: fragments were incomparable
+ del self.pending_queries[query_id]
+
+ return gathered_queries, np.array(gathered_preferences, dtype=np.float32)
+
+ @abc.abstractmethod
+ def _gather_preference(self, query_id: str) -> float:
+ raise NotImplementedError
+
+ def query(self, queries: Sequence[TrajectoryWithRewPair]) -> None:
+ identified_queries = self.querent(queries)
+ self.pending_queries = {**self.pending_queries, **identified_queries}
+
class SyntheticGatherer(PreferenceGatherer):
"""Computes synthetic preferences using ground-truth environment rewards."""
@@ -865,45 +1105,283 @@ def __init__(
if self.sample and self.rng is None:
raise ValueError("If `sample` is True, then `rng` must be provided.")
- def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray:
- """Computes probability fragment 1 is preferred over fragment 2."""
- returns1, returns2 = self._reward_sums(fragment_pairs)
+ def _gather_preference(self, query_id):
+ query = self.pending_queries[query_id]
+
+ return_1, return_2 = (
+ np.array(rollout.discounted_sum(query[0].rews, self.discount_factor), dtype=np.float32),
+ np.array(rollout.discounted_sum(query[1].rews, self.discount_factor), dtype=np.float32)
+ )
+
if self.temperature == 0:
- return (np.sign(returns1 - returns2) + 1) / 2
+ return (np.sign(return_1 - return_2) + 1) / 2
- returns1 /= self.temperature
- returns2 /= self.temperature
+ return_1 /= self.temperature
+ return_2 /= self.temperature
# clip the returns to avoid overflows in the softmax below
- returns_diff = np.clip(returns2 - returns1, -self.threshold, self.threshold)
+ returns_diff = np.clip(return_2 - return_1, -self.threshold, self.threshold)
# Instead of computing exp(rews1) / (exp(rews1) + exp(rews2)) directly,
# we divide enumerator and denominator by exp(rews1) to prevent overflows:
- model_probs = 1 / (1 + np.exp(returns_diff))
+ choice_probability = 1 / (1 + np.exp(returns_diff))
# Compute the mean binary entropy. This metric helps estimate
# how good we can expect the performance of the learned reward
# model to be at predicting preferences.
entropy = -(
- special.xlogy(model_probs, model_probs)
- + special.xlogy(1 - model_probs, 1 - model_probs)
+ special.xlogy(choice_probability, choice_probability)
+ + special.xlogy(1 - choice_probability, 1 - choice_probability)
).mean()
self.logger.record("entropy", entropy)
if self.sample:
assert self.rng is not None
- return self.rng.binomial(n=1, p=model_probs).astype(np.float32)
- return model_probs
-
- def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]:
- rews1, rews2 = zip(
- *[
- (
- rollout.discounted_sum(f1.rews, self.discount_factor),
- rollout.discounted_sum(f2.rews, self.discount_factor),
- )
- for f1, f2 in fragment_pairs
- ],
+ return self.rng.binomial(n=1, p=choice_probability)
+
+ return choice_probability
+
+
+class CommandLineGatherer(PreferenceGatherer):
+ """Queries for human preferences using the command line or a notebook."""
+
+ def __init__(
+ self,
+ video_dir: Union[str, os.PathLike],
+ video_width: int = 500,
+ video_height: int = 500,
+ frames_per_second: int = 25,
+ custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
+ rng: Optional[np.random.Generator] = None,
+ ) -> None:
+ """Initialize the human preference gatherer.
+
+ Args:
+ video_dir: directory where videos of the trajectories are saved.
+ video_width: width of the video in pixels.
+ video_height: height of the video in pixels.
+ frames_per_second: frames per second of the video.
+ custom_logger: Where to log to; if None (default), creates a new logger.
+ rng: random number generator
+ """
+ super().__init__(rng=rng, custom_logger=custom_logger)
+ self.querent = VideoBasedQuerent(
+ video_output_dir=video_dir, video_fps=frames_per_second,
+ )
+ self.video_dir = video_dir
+ self.video_width = video_width
+ self.video_height = video_height
+
+ def _gather_preference(self, query_id):
+ """Displays the videos of the two fragments.
+ If in the command line, it will pop out a video player for each fragment.
+ If in a notebook, it will display the videos.
+ Either way, it will request 1 or 2 to indicate which is preferred.
+
+ Args:
+ query_id: the id of the fragment pair to be displayed.
+
+ Returns:
+ True if the first fragment is preferred, False if not.
+
+ Raises:
+ KeyboardInterrupt: if the user presses q to quit.
+ """
+ frag1_video_path = pathlib.Path(self.video_dir, f"{query_id}-left.webm")
+ frag2_video_path = pathlib.Path(self.video_dir, f"{query_id}-right.webm")
+ if self._in_ipython():
+ self._display_videos_in_notebook(frag1_video_path, frag2_video_path)
+
+ pref = input(
+ "Which video is preferred? (1 or 2, or q to quit, or r to replay): ",
+ )
+ while pref not in ["1", "2", "q"]:
+ if pref == "r":
+ self._display_videos_in_notebook(frag1_video_path, frag2_video_path)
+ pref = input("Please enter 1 or 2 or q or r: ")
+
+ if pref == "q":
+ return None
+ elif pref == "1":
+ return 1.
+ elif pref == "2":
+ return 0.
+
+ # should never be hit
+ assert False
+ else:
+ return self._display_in_windows(frag1_video_path, frag2_video_path)
+
+ def _display_in_windows(
+ self,
+ frag1_video_path: pathlib.Path,
+ frag2_video_path: pathlib.Path,
+ ) -> bool:
+ """Displays the videos in separate windows.
+
+ The videos are displayed side by side and the user is asked to indicate
+ which one is preferred. The interaction is done in the window rather than
+ in the command line because the command line is not interactive when
+ the video is playing, and it's nice to allow the user to choose a video before
+ the videos are done playing. The downside is that the instructions appear on the
+ command line and the interaction happens in the video window.
+
+ Args:
+ frag1_video_path: path to the video file of the first fragment
+ frag2_video_path: path to the video file of the second fragment
+
+ Returns:
+ True if the first fragment is preferred, False if not.
+
+ Raises:
+ KeyboardInterrupt: if the user presses q to quit.
+ RuntimeError: if the video files cannot be opened.
+ """
+ print("Which video is preferred? (1 or 2, or q to quit, or r to replay):\n")
+
+ cap1 = cv2.VideoCapture(str(frag1_video_path))
+ cap2 = cv2.VideoCapture(str(frag2_video_path))
+ cv2.namedWindow("Video 1", cv2.WINDOW_NORMAL)
+ cv2.namedWindow("Video 2", cv2.WINDOW_NORMAL)
+
+ # set window sizes
+ cv2.resizeWindow("Video 1", self.video_width, self.video_height)
+ cv2.resizeWindow("Video 2", self.video_width, self.video_height)
+
+ # move windows side by side
+ cv2.moveWindow("Video 1", 0, 0)
+ cv2.moveWindow("Video 2", self.video_width, 0)
+
+ if not cap1.isOpened():
+ raise RuntimeError(f"Error opening video file {frag1_video_path}.")
+
+ if not cap2.isOpened():
+ raise RuntimeError(f"Error opening video file {frag2_video_path}.")
+
+ ret1, frame1 = cap1.read()
+ ret2, frame2 = cap2.read()
+ while cap1.isOpened() and cap2.isOpened():
+ if ret1 or ret2:
+ cv2.imshow("Video 1", frame1)
+ cv2.imshow("Video 2", frame2)
+ ret1, frame1 = cap1.read()
+ ret2, frame2 = cap2.read()
+
+ key = chr(cv2.waitKey(1) & 0xFF)
+ if key == "q":
+ cv2.destroyAllWindows()
+ raise KeyboardInterrupt
+ elif key == "r":
+ cap1.set(cv2.CAP_PROP_POS_FRAMES, 0)
+ cap2.set(cv2.CAP_PROP_POS_FRAMES, 0)
+ ret1, frame1 = cap1.read()
+ ret2, frame2 = cap2.read()
+ elif key == "1" or key == "2":
+ cv2.destroyAllWindows()
+ return 1.0 if key == "1" else 0.0
+
+ cv2.destroyAllWindows()
+ raise KeyboardInterrupt
+
+ def _display_videos_in_notebook(
+ self,
+ frag1_video_path: pathlib.Path,
+ frag2_video_path: pathlib.Path,
+ ) -> None:
+ """Displays the videos in a notebook.
+
+ Interaction can happen in the notebook while the videos are playing.
+
+ Args:
+ frag1_video_path: path to the video file of the first fragment
+ frag2_video_path: path to the video file of the second fragment
+
+ Raises:
+ RuntimeError: if the video files cannot be opened.
+ """
+ from IPython.display import HTML, Video, clear_output, display
+
+ clear_output(wait=True)
+
+ for i, path in enumerate([frag1_video_path, frag2_video_path]):
+ if not path.exists():
+ raise RuntimeError(f"Video file {path} does not exist.")
+ display(HTML(f"
Video {i+1}
"))
+ display(
+ Video(
+ filename=str(path),
+ height=self.video_height,
+ width=self.video_width,
+ html_attributes="controls autoplay muted",
+ embed=False,
+ ),
+ )
+
+ def _in_ipython(self) -> bool:
+ try:
+ return self._is_running_pytest_test() or get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[name-defined] # noqa
+ except NameError:
+ return False
+
+ @staticmethod
+ def _is_running_pytest_test() -> bool:
+ return "PYTEST_CURRENT_TEST" in os.environ
+
+
+class RESTGatherer(PreferenceGatherer):
+ """Gathers preferences from a REST web service.
+
+ The queries are gathered via GET request to `collection_service_address`/`query_id`
+ and returns a preference as json payload:
+
+ {
+ "label": float
+ }
+
+ The float value ranges from 0.0 (preferring left) to 1.0 (preferring right),
+ with 0.5 indicating indifference. -1.0 indicates that the query was incomparable.
+ """
+
+ def __init__(
+ self,
+ collection_service_address: str,
+ wait_for_user: bool = True,
+ rng: Optional[np.random.Generator] = None,
+ custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
+ querent_kwargs: Optional[Mapping] = None,
+ ) -> None:
+ """Initializes the preference gatherer.
+
+ Args:
+ collection_service_address: Network address of the collection service's REST interface.
+ wait_for_user: Waits for user to input their preferences.
+ rng: random number generator, if applicable.
+ custom_logger: Where to log to; if None (default), creates a new logger.
+ querent_kwargs: Keyword arguments passed to the querent.
+ """
+ super().__init__(rng, custom_logger)
+ querent_kwargs = querent_kwargs if querent_kwargs else {}
+ self.querent = RESTQuerent(
+ collection_service_address=collection_service_address,
+ **querent_kwargs,
)
- return np.array(rews1, dtype=np.float32), np.array(rews2, dtype=np.float32)
+ self.query_endpoint = collection_service_address
+
+ def _gather_preference(self, query_id: str) -> float:
+ query_url = urljoin(self.query_endpoint, query_id)
+ answered_query = requests.get(query_url).json()
+ return answered_query["label"]
+
+
+def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None:
+ """Removes rendered images of the provided trajectories list."""
+ for traj in trajectories:
+ if traj.infos is not None and "rendered_img" in traj.infos[0]:
+ for info in traj.infos:
+ rendered_img_info = info["rendered_img"]
+ if isinstance(rendered_img_info, (str, bytes, os.PathLike)):
+ os.remove(rendered_img_info)
+ del info["rendered_img"]
+ elif isinstance(rendered_img_info, np.ndarray):
+ del info["rendered_img"]
class PreferenceDataset(data_th.Dataset):
@@ -1516,7 +1994,7 @@ def __init__(
for which preferences will be gathered. These fragments could be random,
or they could be selected more deliberately (active learning).
Default is a random fragmenter.
- preference_gatherer: how to get preferences between trajectory fragments.
+ preference_gatherer: gathers preferences between trajectory fragments.
Default (and currently the only option) is to use synthetic preferences
based on ground-truth rewards. Human preferences could be implemented
here in the future.
@@ -1589,16 +2067,19 @@ def __init__(
reward_trainer,
)
+ # TODO: update messages with preference querent
if self.rng is None and has_any_rng_args_none:
raise ValueError(
"If you don't provide a random state, you must provide your own "
- "seeded fragmenter, preference gatherer, and reward_trainer. "
+ "seeded fragmenter, preference gatherer, "
+ "and reward_trainer. "
"You can initialize a random state with `np.random.default_rng(seed)`.",
)
elif self.rng is not None and not has_any_rng_args_none:
raise ValueError(
"If you provide your own fragmenter, preference gatherer, "
- "and reward trainer, you don't need to provide a random state.",
+ "and reward trainer, "
+ "you don't need to provide a random state.",
)
if reward_trainer is None:
@@ -1688,15 +2169,15 @@ def train(
reward_loss = None
reward_accuracy = None
- for i, num_pairs in enumerate(schedule):
+ for i, num_queries in enumerate(schedule):
##########################
# Gather new preferences #
##########################
num_steps = math.ceil(
- self.transition_oversampling * 2 * num_pairs * self.fragment_length,
+ self.transition_oversampling * 2 * num_queries * self.fragment_length,
)
self.logger.log(
- f"Collecting {2 * num_pairs} fragments ({num_steps} transitions)",
+ f"Collecting {2 * num_queries} fragments ({num_steps} transitions)",
)
trajectories = self.trajectory_generator.sample(num_steps)
# This assumes there are no fragments missing initial timesteps
@@ -1704,12 +2185,24 @@ def train(
horizons = (len(traj) for traj in trajectories if traj.terminal)
self._check_fixed_horizon(horizons)
self.logger.log("Creating fragment pairs")
- fragments = self.fragmenter(trajectories, self.fragment_length, num_pairs)
+
+ queries = self.fragmenter(trajectories, self.fragment_length, num_queries)
+
+ self.preference_gatherer.query(queries)
+
with self.logger.accumulate_means("preferences"):
self.logger.log("Gathering preferences")
- preferences = self.preference_gatherer(fragments)
- self.dataset.push(fragments, preferences)
+ # Gather fragment pairs for which preferences have been provided
+ queries, preferences = self.preference_gatherer.gather()
+
+ # Free up RAM or disk space from keeping rendered images
+ remove_rendered_images(trajectories)
+
+ self.dataset.push(queries, preferences)
self.logger.log(f"Dataset now contains {len(self.dataset)} comparisons")
+ # Skip training if dataset is empty
+ if len(self.dataset) == 0:
+ continue
##########################
# Train the reward model #
diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py
index 94c88111d..1b033c811 100644
--- a/src/imitation/data/wrappers.py
+++ b/src/imitation/data/wrappers.py
@@ -1,8 +1,14 @@
"""Environment wrappers for collecting rollouts."""
+import os
+import shutil
+import tempfile
+import uuid
from typing import List, Optional, Sequence, Tuple
+import cv2
import gymnasium as gym
+import imageio
import numpy as np
import numpy.typing as npt
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper
@@ -10,6 +16,73 @@
from imitation.data import rollout, types
+class RenderImageInfoWrapper(gym.Wrapper):
+ """Saves render images to `info`.
+
+ Can be very memory intensive for large render images.
+ Use `scale_factor` to reduce render image size.
+ If you need to preserve the resolution and memory
+ runs out, you can activate `use_file_cache` to save
+ rendered images and instead put their path into `info`.
+ """
+
+ def __init__(
+ self,
+ env: gym.Env,
+ scale_factor: float = 1.0,
+ use_file_cache: bool = False,
+ ):
+ """Builds RenderImageInfoWrapper.
+
+ Args:
+ env: Environment to wrap.
+ scale_factor: scales rendered images to be stored.
+ use_file_cache: whether to save rendered images to disk.
+ """
+ assert env.render_mode == "rgb_array", (
+ "The environment must be in render mode 'rgb_array' in order"
+ " to use this wrapper but render_mode is "
+ f"'{env.render_mode}'."
+ )
+ super().__init__(env)
+ self.scale_factor = scale_factor
+ self.use_file_cache = use_file_cache
+ if self.use_file_cache:
+ self.file_cache = tempfile.mkdtemp("imitation_RenderImageInfoWrapper")
+
+ def step(self, action):
+ observation, reward, terminated, truncated, info = self.env.step(action)
+
+ rendered_image = self.render()
+ # Scale the render image
+ scaled_size = (
+ int(self.scale_factor * rendered_image.shape[1]),
+ int(self.scale_factor * rendered_image.shape[0]),
+ )
+ scaled_rendered_image = cv2.resize(
+ rendered_image,
+ scaled_size,
+ interpolation=cv2.INTER_AREA,
+ )
+ # Store the render image
+ if not self.use_file_cache:
+ info["rendered_img"] = scaled_rendered_image
+ else:
+ unique_file_path = os.path.join(
+ self.file_cache,
+ str(uuid.uuid4()) + ".png",
+ )
+ imageio.imwrite(unique_file_path, scaled_rendered_image)
+ info["rendered_img"] = unique_file_path
+
+ return observation, reward, terminated, truncated, info
+
+ def close(self) -> None:
+ if self.use_file_cache:
+ shutil.rmtree(self.file_cache)
+ return super().close()
+
+
class BufferingWrapper(VecEnvWrapper):
"""Saves transitions of underlying VecEnv.
diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py
index 4d8531732..33a417cbb 100644
--- a/src/imitation/scripts/config/train_preference_comparisons.py
+++ b/src/imitation/scripts/config/train_preference_comparisons.py
@@ -4,6 +4,7 @@
from torch import nn
from imitation.algorithms import preference_comparisons
+from imitation.data.wrappers import RenderImageInfoWrapper
from imitation.scripts.ingredients import environment
from imitation.scripts.ingredients import logging as logging_ingredient
from imitation.scripts.ingredients import policy_evaluation, reward, rl
@@ -69,6 +70,51 @@ def train_defaults():
query_schedule = "hyperbolic"
+@train_preference_comparisons_ex.named_config
+def synch_human_preferences():
+ gatherer_cls = preference_comparisons.CommandLineGatherer
+ gatherer_kwargs = dict(video_dir="videos")
+ querent_cls = preference_comparisons.PreferenceQuerent
+ querent_kwargs = dict()
+ environment = dict(
+ post_wrappers=dict(
+ RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper(
+ env,
+ **kwargs,
+ ),
+ ),
+ num_vec=2,
+ post_wrappers_kwargs=dict(
+ RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True),
+ ),
+ )
+
+
+@train_preference_comparisons_ex.named_config
+def human_preferences():
+ gatherer_cls = preference_comparisons.RESTGatherer
+ gatherer_kwargs = dict(
+ collection_service_address="http://127.0.0.1:8000",
+ wait_for_user=True,
+ querent_kwargs=dict(
+ video_output_dir="./videos",
+ video_fps=20,
+ ),
+ )
+ environment = dict(
+ post_wrappers=dict(
+ RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper(
+ env,
+ **kwargs,
+ ),
+ ),
+ post_wrappers_kwargs=dict(
+ RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True),
+ ),
+ env_make_kwargs=dict(render_mode="rgb_array"),
+ )
+
+
@train_preference_comparisons_ex.named_config
def cartpole():
environment = dict(gym_id="CartPole-v1")
diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py
index 86e6f8d53..5bb68a9e3 100644
--- a/src/imitation/scripts/eval_policy.py
+++ b/src/imitation/scripts/eval_policy.py
@@ -94,7 +94,11 @@ def eval_policy(
"""
log_dir = logging_ingredient.make_log_dir()
sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes)
- post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None
+ post_wrappers = (
+ {"VideoWrapper": video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)}
+ if videos
+ else {}
+ )
render_mode = "rgb_array" if videos else None
with environment.make_venv( # type: ignore[wrong-keyword-args]
post_wrappers=post_wrappers,
diff --git a/src/imitation/scripts/ingredients/environment.py b/src/imitation/scripts/ingredients/environment.py
index 40f70d5c2..0c1f52a30 100644
--- a/src/imitation/scripts/ingredients/environment.py
+++ b/src/imitation/scripts/ingredients/environment.py
@@ -1,7 +1,9 @@
"""This ingredient provides a vectorized gym environment."""
import contextlib
-from typing import Any, Generator, Mapping
+import functools
+from typing import Any, Callable, Generator, Mapping
+import gymnasium as gym
import numpy as np
import sacred
from stable_baselines3.common import vec_env
@@ -19,6 +21,8 @@ def config():
max_episode_steps = None # Set to positive int to limit episode horizons
env_make_kwargs = {} # The kwargs passed to `spec.make`.
gym_id = "seals/CartPole-v0" # The environment to train on
+ post_wrappers = {} # Wrappers applied after `spec.make`
+ post_wrappers_kwargs = {} # The kwargs passed to post wrappers
locals() # quieten flake8
@@ -31,6 +35,8 @@ def make_venv(
parallel: bool,
max_episode_steps: int,
env_make_kwargs: Mapping[str, Any],
+ post_wrappers: Mapping[str, Callable[[gym.Env, int], gym.Env]],
+ post_wrappers_kwargs: Mapping[str, Mapping[str, Any]],
_run: sacred.run.Run,
_rnd: np.random.Generator,
**kwargs,
@@ -46,11 +52,20 @@ def make_venv(
environment to artificially limit the maximum number of timesteps in an
episode.
env_make_kwargs: The kwargs passed to `spec.make` of a gym environment.
+ post_wrappers: The wrappers applied after environment creation with `spec.make`.
+ post_wrappers_kwargs: List of kwargs passed to the respective post wrappers.
kwargs: Passed through to `util.make_vec_env`.
Yields:
The constructed vector environment.
"""
+ # Update env_fns for post wrappers with kwargs
+ updated_post_wrappers = []
+ for key, post_wrapper in post_wrappers.items():
+ if key in post_wrappers_kwargs:
+ post_wrapper = functools.partial(post_wrapper, **post_wrappers_kwargs[key])
+ updated_post_wrappers.append(post_wrapper)
+
# Note: we create the venv outside the try -- finally block for the case that env
# creation fails.
venv = util.make_vec_env(
@@ -61,6 +76,7 @@ def make_venv(
max_episode_steps=max_episode_steps,
log_dir=_run.config["logging"]["log_dir"] if "logging" in _run.config else None,
env_make_kwargs=env_make_kwargs,
+ post_wrappers=updated_post_wrappers,
**kwargs,
)
try:
diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py
index 1b5dfc028..c1764f6f8 100644
--- a/src/imitation/scripts/train_rl.py
+++ b/src/imitation/scripts/train_rl.py
@@ -98,7 +98,9 @@ def train_rl(
rollout_dir.mkdir(parents=True, exist_ok=True)
policy_dir.mkdir(parents=True, exist_ok=True)
- post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)]
+ post_wrappers = {
+ "RolloutInfoWrapper": lambda env, idx: wrappers.RolloutInfoWrapper(env),
+ }
with environment.make_venv( # type: ignore[wrong-keyword-args]
post_wrappers=post_wrappers,
) as venv:
diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py
index cf06c00c5..e250da5e3 100644
--- a/src/imitation/util/video_wrapper.py
+++ b/src/imitation/util/video_wrapper.py
@@ -8,6 +8,17 @@
from gymnasium.wrappers.monitoring import video_recorder
+def video_wrapper_factory(log_dir: pathlib.Path, **kwargs):
+ """Returns a function that wraps the environment in a video recorder."""
+
+ def f(env: gym.Env, i: int) -> VideoWrapper:
+ """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`."""
+ directory = log_dir / "videos" / str(i)
+ return VideoWrapper(env, directory=directory, **kwargs)
+
+ return f
+
+
class VideoWrapper(gym.Wrapper):
"""Creates videos from wrapped environment by calling render after each timestep."""
@@ -21,6 +32,7 @@ def __init__(
env: gym.Env,
directory: pathlib.Path,
single_video: bool = True,
+ delete_on_close: bool = True,
):
"""Builds a VideoWrapper.
@@ -32,11 +44,14 @@ def __init__(
Usually a single video file is what is desired. However, if one is
searching for an interesting episode (perhaps by looking at the
metadata), then saving to different files can be useful.
+ delete_on_close: if True, deletes the video file when the environment is
+ closed. If False, the video file is left on disk.
"""
super().__init__(env)
self.episode_id = 0
self.video_recorder = None
self.single_video = single_video
+ self.delete_on_close = delete_on_close
self.directory = directory
self.directory.mkdir(parents=True, exist_ok=True)
@@ -85,4 +100,7 @@ def close(self) -> None:
if self.video_recorder is not None:
self.video_recorder.close()
self.video_recorder = None
+ if self.delete_on_close:
+ for path in self.directory.glob("*.mp4"):
+ path.unlink()
super().close()
diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py
index 60c6dfc18..c5a72fd94 100644
--- a/tests/algorithms/test_preference_comparisons.py
+++ b/tests/algorithms/test_preference_comparisons.py
@@ -1,10 +1,19 @@
"""Tests for the preference comparisons reward learning implementation."""
+import base64
+import binascii
import math
+import os
+import pathlib
import re
-from typing import Any, Sequence
+import shutil
+import tempfile
+import uuid
+from typing import Any, Sequence, Tuple
+from unittest.mock import MagicMock, Mock, patch
import gymnasium as gym
+import imageio
import numpy as np
import pytest
import seals # noqa: F401
@@ -17,8 +26,17 @@
import imitation.testing.reward_nets as testing_reward_nets
from imitation.algorithms import preference_comparisons
+from imitation.algorithms.preference_comparisons import (
+ PreferenceGatherer,
+ PreferenceQuerent,
+ RESTGatherer,
+ RESTQuerent,
+ SyntheticGatherer,
+ VideoBasedQuerent,
+ remove_rendered_images,
+)
from imitation.data import types
-from imitation.data.types import TrajectoryWithRew
+from imitation.data.types import TrajectoryWithRew, TrajectoryWithRewPair
from imitation.regularization import regularizers, updaters
from imitation.rewards import reward_nets
from imitation.testing import reward_improvement
@@ -73,6 +91,24 @@ def agent_trainer(agent, reward_net, venv, rng):
return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng)
+@pytest.fixture
+def trajectory_with_rew(venv):
+ observations, rewards, _, infos, actions = [], [], [], [], []
+ observations.append(venv.observation_space.sample())
+ for _ in range(2):
+ observations.append(venv.observation_space.sample())
+ actions.append(venv.action_space.sample())
+ rewards.append(0.0)
+ infos.append({})
+ return TrajectoryWithRew(
+ obs=np.array(observations),
+ acts=np.array(actions),
+ rews=np.array(rewards),
+ infos=np.array(infos),
+ terminal=False,
+ )
+
+
def assert_info_arrs_equal(arr1, arr2): # pragma: no cover
def check_possibly_nested_dicts_equal(dict1, dict2):
for key, val1 in dict1.items():
@@ -239,14 +275,14 @@ def test_preference_comparisons_raises(
loss,
rng=rng,
)
+
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
- # no rng, must provide fragmenter, preference gatherer, reward trainer
no_rng_msg = (
".*don't provide.*random state.*provide.*fragmenter"
".*preference gatherer.*reward_trainer.*"
)
- def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng):
+ def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng):
preference_comparisons.PreferenceComparisons(
agent_trainer,
reward_net,
@@ -261,16 +297,16 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng):
)
with pytest.raises(ValueError, match=no_rng_msg):
- build_preference_comparsions(gatherer, None, None, rng=None)
+ build_preference_comparisons(gatherer, None, None, rng=None)
with pytest.raises(ValueError, match=no_rng_msg):
- build_preference_comparsions(None, reward_trainer, None, rng=None)
+ build_preference_comparisons(None, reward_trainer, None, rng=None)
with pytest.raises(ValueError, match=no_rng_msg):
- build_preference_comparsions(None, None, random_fragmenter, rng=None)
+ build_preference_comparisons(None, None, random_fragmenter, rng=None)
# This should not raise
- build_preference_comparsions(gatherer, reward_trainer, random_fragmenter, rng=None)
+ build_preference_comparisons(gatherer, reward_trainer, random_fragmenter, rng=None)
# if providing fragmenter, preference gatherer, reward trainer, does not need rng.
with_rng_msg = (
@@ -279,7 +315,7 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng):
)
with pytest.raises(ValueError, match=with_rng_msg):
- build_preference_comparsions(
+ build_preference_comparisons(
gatherer,
reward_trainer,
random_fragmenter,
@@ -287,10 +323,10 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng):
)
# This should not raise
- build_preference_comparsions(None, None, None, rng=rng)
- build_preference_comparsions(gatherer, None, None, rng=rng)
- build_preference_comparsions(None, reward_trainer, None, rng=rng)
- build_preference_comparsions(None, None, random_fragmenter, rng=rng)
+ build_preference_comparisons(None, None, None, rng=rng)
+ build_preference_comparisons(gatherer, None, None, rng=rng)
+ build_preference_comparisons(None, reward_trainer, None, rng=rng)
+ build_preference_comparisons(None, None, random_fragmenter, rng=rng)
@pytest.mark.parametrize(
@@ -485,7 +521,8 @@ def test_gradient_accumulation(
dataset = preference_comparisons.PreferenceDataset()
trajectory = agent_trainer.sample(num_trajectories)
fragments = random_fragmenter(trajectory, 1, num_trajectories)
- preferences = preference_gatherer(fragments)
+ preference_gatherer.query(fragments)
+ fragments, preferences = preference_gatherer.gather()
dataset.push(fragments, preferences)
seed = rng.integers(2**32)
@@ -529,8 +566,10 @@ def test_synthetic_gatherer_deterministic(
)
trajectories = agent_trainer.sample(10)
fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2)
- preferences1 = gatherer(fragments)
- preferences2 = gatherer(fragments)
+ gatherer.query(fragments)
+ _, preferences1 = gatherer.gather()
+ gatherer.query(fragments)
+ _, preferences2 = gatherer.gather()
assert np.all(preferences1 == preferences2)
@@ -608,7 +647,8 @@ def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng):
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
for i in range(6):
fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=1)
- preferences = gatherer(fragments)
+ gatherer.query(fragments)
+ fragments, preferences = gatherer.gather()
assert len(dataset) == min(i, 5)
dataset.push(fragments, preferences)
assert len(dataset) == min(i + 1, 5)
@@ -627,7 +667,8 @@ def test_store_and_load_preference_dataset(
trajectories = agent_trainer.sample(10)
fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2)
gatherer = preference_comparisons.SyntheticGatherer(rng=rng)
- preferences = gatherer(fragments)
+ gatherer.query(fragments)
+ fragments, preferences = gatherer.gather()
dataset.push(fragments, preferences)
path = tmp_path / "preferences.pkl"
@@ -1095,3 +1136,319 @@ def test_that_trainer_improves(
novice_agent_rewards,
trained_agent_rewards,
)
+
+
+# PreferenceQuerent
+def test_returns_query_dict_from_query_sequence_with_correct_length():
+ querent = PreferenceQuerent()
+ query_sequence = [Mock()]
+ query_dict = querent(query_sequence)
+ assert len(query_dict) == len(query_sequence)
+
+
+def test_returned_queries_have_uuid():
+ querent = PreferenceQuerent()
+ query_dict = querent([Mock()])
+
+ try:
+ key = list(query_dict.keys())[0]
+ uuid.UUID(key, version=4)
+ except ValueError:
+ pytest.fail()
+
+
+# PrefCollectQuerent
+def test_sends_put_request_for_each_query(requests_mock):
+ address = "https://test.de"
+ querent = RESTQuerent(collection_service_address=address, video_output_dir="video")
+ querent._load_video_data = MagicMock()
+ query_id = "1234"
+
+ requests_mock.put(f"{address}/{query_id}")
+ querent._query(query_id)
+
+ assert requests_mock.last_request.method == "PUT"
+
+
+@pytest.fixture
+def empty_trajectory_with_rew_and_render_images() -> TrajectoryWithRew:
+ num_frames = 10
+ frame_shape = (200, 200)
+ return types.TrajectoryWithRew(
+ obs=np.zeros((num_frames, *frame_shape, 3), np.uint8),
+ acts=np.zeros((num_frames - 1,), np.uint8),
+ infos=np.array(
+ [
+ {"rendered_img": np.zeros((*frame_shape, 3), np.uint8)}
+ for _ in range(num_frames - 1)
+ ],
+ ),
+ rews=np.zeros((num_frames - 1,)),
+ terminal=True,
+ )
+
+
+def is_base64(data):
+ """Checks if the data is base64 encoded."""
+ try:
+ base64.b64decode(data)
+ return True
+ except binascii.Error:
+ return False
+
+
+def test_load_video_data(empty_trajectory_with_rew_and_render_images):
+ address = "https://test.de"
+ video_dir = "video"
+ querent = RESTQuerent(
+ collection_service_address=address, video_output_dir=video_dir,
+ )
+
+ # Setup query with saved videos
+ query_id = "1234"
+ frames = querent._get_frames_for_each_observation(
+ empty_trajectory_with_rew_and_render_images,
+ )
+ output_path = querent._create_query_video_path(query_id, "left")
+ querent._write(frames, output_path, progress_logger=False)
+ output_path = querent._create_query_video_path(query_id, "right")
+ querent._write(frames, output_path, progress_logger=False)
+
+ # Load videos and check their encoding
+ video_data = querent._load_video_data(query_id)
+ for alternative in ("left", "right"):
+ assert is_base64(video_data[alternative])
+
+ # Check that loading video data of non-existent query fails
+ with pytest.raises(RuntimeError):
+ querent._load_video_data("0")
+
+ shutil.rmtree(video_dir)
+
+
+def test_prefcollectquerent_call_creates_all_videos(
+ empty_trajectory_with_rew_and_render_images,
+):
+ address = "https://test.de"
+ queries = [
+ (
+ empty_trajectory_with_rew_and_render_images,
+ empty_trajectory_with_rew_and_render_images,
+ ),
+ ]
+ querent = RESTQuerent(collection_service_address=address, video_output_dir="video")
+ identified_queries = querent(queries)
+ for query_id, _ in identified_queries.items():
+ file = os.path.join(querent.video_output_dir, query_id + "-{}.webm")
+ for part in ["left", "right"]:
+ assert os.path.isfile(file.format(part))
+ os.remove(file.format(part))
+
+
+@pytest.fixture(
+ params=["obs_only", "with_render_images", "with_render_image_paths"],
+)
+def fragment(request, empty_trajectory_with_rew_and_render_images):
+ obs = empty_trajectory_with_rew_and_render_images.obs
+ infos = empty_trajectory_with_rew_and_render_images.infos
+ if request.param == "with_render_images":
+ infos = np.array(
+ [
+ {"rendered_img": frame}
+ for frame in empty_trajectory_with_rew_and_render_images.obs[1:]
+ ],
+ )
+ elif request.param == "with_render_image_paths":
+ tmp_dir = tempfile.mkdtemp()
+ infos = []
+ for frame in obs[1:]:
+ unique_file_path = str(pathlib.Path(tmp_dir) / (str(uuid.uuid4()) + ".png"))
+ imageio.imwrite(unique_file_path, frame)
+ infos.append({"rendered_img": unique_file_path})
+ infos = np.array(infos)
+ yield types.TrajectoryWithRew(
+ obs=obs,
+ acts=empty_trajectory_with_rew_and_render_images.acts,
+ infos=infos,
+ terminal=True,
+ rews=empty_trajectory_with_rew_and_render_images.rews,
+ )
+ if request.param == "with_render_image_paths":
+ shutil.rmtree(tmp_dir)
+
+
+# utils
+@pytest.mark.parametrize("codec", ["webm", "mp4"])
+def test_write_fragment_video(fragment, codec):
+ output_dir = "video"
+ video_based_querent = VideoBasedQuerent(video_output_dir=output_dir)
+ video_path = pathlib.Path(f"{output_dir}.{codec}")
+ if "rendered_img" not in fragment.infos[0]:
+ with pytest.raises(ValueError):
+ video_based_querent._write_fragment_video(fragment, output_path=video_path)
+ else:
+ video_based_querent._write_fragment_video(fragment, output_path=video_path)
+ assert os.path.isfile(video_path)
+ os.remove(video_path)
+
+
+def test_remove_rendered_images(fragment):
+ trajectories = [fragment]
+ remove_rendered_images(trajectories)
+ assert not any("rendered_img" in info for trajectory in trajectories for info in trajectory.infos)
+
+
+# PreferenceGatherer
+class ConcretePreferenceGatherer(PreferenceGatherer):
+ """A concrete preference gatherer for unit testing purposes only."""
+
+ def _gather_preference(self, query_id: str) -> float:
+ return 0.
+
+
+def test_adds_queries_to_pending_queries():
+ gatherer = ConcretePreferenceGatherer()
+ query = Mock()
+ queries = [query]
+
+ gatherer.query(queries)
+ assert query in list(gatherer.pending_queries.values())
+
+
+def test_clears_pending_queries(trajectory_with_rew):
+ gatherer = SyntheticGatherer(sample=False)
+
+ queries = [(trajectory_with_rew, trajectory_with_rew)]
+ gatherer.query(queries)
+
+ gatherer.gather()
+
+ assert len(gatherer.pending_queries) == 0
+
+
+# PrefCollectGatherer
+def test_returns_none_for_unanswered_query(requests_mock):
+ address = "https://test.de"
+ query_id = "1234"
+ answer = None
+
+ gatherer = RESTGatherer(
+ collection_service_address=address,
+ querent_kwargs={"video_output_dir": "videos"},
+ )
+
+ requests_mock.get(
+ f"{address}/{query_id}",
+ json={"query_id": query_id, "label": answer},
+ )
+
+ preference = gatherer._gather_preference(query_id)
+
+ assert preference is answer
+
+
+def test_returns_preference_for_answered_query(requests_mock):
+ address = "https://test.de"
+ query_id = "1234"
+ answer = 1.0
+
+ gatherer = RESTGatherer(
+ collection_service_address=address,
+ querent_kwargs={"video_output_dir": "videos"},
+ )
+
+ requests_mock.get(
+ f"{address}/{query_id}",
+ json={"query_id": query_id, "label": answer},
+ )
+
+ preference = gatherer._gather_preference(query_id)
+
+ assert preference == answer
+
+
+def test_keeps_pending_query_for_unanswered_query():
+ gatherer = RESTGatherer(
+ collection_service_address="https://test.de",
+ wait_for_user=False,
+ querent_kwargs={"video_output_dir": "videos"},
+ )
+ gatherer._gather_preference = MagicMock(return_value=None)
+ gatherer.pending_queries = {"1234": Mock()}
+
+ pending_queries_pre = gatherer.pending_queries.copy()
+ gatherer.gather()
+
+ assert pending_queries_pre == gatherer.pending_queries
+
+
+def test_deletes_pending_query_for_answered_query():
+ gatherer = RESTGatherer(
+ collection_service_address="https://test.de",
+ wait_for_user=False,
+ querent_kwargs={"video_output_dir": "videos"},
+ )
+ preference = 0.5
+ gatherer._gather_preference = MagicMock(return_value=preference)
+ gatherer.pending_queries = {"1234": Mock()}
+
+ gatherer.gather()
+
+ assert len(gatherer.pending_queries) == 0
+
+
+def test_gathers_valid_preference():
+ gatherer = RESTGatherer(
+ collection_service_address="https://test.de",
+ wait_for_user=False,
+ querent_kwargs={"video_output_dir": "videos"},
+ )
+ preference = 0.5
+ gatherer._gather_preference = MagicMock(return_value=preference)
+ query = Mock()
+ gatherer.pending_queries = {"1234": query}
+
+ gathered_queries, gathered_preferences = gatherer.gather()
+
+ assert gathered_preferences[0] == preference
+ assert gathered_queries[0] == query
+
+
+def test_ignores_incomparable_answer():
+ gatherer = RESTGatherer(
+ collection_service_address="https://test.de",
+ wait_for_user=False,
+ querent_kwargs={"video_output_dir": "videos"},
+ )
+ gatherer._gather_preference = MagicMock(return_value=-1.0)
+ gatherer.pending_queries = {"1234": Mock()}
+
+ gathered_queries, gathered_preferences = gatherer.gather()
+
+ assert len(gathered_preferences) == 0
+ assert len(gathered_queries) == 0
+
+
+# SynchronousHumanGatherer
+@patch("builtins.input")
+@patch("IPython.display.display")
+def test_command_line_gatherer(mock_display, mock_input, fragment):
+ del mock_display # unused
+ video_dir = "videos"
+ gatherer = preference_comparisons.CommandLineGatherer(
+ video_dir=pathlib.Path(video_dir),
+ )
+
+ # these inputs are designed solely to pass the test. they aren't tested for anything
+ trajectory_pairs = [(fragment, fragment),]
+ gatherer.query(trajectory_pairs)
+
+ # this is the actual test
+ mock_input.return_value = "1"
+ assert gatherer.gather()[1] == np.array([1.0])
+
+ gatherer.query(trajectory_pairs)
+ mock_input.return_value = "2"
+ assert gatherer.gather()[1] == np.array([0.0])
+
+ shutil.rmtree(video_dir)
diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py
index 33677c68f..119ea609b 100644
--- a/tests/data/test_wrappers.py
+++ b/tests/data/test_wrappers.py
@@ -1,14 +1,16 @@
"""Tests for `imitation.data.wrappers`."""
+from pathlib import Path
from typing import List, Sequence, Type
import gymnasium as gym
+import imageio
import numpy as np
import pytest
from stable_baselines3.common.vec_env import DummyVecEnv
from imitation.data import types
-from imitation.data.wrappers import BufferingWrapper
+from imitation.data.wrappers import BufferingWrapper, RenderImageInfoWrapper
class _CountingEnv(gym.Env): # pragma: no cover
@@ -278,3 +280,44 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]):
assert venv.n_transitions == 0
with pytest.raises(RuntimeError, match=".* empty .*"):
venv.pop_transitions()
+
+
+@pytest.mark.parametrize("scale_factor", [0.1, 0.5, 1.0])
+def test_writes_rendered_img_to_info(scale_factor):
+ env = gym.make("CartPole-v0", render_mode="rgb_array")
+ wrapped_env = RenderImageInfoWrapper(env, scale_factor=scale_factor)
+ wrapped_env.reset()
+ rendered_img = wrapped_env.render()
+ _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample())
+ assert "rendered_img" in info
+ assert isinstance(info["rendered_img"], np.ndarray)
+ if scale_factor == 1.0:
+ assert np.allclose(info["rendered_img"], rendered_img)
+ assert int(scale_factor * rendered_img.shape[0]) == info["rendered_img"].shape[0]
+ assert int(scale_factor * rendered_img.shape[1]) == info["rendered_img"].shape[1]
+
+
+def test_raises_assertion_error_if_env_not_in_correct_render_mode():
+ wrong_mode = "human"
+ env = gym.make("CartPole-v0", render_mode=wrong_mode)
+
+ with pytest.raises(
+ AssertionError,
+ match="The environment must be in render mode 'rgb_array' "
+ "in order to use this wrapper "
+ f"but render_mode is '{wrong_mode}'.",
+ ):
+ RenderImageInfoWrapper(env)
+
+
+def test_rendered_img_file_cache():
+ env = gym.make("CartPole-v0", render_mode="rgb_array")
+ wrapped_env = RenderImageInfoWrapper(env, use_file_cache=True)
+ assert Path(wrapped_env.file_cache).exists()
+ wrapped_env.reset()
+ _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample())
+ rendered_img_path = info["rendered_img"]
+ assert Path(rendered_img_path).exists()
+ assert (imageio.imread(rendered_img_path) == wrapped_env.render()).all()
+ wrapped_env.close()
+ assert not Path(wrapped_env.file_cache).exists()
diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py
index ae39116e7..e4f0cade3 100644
--- a/tests/scripts/test_scripts.py
+++ b/tests/scripts/test_scripts.py
@@ -11,6 +11,7 @@
import pathlib
import pickle
import platform
+import re
import shutil
import sys
import tempfile
@@ -22,6 +23,7 @@
import pandas as pd
import pytest
import ray.tune as tune
+import requests_mock
import sacred
import sacred.utils
import stable_baselines3
@@ -168,6 +170,41 @@ def test_train_preference_comparisons_main(tmpdir, preference_comparison_config)
assert isinstance(run.result, dict)
+def response_callback(request, context):
+ return {"query_id": request.path.split("/")[0], "label": 1.0}
+
+
+def test_train_preference_comparisons_with_collected_preferences(tmpdir):
+ address = "http://localhost:8000"
+ config_updates = dict(
+ logging=dict(log_root=tmpdir),
+ environment=dict(env_make_kwargs=dict(render_mode="rgb_array")),
+ gatherer_kwargs=dict(
+ wait_for_user=False,
+ querent_kwargs=dict(video_output_dir=tmpdir),
+ collection_service_address=address,
+ ),
+ )
+
+ with requests_mock.Mocker() as m:
+ request_matcher = re.compile(f"{address}/")
+
+ m.put(url=request_matcher)
+ m.get(
+ url=request_matcher,
+ json=response_callback,
+ )
+
+ run = train_preference_comparisons.train_preference_comparisons_ex.run(
+ named_configs=["cartpole", "human_preferences"]
+ + ALGO_FAST_CONFIGS["preference_comparison"],
+ config_updates=config_updates,
+ )
+
+ assert run.status == "COMPLETED"
+ assert isinstance(run.result, dict)
+
+
@pytest.mark.parametrize(
"env_name",
["seals_cartpole", "mountain_car", "seals_mountain_car"],