diff --git a/ci/clean_notebooks.py b/ci/clean_notebooks.py index 2531757be..9b87aaea3 100755 --- a/ci/clean_notebooks.py +++ b/ci/clean_notebooks.py @@ -18,6 +18,7 @@ class UncleanNotebookError(Exception): "metadata": {"do": "constant", "value": dict()}, "source": {"do": "keep"}, "id": {"do": "keep"}, + "attachments": {"do": "constant", "value": {}}, } code_structure: Dict[str, Dict[str, Any]] = { @@ -76,7 +77,8 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None: if key not in structure[cell["cell_type"]]: if check_only: raise UncleanNotebookError( - f"Notebook {file} has unknown cell key {key}", + f"Notebook {file} has unknown cell key {key} for cell type " + + f"{cell['cell_type']}", ) del cell[key] was_dirty = True @@ -108,7 +110,12 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None: def parse_args(): - """Parse command-line arguments.""" + """Parse command-line arguments. + + Returns: + parser: The parser object. + args: The parsed arguments. + """ # if the argument --check has been passed, check if the notebooks are clean # otherwise, clean them in-place parser = argparse.ArgumentParser() @@ -125,7 +132,14 @@ def parse_args(): def get_files(input_paths: List): - """Build list of files to scan from list of paths and files.""" + """Build list of files to scan from list of paths and files. + + Args: + input_paths: List of paths and files to scan. + + Returns: + files: List of files to scan. + """ files = [] for file in input_paths: if file.is_dir(): diff --git a/docs/index.rst b/docs/index.rst index 3b3a9e1be..a0c061a48 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -86,6 +86,7 @@ If you use ``imitation`` in your research project, please cite our paper to help tutorials/4_train_airl tutorials/5_train_preference_comparisons tutorials/5a_train_preference_comparisons_with_cnn + tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback tutorials/6_train_mce tutorials/7_train_density tutorials/8_train_sqil diff --git a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb new file mode 100644 index 000000000..5f768784d --- /dev/null +++ b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb @@ -0,0 +1,251 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/5_train_preference_comparisons.ipynb)\n", + "# Learning a Reward Function using Preference Comparisons with Synchronous Human Feedback\n", + "\n", + "You can request human feedback via synchronous CLI or Notebook interactions as well. The setup is only slightly different than it would be with a synthetic preference gatherer." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's the starting setup. The major differences from the synthetic setup are indicated with comments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pathlib\n", + "import random\n", + "import tempfile\n", + "from imitation.algorithms import preference_comparisons\n", + "from imitation.rewards.reward_nets import BasicRewardNet\n", + "from imitation.util import video_wrapper\n", + "from imitation.util.networks import RunningNorm\n", + "from imitation.util.util import make_vec_env\n", + "from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor\n", + "import gym\n", + "from stable_baselines3 import PPO\n", + "import numpy as np\n", + "\n", + "# Add a temporary directory for video recordings of trajectories. Unfortunately Jupyter\n", + "# won't play videos outside the current directory, so we have to put them here. We'll\n", + "# delete them at the end of the script.\n", + "video_dir = tempfile.mkdtemp(dir=\".\", prefix=\"videos_\")\n", + "\n", + "rng = np.random.default_rng(0)\n", + "\n", + "# Add a video wrapper to the environment. This will record videos of the agent's\n", + "# trajectories so we can review them later.\n", + "venv = make_vec_env(\n", + " \"Pendulum-v1\",\n", + " rng=rng,\n", + " post_wrappers={\n", + " \"VideoWrapper\": video_wrapper.video_wrapper_factory(\n", + " pathlib.Path(video_dir), single_video=False\n", + " )\n", + " },\n", + ")\n", + "\n", + "reward_net = BasicRewardNet(\n", + " venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm\n", + ")\n", + "\n", + "fragmenter = preference_comparisons.RandomFragmenter(\n", + " warning_threshold=0,\n", + " rng=rng,\n", + ")\n", + "\n", + "querent = preference_comparisons.PreferenceQuerent()\n", + "\n", + "# This gatherer will show the user (you!) pairs of trajectories and ask it to choose\n", + "# which one is better. It will then use the user's feedback to train the reward network.\n", + "gatherer = preference_comparisons.SynchronousHumanGatherer(video_dir=video_dir)\n", + "\n", + "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", + "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", + " preference_model=preference_model,\n", + " loss=preference_comparisons.CrossEntropyRewardLoss(),\n", + " epochs=3,\n", + " rng=rng,\n", + ")\n", + "\n", + "agent = PPO(\n", + " policy=FeedForward32Policy,\n", + " policy_kwargs=dict(\n", + " features_extractor_class=NormalizeFeaturesExtractor,\n", + " features_extractor_kwargs=dict(normalize_class=RunningNorm),\n", + " ),\n", + " env=venv,\n", + " seed=0,\n", + " n_steps=2048 // venv.num_envs,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + ")\n", + "\n", + "trajectory_generator = preference_comparisons.AgentTrainer(\n", + " algorithm=agent,\n", + " reward_fn=reward_net,\n", + " venv=venv,\n", + " exploration_frac=0.0,\n", + " rng=rng,\n", + ")\n", + "\n", + "pref_comparisons = preference_comparisons.PreferenceComparisons(\n", + " trajectory_generator,\n", + " reward_net,\n", + " num_iterations=5,\n", + " fragmenter=fragmenter,\n", + " preference_querent=querent,\n", + " preference_gatherer=gatherer,\n", + " reward_trainer=reward_trainer,\n", + " fragment_length=100,\n", + " transition_oversampling=1,\n", + " initial_comparison_frac=0.1,\n", + " allow_variable_horizon=False,\n", + " initial_epoch_multiplier=1,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We're going to train with only 20 comparisons to make it faster for you to evaluate. The videos will appear in-line in this notebook for you to watch, and a text input will appear for you to choose one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pref_comparisons.train(\n", + " total_timesteps=5_000, # For good performance this should be 1_000_000\n", + " total_comparisons=20, # For good performance this should be 5_000\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "From this point onward, this notebook is the same as [the synthetic gatherer notebook](5_train_preference_comparisons.ipynb).\n", + "\n", + "After we trained the reward network using the preference comparisons algorithm, we can wrap our environment with that learned reward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n", + "\n", + "\n", + "learned_reward_venv = RewardVecEnvWrapper(venv, reward_net.predict)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can train an agent, that only sees those learned reward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3 import PPO\n", + "from stable_baselines3.ppo import MlpPolicy\n", + "\n", + "learner = PPO(\n", + " policy=MlpPolicy,\n", + " env=learned_reward_venv,\n", + " seed=0,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + " n_steps=64,\n", + ")\n", + "learner.learn(1000) # Note: set to 100000 to train a proficient expert" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we can evaluate it using the original reward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "\n", + "reward, _ = evaluate_policy(learner.policy, venv, 10)\n", + "print(reward)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# clean up the videos we made\n", + "import shutil\n", + "\n", + "shutil.rmtree(video_dir)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "439158cd89905785fcc749928062ade7bfccc3f087fab145e5671f895c635937" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 1aa407456..c62c5f7d5 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,8 @@ "wandb==0.12.21", "setuptools_scm~=7.0.5", "pre-commit>=2.20.0", + "types-requests~=2.31.0.1", + "requests-mock~=1.11.0", ] + PARALLEL_REQUIRE + ATARI_REQUIRE @@ -209,6 +211,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "huggingface_sb3~=3.0", "optuna>=3.0.1", "datasets>=2.8.0", + "opencv-python", # TODO: specify version ], tests_require=TESTS_REQUIRE, extras_require={ diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 14a8fad5b..d4500f8ea 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -3,10 +3,16 @@ Trains a reward model and optionally a policy based on preferences between trajectory fragments. """ + import abc +import base64 import math +import os +import pathlib import pickle import re +import uuid +from urllib.parse import urljoin from collections import defaultdict from typing import ( Any, @@ -24,8 +30,11 @@ overload, ) +import cv2 import numpy as np +import requests import torch as th +from moviepy.editor import ImageSequenceClip from scipy import special from stable_baselines3.common import base_class, type_aliases, utils, vec_env from torch import nn @@ -778,45 +787,276 @@ def variance_estimate(self, rews1: th.Tensor, rews2: th.Tensor) -> float: return var_estimate -class PreferenceGatherer(abc.ABC): - """Base class for gathering preference comparisons between trajectory fragments.""" +class PreferenceQuerent: + """Dummy class for querying preferences between trajectory fragments.""" def __init__( self, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ) -> None: + """Initializes the preference querent. + + Args: + rng: random number generator + custom_logger: Where to log to; if None (default), creates a new logger. + """ + self.logger = custom_logger or imit_logger.configure() + self.rng = rng + + def __call__( + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: + """Queries the user for their preferences. + + This dummy implementation does nothing because by default the queries are + answered by an oracle. + + Args: + queries: sequence of pairs of trajectory fragments + + Returns: + dictionary with queries and their respective UUIDs + """ + return {str(uuid.uuid4()): query for query in queries} + + +class VideoBasedQuerent(PreferenceQuerent): + """Writes videos for each query to the local file system for later use by child querent (and gatherer) classes.""" + + def __init__( + self, + video_output_dir: Union[str, os.PathLike], + video_type="webm", + video_fps: int = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ): + """Initializes the querent. + + Args: + video_output_dir: path to the video clip directory. + video_type: specifies the video format, e.g. 'webm' or 'mp4' + video_fps: frames per second of the generated videos. + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + """ + super().__init__(custom_logger=custom_logger) + self.rng = rng + self.video_output_dir = video_output_dir + self.video_type = video_type + self.frames_per_second = video_fps + + # Create video directory + os.makedirs(self.video_output_dir, exist_ok=True) + + def __call__( + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: + identified_queries = super().__call__(queries) + for query_id, query in identified_queries.items(): + self._write_query_videos(query_id, query) + return identified_queries + + def _write_query_videos(self, query_id, query): + for i, alternative in enumerate(("left", "right")): + self._write_fragment_video( + fragment=query[i], + output_path=self._create_query_video_path(query_id, alternative), + ) + + def _create_query_video_path(self, query_id: str, alternative: str): + return ( + pathlib.Path(self.video_output_dir) + / f"{query_id}-{alternative}.{self.video_type}" + ) + + def _write_fragment_video( + self, + fragment: TrajectoryWithRew, + output_path: AnyPath, + progress_logger: bool = True, + ) -> None: + """Write fragment video clip.""" + frames = self._get_frames(fragment) + self._write(frames, output_path, progress_logger) + + def _get_frames( + self, fragment: TrajectoryWithRew, + ) -> list[Union[os.PathLike, np.ndarray]]: + if self._rendered_image_of_observation_is_available(fragment): + return self._get_frames_for_each_observation(fragment) + else: + raise ValueError( + "No rendered images contained in info dict. " + "Please apply `RenderImageWrapper` to your environment.", + ) + + @staticmethod + def _rendered_image_of_observation_is_available( + fragment: TrajectoryWithRew, + ) -> bool: + return fragment.infos is not None and "rendered_img" in fragment.infos[0] + + @staticmethod + def _get_frames_for_each_observation(fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]: + frames: list[Union[os.PathLike, np.ndarray]] = [] + for i in range(len(fragment.infos)): + frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] + frames.append(frame) + return frames + + def _write( + self, + frames: List[Union[os.PathLike, np.ndarray]], + output_path: Union[str, bytes, os.PathLike], + progress_logger: bool, + ): + clip = ImageSequenceClip( + frames, fps=self.frames_per_second, + ) # accepts list of image paths and numpy arrays + output_path = pathlib.Path(output_path) + if output_path.suffix == ".gif": + clip.write_gif( + str(output_path), + program="ffmpeg", + logger="bar" if progress_logger else None, + ) + else: + clip.write_videofile( + str(output_path), logger="bar" if progress_logger else None, + ) + + +class RESTQuerent(VideoBasedQuerent): + """Sends queries to a REST web service. + + The queries are sent via PUT request as json payload to `collection_service_address`/`query_id` + in the following form: + + { + "uuid": "1234", + "left": b64 encoded video, + "right" b64 encoded video + } + + """ + + def __init__( + self, + collection_service_address: str, + video_output_dir: str, + video_fps: int = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ): + """Initializes the querent. + + Args: + collection_service_address: Network address of the collection service's REST interface. + video_output_dir: path to the video clip directory. + video_fps: frames per second of the generated videos. + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + """ + super().__init__( + video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger, + ) + self.query_endpoint = collection_service_address + + def __call__( + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: + identified_queries = super().__call__(queries) + for query_id, query in identified_queries.items(): + self._query(query_id) + return identified_queries + + def _query(self, query_id: str): + video_data = self._load_video_data(query_id) + requests.put( + urljoin(self.query_endpoint, query_id), + json={"uuid": "{}".format(query_id), **video_data}, + ) + + def _load_video_data(self, query_id: str) -> dict[str, str]: + video_data = {} + for alternative in ("left", "right"): + video_path = self._create_query_video_path(query_id, alternative) + if video_path.exists(): + with open(video_path, "rb") as video_file: + video_data[alternative] = base64.b64encode( + video_file.read(), + ).decode("utf-8") + else: + raise RuntimeError( + f"Video to be loaded does not exist at {video_path}.", + ) + return video_data + + +class PreferenceGatherer(abc.ABC): + """Base class for gathering preference comparisons between trajectory fragments.""" + + def __init__(self, rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None) -> None: """Initializes the preference gatherer. Args: rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. + querent_kwargs: Keyword arguments passed to the querent. """ # The random seed isn't used here, but it's useful to have this # as an argument nevertheless because that means we can always # pass in a seed in training scripts (without worrying about whether # the PreferenceGatherer we use needs one). - del rng + self.querent = PreferenceQuerent(rng, custom_logger) self.logger = custom_logger or imit_logger.configure() + self.pending_queries: Dict = {} - @abc.abstractmethod - def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: - """Gathers the probabilities that fragment 1 is preferred in `fragment_pairs`. - - Args: - fragment_pairs: sequence of pairs of trajectory fragments + def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + """Gathers preference probabilities for completed queries. Returns: - A numpy array with shape (b, ), where b is the length of the input - (i.e. batch size). Each item in the array is the probability that - fragment 1 is preferred over fragment 2 for the corresponding - pair of fragments. + * A list of length b with queries for which preferences have been gathered. + + * A numpy array with shape (b, ). + Each item in the array is the probability that fragment 1 is preferred + over fragment 2 for the corresponding query in the list above. Note that for human feedback, these probabilities are simply 0 or 1 (or 0.5 in case of indifference), but synthetic models may yield other probabilities. """ # noqa: DAR202 + gathered_queries = [] + gathered_preferences = [] + + for query_id, query in list(self.pending_queries.items()): + preference = self._gather_preference(query_id) + + if preference is not None: + # Preference for this query has been provided + if 0 <= preference <= 1: + gathered_queries.append(query) + gathered_preferences.append(preference) + # else: fragments were incomparable + del self.pending_queries[query_id] + + return gathered_queries, np.array(gathered_preferences, dtype=np.float32) + + @abc.abstractmethod + def _gather_preference(self, query_id: str) -> float: + raise NotImplementedError + + def query(self, queries: Sequence[TrajectoryWithRewPair]) -> None: + identified_queries = self.querent(queries) + self.pending_queries = {**self.pending_queries, **identified_queries} + class SyntheticGatherer(PreferenceGatherer): """Computes synthetic preferences using ground-truth environment rewards.""" @@ -865,45 +1105,283 @@ def __init__( if self.sample and self.rng is None: raise ValueError("If `sample` is True, then `rng` must be provided.") - def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: - """Computes probability fragment 1 is preferred over fragment 2.""" - returns1, returns2 = self._reward_sums(fragment_pairs) + def _gather_preference(self, query_id): + query = self.pending_queries[query_id] + + return_1, return_2 = ( + np.array(rollout.discounted_sum(query[0].rews, self.discount_factor), dtype=np.float32), + np.array(rollout.discounted_sum(query[1].rews, self.discount_factor), dtype=np.float32) + ) + if self.temperature == 0: - return (np.sign(returns1 - returns2) + 1) / 2 + return (np.sign(return_1 - return_2) + 1) / 2 - returns1 /= self.temperature - returns2 /= self.temperature + return_1 /= self.temperature + return_2 /= self.temperature # clip the returns to avoid overflows in the softmax below - returns_diff = np.clip(returns2 - returns1, -self.threshold, self.threshold) + returns_diff = np.clip(return_2 - return_1, -self.threshold, self.threshold) # Instead of computing exp(rews1) / (exp(rews1) + exp(rews2)) directly, # we divide enumerator and denominator by exp(rews1) to prevent overflows: - model_probs = 1 / (1 + np.exp(returns_diff)) + choice_probability = 1 / (1 + np.exp(returns_diff)) # Compute the mean binary entropy. This metric helps estimate # how good we can expect the performance of the learned reward # model to be at predicting preferences. entropy = -( - special.xlogy(model_probs, model_probs) - + special.xlogy(1 - model_probs, 1 - model_probs) + special.xlogy(choice_probability, choice_probability) + + special.xlogy(1 - choice_probability, 1 - choice_probability) ).mean() self.logger.record("entropy", entropy) if self.sample: assert self.rng is not None - return self.rng.binomial(n=1, p=model_probs).astype(np.float32) - return model_probs - - def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]: - rews1, rews2 = zip( - *[ - ( - rollout.discounted_sum(f1.rews, self.discount_factor), - rollout.discounted_sum(f2.rews, self.discount_factor), - ) - for f1, f2 in fragment_pairs - ], + return self.rng.binomial(n=1, p=choice_probability) + + return choice_probability + + +class CommandLineGatherer(PreferenceGatherer): + """Queries for human preferences using the command line or a notebook.""" + + def __init__( + self, + video_dir: Union[str, os.PathLike], + video_width: int = 500, + video_height: int = 500, + frames_per_second: int = 25, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + rng: Optional[np.random.Generator] = None, + ) -> None: + """Initialize the human preference gatherer. + + Args: + video_dir: directory where videos of the trajectories are saved. + video_width: width of the video in pixels. + video_height: height of the video in pixels. + frames_per_second: frames per second of the video. + custom_logger: Where to log to; if None (default), creates a new logger. + rng: random number generator + """ + super().__init__(rng=rng, custom_logger=custom_logger) + self.querent = VideoBasedQuerent( + video_output_dir=video_dir, video_fps=frames_per_second, + ) + self.video_dir = video_dir + self.video_width = video_width + self.video_height = video_height + + def _gather_preference(self, query_id): + """Displays the videos of the two fragments. + If in the command line, it will pop out a video player for each fragment. + If in a notebook, it will display the videos. + Either way, it will request 1 or 2 to indicate which is preferred. + + Args: + query_id: the id of the fragment pair to be displayed. + + Returns: + True if the first fragment is preferred, False if not. + + Raises: + KeyboardInterrupt: if the user presses q to quit. + """ + frag1_video_path = pathlib.Path(self.video_dir, f"{query_id}-left.webm") + frag2_video_path = pathlib.Path(self.video_dir, f"{query_id}-right.webm") + if self._in_ipython(): + self._display_videos_in_notebook(frag1_video_path, frag2_video_path) + + pref = input( + "Which video is preferred? (1 or 2, or q to quit, or r to replay): ", + ) + while pref not in ["1", "2", "q"]: + if pref == "r": + self._display_videos_in_notebook(frag1_video_path, frag2_video_path) + pref = input("Please enter 1 or 2 or q or r: ") + + if pref == "q": + return None + elif pref == "1": + return 1. + elif pref == "2": + return 0. + + # should never be hit + assert False + else: + return self._display_in_windows(frag1_video_path, frag2_video_path) + + def _display_in_windows( + self, + frag1_video_path: pathlib.Path, + frag2_video_path: pathlib.Path, + ) -> bool: + """Displays the videos in separate windows. + + The videos are displayed side by side and the user is asked to indicate + which one is preferred. The interaction is done in the window rather than + in the command line because the command line is not interactive when + the video is playing, and it's nice to allow the user to choose a video before + the videos are done playing. The downside is that the instructions appear on the + command line and the interaction happens in the video window. + + Args: + frag1_video_path: path to the video file of the first fragment + frag2_video_path: path to the video file of the second fragment + + Returns: + True if the first fragment is preferred, False if not. + + Raises: + KeyboardInterrupt: if the user presses q to quit. + RuntimeError: if the video files cannot be opened. + """ + print("Which video is preferred? (1 or 2, or q to quit, or r to replay):\n") + + cap1 = cv2.VideoCapture(str(frag1_video_path)) + cap2 = cv2.VideoCapture(str(frag2_video_path)) + cv2.namedWindow("Video 1", cv2.WINDOW_NORMAL) + cv2.namedWindow("Video 2", cv2.WINDOW_NORMAL) + + # set window sizes + cv2.resizeWindow("Video 1", self.video_width, self.video_height) + cv2.resizeWindow("Video 2", self.video_width, self.video_height) + + # move windows side by side + cv2.moveWindow("Video 1", 0, 0) + cv2.moveWindow("Video 2", self.video_width, 0) + + if not cap1.isOpened(): + raise RuntimeError(f"Error opening video file {frag1_video_path}.") + + if not cap2.isOpened(): + raise RuntimeError(f"Error opening video file {frag2_video_path}.") + + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + while cap1.isOpened() and cap2.isOpened(): + if ret1 or ret2: + cv2.imshow("Video 1", frame1) + cv2.imshow("Video 2", frame2) + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + + key = chr(cv2.waitKey(1) & 0xFF) + if key == "q": + cv2.destroyAllWindows() + raise KeyboardInterrupt + elif key == "r": + cap1.set(cv2.CAP_PROP_POS_FRAMES, 0) + cap2.set(cv2.CAP_PROP_POS_FRAMES, 0) + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + elif key == "1" or key == "2": + cv2.destroyAllWindows() + return 1.0 if key == "1" else 0.0 + + cv2.destroyAllWindows() + raise KeyboardInterrupt + + def _display_videos_in_notebook( + self, + frag1_video_path: pathlib.Path, + frag2_video_path: pathlib.Path, + ) -> None: + """Displays the videos in a notebook. + + Interaction can happen in the notebook while the videos are playing. + + Args: + frag1_video_path: path to the video file of the first fragment + frag2_video_path: path to the video file of the second fragment + + Raises: + RuntimeError: if the video files cannot be opened. + """ + from IPython.display import HTML, Video, clear_output, display + + clear_output(wait=True) + + for i, path in enumerate([frag1_video_path, frag2_video_path]): + if not path.exists(): + raise RuntimeError(f"Video file {path} does not exist.") + display(HTML(f"

Video {i+1}

")) + display( + Video( + filename=str(path), + height=self.video_height, + width=self.video_width, + html_attributes="controls autoplay muted", + embed=False, + ), + ) + + def _in_ipython(self) -> bool: + try: + return self._is_running_pytest_test() or get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[name-defined] # noqa + except NameError: + return False + + @staticmethod + def _is_running_pytest_test() -> bool: + return "PYTEST_CURRENT_TEST" in os.environ + + +class RESTGatherer(PreferenceGatherer): + """Gathers preferences from a REST web service. + + The queries are gathered via GET request to `collection_service_address`/`query_id` + and returns a preference as json payload: + + { + "label": float + } + + The float value ranges from 0.0 (preferring left) to 1.0 (preferring right), + with 0.5 indicating indifference. -1.0 indicates that the query was incomparable. + """ + + def __init__( + self, + collection_service_address: str, + wait_for_user: bool = True, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None, + ) -> None: + """Initializes the preference gatherer. + + Args: + collection_service_address: Network address of the collection service's REST interface. + wait_for_user: Waits for user to input their preferences. + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + querent_kwargs: Keyword arguments passed to the querent. + """ + super().__init__(rng, custom_logger) + querent_kwargs = querent_kwargs if querent_kwargs else {} + self.querent = RESTQuerent( + collection_service_address=collection_service_address, + **querent_kwargs, ) - return np.array(rews1, dtype=np.float32), np.array(rews2, dtype=np.float32) + self.query_endpoint = collection_service_address + + def _gather_preference(self, query_id: str) -> float: + query_url = urljoin(self.query_endpoint, query_id) + answered_query = requests.get(query_url).json() + return answered_query["label"] + + +def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: + """Removes rendered images of the provided trajectories list.""" + for traj in trajectories: + if traj.infos is not None and "rendered_img" in traj.infos[0]: + for info in traj.infos: + rendered_img_info = info["rendered_img"] + if isinstance(rendered_img_info, (str, bytes, os.PathLike)): + os.remove(rendered_img_info) + del info["rendered_img"] + elif isinstance(rendered_img_info, np.ndarray): + del info["rendered_img"] class PreferenceDataset(data_th.Dataset): @@ -1516,7 +1994,7 @@ def __init__( for which preferences will be gathered. These fragments could be random, or they could be selected more deliberately (active learning). Default is a random fragmenter. - preference_gatherer: how to get preferences between trajectory fragments. + preference_gatherer: gathers preferences between trajectory fragments. Default (and currently the only option) is to use synthetic preferences based on ground-truth rewards. Human preferences could be implemented here in the future. @@ -1589,16 +2067,19 @@ def __init__( reward_trainer, ) + # TODO: update messages with preference querent if self.rng is None and has_any_rng_args_none: raise ValueError( "If you don't provide a random state, you must provide your own " - "seeded fragmenter, preference gatherer, and reward_trainer. " + "seeded fragmenter, preference gatherer, " + "and reward_trainer. " "You can initialize a random state with `np.random.default_rng(seed)`.", ) elif self.rng is not None and not has_any_rng_args_none: raise ValueError( "If you provide your own fragmenter, preference gatherer, " - "and reward trainer, you don't need to provide a random state.", + "and reward trainer, " + "you don't need to provide a random state.", ) if reward_trainer is None: @@ -1688,15 +2169,15 @@ def train( reward_loss = None reward_accuracy = None - for i, num_pairs in enumerate(schedule): + for i, num_queries in enumerate(schedule): ########################## # Gather new preferences # ########################## num_steps = math.ceil( - self.transition_oversampling * 2 * num_pairs * self.fragment_length, + self.transition_oversampling * 2 * num_queries * self.fragment_length, ) self.logger.log( - f"Collecting {2 * num_pairs} fragments ({num_steps} transitions)", + f"Collecting {2 * num_queries} fragments ({num_steps} transitions)", ) trajectories = self.trajectory_generator.sample(num_steps) # This assumes there are no fragments missing initial timesteps @@ -1704,12 +2185,24 @@ def train( horizons = (len(traj) for traj in trajectories if traj.terminal) self._check_fixed_horizon(horizons) self.logger.log("Creating fragment pairs") - fragments = self.fragmenter(trajectories, self.fragment_length, num_pairs) + + queries = self.fragmenter(trajectories, self.fragment_length, num_queries) + + self.preference_gatherer.query(queries) + with self.logger.accumulate_means("preferences"): self.logger.log("Gathering preferences") - preferences = self.preference_gatherer(fragments) - self.dataset.push(fragments, preferences) + # Gather fragment pairs for which preferences have been provided + queries, preferences = self.preference_gatherer.gather() + + # Free up RAM or disk space from keeping rendered images + remove_rendered_images(trajectories) + + self.dataset.push(queries, preferences) self.logger.log(f"Dataset now contains {len(self.dataset)} comparisons") + # Skip training if dataset is empty + if len(self.dataset) == 0: + continue ########################## # Train the reward model # diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 94c88111d..1b033c811 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -1,8 +1,14 @@ """Environment wrappers for collecting rollouts.""" +import os +import shutil +import tempfile +import uuid from typing import List, Optional, Sequence, Tuple +import cv2 import gymnasium as gym +import imageio import numpy as np import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper @@ -10,6 +16,73 @@ from imitation.data import rollout, types +class RenderImageInfoWrapper(gym.Wrapper): + """Saves render images to `info`. + + Can be very memory intensive for large render images. + Use `scale_factor` to reduce render image size. + If you need to preserve the resolution and memory + runs out, you can activate `use_file_cache` to save + rendered images and instead put their path into `info`. + """ + + def __init__( + self, + env: gym.Env, + scale_factor: float = 1.0, + use_file_cache: bool = False, + ): + """Builds RenderImageInfoWrapper. + + Args: + env: Environment to wrap. + scale_factor: scales rendered images to be stored. + use_file_cache: whether to save rendered images to disk. + """ + assert env.render_mode == "rgb_array", ( + "The environment must be in render mode 'rgb_array' in order" + " to use this wrapper but render_mode is " + f"'{env.render_mode}'." + ) + super().__init__(env) + self.scale_factor = scale_factor + self.use_file_cache = use_file_cache + if self.use_file_cache: + self.file_cache = tempfile.mkdtemp("imitation_RenderImageInfoWrapper") + + def step(self, action): + observation, reward, terminated, truncated, info = self.env.step(action) + + rendered_image = self.render() + # Scale the render image + scaled_size = ( + int(self.scale_factor * rendered_image.shape[1]), + int(self.scale_factor * rendered_image.shape[0]), + ) + scaled_rendered_image = cv2.resize( + rendered_image, + scaled_size, + interpolation=cv2.INTER_AREA, + ) + # Store the render image + if not self.use_file_cache: + info["rendered_img"] = scaled_rendered_image + else: + unique_file_path = os.path.join( + self.file_cache, + str(uuid.uuid4()) + ".png", + ) + imageio.imwrite(unique_file_path, scaled_rendered_image) + info["rendered_img"] = unique_file_path + + return observation, reward, terminated, truncated, info + + def close(self) -> None: + if self.use_file_cache: + shutil.rmtree(self.file_cache) + return super().close() + + class BufferingWrapper(VecEnvWrapper): """Saves transitions of underlying VecEnv. diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 4d8531732..33a417cbb 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -4,6 +4,7 @@ from torch import nn from imitation.algorithms import preference_comparisons +from imitation.data.wrappers import RenderImageInfoWrapper from imitation.scripts.ingredients import environment from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, reward, rl @@ -69,6 +70,51 @@ def train_defaults(): query_schedule = "hyperbolic" +@train_preference_comparisons_ex.named_config +def synch_human_preferences(): + gatherer_cls = preference_comparisons.CommandLineGatherer + gatherer_kwargs = dict(video_dir="videos") + querent_cls = preference_comparisons.PreferenceQuerent + querent_kwargs = dict() + environment = dict( + post_wrappers=dict( + RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( + env, + **kwargs, + ), + ), + num_vec=2, + post_wrappers_kwargs=dict( + RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True), + ), + ) + + +@train_preference_comparisons_ex.named_config +def human_preferences(): + gatherer_cls = preference_comparisons.RESTGatherer + gatherer_kwargs = dict( + collection_service_address="http://127.0.0.1:8000", + wait_for_user=True, + querent_kwargs=dict( + video_output_dir="./videos", + video_fps=20, + ), + ) + environment = dict( + post_wrappers=dict( + RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( + env, + **kwargs, + ), + ), + post_wrappers_kwargs=dict( + RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True), + ), + env_make_kwargs=dict(render_mode="rgb_array"), + ) + + @train_preference_comparisons_ex.named_config def cartpole(): environment = dict(gym_id="CartPole-v1") diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 86e6f8d53..5bb68a9e3 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -94,7 +94,11 @@ def eval_policy( """ log_dir = logging_ingredient.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) - post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None + post_wrappers = ( + {"VideoWrapper": video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)} + if videos + else {} + ) render_mode = "rgb_array" if videos else None with environment.make_venv( # type: ignore[wrong-keyword-args] post_wrappers=post_wrappers, diff --git a/src/imitation/scripts/ingredients/environment.py b/src/imitation/scripts/ingredients/environment.py index 40f70d5c2..0c1f52a30 100644 --- a/src/imitation/scripts/ingredients/environment.py +++ b/src/imitation/scripts/ingredients/environment.py @@ -1,7 +1,9 @@ """This ingredient provides a vectorized gym environment.""" import contextlib -from typing import Any, Generator, Mapping +import functools +from typing import Any, Callable, Generator, Mapping +import gymnasium as gym import numpy as np import sacred from stable_baselines3.common import vec_env @@ -19,6 +21,8 @@ def config(): max_episode_steps = None # Set to positive int to limit episode horizons env_make_kwargs = {} # The kwargs passed to `spec.make`. gym_id = "seals/CartPole-v0" # The environment to train on + post_wrappers = {} # Wrappers applied after `spec.make` + post_wrappers_kwargs = {} # The kwargs passed to post wrappers locals() # quieten flake8 @@ -31,6 +35,8 @@ def make_venv( parallel: bool, max_episode_steps: int, env_make_kwargs: Mapping[str, Any], + post_wrappers: Mapping[str, Callable[[gym.Env, int], gym.Env]], + post_wrappers_kwargs: Mapping[str, Mapping[str, Any]], _run: sacred.run.Run, _rnd: np.random.Generator, **kwargs, @@ -46,11 +52,20 @@ def make_venv( environment to artificially limit the maximum number of timesteps in an episode. env_make_kwargs: The kwargs passed to `spec.make` of a gym environment. + post_wrappers: The wrappers applied after environment creation with `spec.make`. + post_wrappers_kwargs: List of kwargs passed to the respective post wrappers. kwargs: Passed through to `util.make_vec_env`. Yields: The constructed vector environment. """ + # Update env_fns for post wrappers with kwargs + updated_post_wrappers = [] + for key, post_wrapper in post_wrappers.items(): + if key in post_wrappers_kwargs: + post_wrapper = functools.partial(post_wrapper, **post_wrappers_kwargs[key]) + updated_post_wrappers.append(post_wrapper) + # Note: we create the venv outside the try -- finally block for the case that env # creation fails. venv = util.make_vec_env( @@ -61,6 +76,7 @@ def make_venv( max_episode_steps=max_episode_steps, log_dir=_run.config["logging"]["log_dir"] if "logging" in _run.config else None, env_make_kwargs=env_make_kwargs, + post_wrappers=updated_post_wrappers, **kwargs, ) try: diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 1b5dfc028..c1764f6f8 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -98,7 +98,9 @@ def train_rl( rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) - post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] + post_wrappers = { + "RolloutInfoWrapper": lambda env, idx: wrappers.RolloutInfoWrapper(env), + } with environment.make_venv( # type: ignore[wrong-keyword-args] post_wrappers=post_wrappers, ) as venv: diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index cf06c00c5..e250da5e3 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -8,6 +8,17 @@ from gymnasium.wrappers.monitoring import video_recorder +def video_wrapper_factory(log_dir: pathlib.Path, **kwargs): + """Returns a function that wraps the environment in a video recorder.""" + + def f(env: gym.Env, i: int) -> VideoWrapper: + """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" + directory = log_dir / "videos" / str(i) + return VideoWrapper(env, directory=directory, **kwargs) + + return f + + class VideoWrapper(gym.Wrapper): """Creates videos from wrapped environment by calling render after each timestep.""" @@ -21,6 +32,7 @@ def __init__( env: gym.Env, directory: pathlib.Path, single_video: bool = True, + delete_on_close: bool = True, ): """Builds a VideoWrapper. @@ -32,11 +44,14 @@ def __init__( Usually a single video file is what is desired. However, if one is searching for an interesting episode (perhaps by looking at the metadata), then saving to different files can be useful. + delete_on_close: if True, deletes the video file when the environment is + closed. If False, the video file is left on disk. """ super().__init__(env) self.episode_id = 0 self.video_recorder = None self.single_video = single_video + self.delete_on_close = delete_on_close self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) @@ -85,4 +100,7 @@ def close(self) -> None: if self.video_recorder is not None: self.video_recorder.close() self.video_recorder = None + if self.delete_on_close: + for path in self.directory.glob("*.mp4"): + path.unlink() super().close() diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 60c6dfc18..c5a72fd94 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1,10 +1,19 @@ """Tests for the preference comparisons reward learning implementation.""" +import base64 +import binascii import math +import os +import pathlib import re -from typing import Any, Sequence +import shutil +import tempfile +import uuid +from typing import Any, Sequence, Tuple +from unittest.mock import MagicMock, Mock, patch import gymnasium as gym +import imageio import numpy as np import pytest import seals # noqa: F401 @@ -17,8 +26,17 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons +from imitation.algorithms.preference_comparisons import ( + PreferenceGatherer, + PreferenceQuerent, + RESTGatherer, + RESTQuerent, + SyntheticGatherer, + VideoBasedQuerent, + remove_rendered_images, +) from imitation.data import types -from imitation.data.types import TrajectoryWithRew +from imitation.data.types import TrajectoryWithRew, TrajectoryWithRewPair from imitation.regularization import regularizers, updaters from imitation.rewards import reward_nets from imitation.testing import reward_improvement @@ -73,6 +91,24 @@ def agent_trainer(agent, reward_net, venv, rng): return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng) +@pytest.fixture +def trajectory_with_rew(venv): + observations, rewards, _, infos, actions = [], [], [], [], [] + observations.append(venv.observation_space.sample()) + for _ in range(2): + observations.append(venv.observation_space.sample()) + actions.append(venv.action_space.sample()) + rewards.append(0.0) + infos.append({}) + return TrajectoryWithRew( + obs=np.array(observations), + acts=np.array(actions), + rews=np.array(rewards), + infos=np.array(infos), + terminal=False, + ) + + def assert_info_arrs_equal(arr1, arr2): # pragma: no cover def check_possibly_nested_dicts_equal(dict1, dict2): for key, val1 in dict1.items(): @@ -239,14 +275,14 @@ def test_preference_comparisons_raises( loss, rng=rng, ) + gatherer = preference_comparisons.SyntheticGatherer(rng=rng) - # no rng, must provide fragmenter, preference gatherer, reward trainer no_rng_msg = ( ".*don't provide.*random state.*provide.*fragmenter" ".*preference gatherer.*reward_trainer.*" ) - def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): + def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): preference_comparisons.PreferenceComparisons( agent_trainer, reward_net, @@ -261,16 +297,16 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): ) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparsions(gatherer, None, None, rng=None) + build_preference_comparisons(gatherer, None, None, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparsions(None, reward_trainer, None, rng=None) + build_preference_comparisons(None, reward_trainer, None, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparsions(None, None, random_fragmenter, rng=None) + build_preference_comparisons(None, None, random_fragmenter, rng=None) # This should not raise - build_preference_comparsions(gatherer, reward_trainer, random_fragmenter, rng=None) + build_preference_comparisons(gatherer, reward_trainer, random_fragmenter, rng=None) # if providing fragmenter, preference gatherer, reward trainer, does not need rng. with_rng_msg = ( @@ -279,7 +315,7 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): ) with pytest.raises(ValueError, match=with_rng_msg): - build_preference_comparsions( + build_preference_comparisons( gatherer, reward_trainer, random_fragmenter, @@ -287,10 +323,10 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): ) # This should not raise - build_preference_comparsions(None, None, None, rng=rng) - build_preference_comparsions(gatherer, None, None, rng=rng) - build_preference_comparsions(None, reward_trainer, None, rng=rng) - build_preference_comparsions(None, None, random_fragmenter, rng=rng) + build_preference_comparisons(None, None, None, rng=rng) + build_preference_comparisons(gatherer, None, None, rng=rng) + build_preference_comparisons(None, reward_trainer, None, rng=rng) + build_preference_comparisons(None, None, random_fragmenter, rng=rng) @pytest.mark.parametrize( @@ -485,7 +521,8 @@ def test_gradient_accumulation( dataset = preference_comparisons.PreferenceDataset() trajectory = agent_trainer.sample(num_trajectories) fragments = random_fragmenter(trajectory, 1, num_trajectories) - preferences = preference_gatherer(fragments) + preference_gatherer.query(fragments) + fragments, preferences = preference_gatherer.gather() dataset.push(fragments, preferences) seed = rng.integers(2**32) @@ -529,8 +566,10 @@ def test_synthetic_gatherer_deterministic( ) trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) - preferences1 = gatherer(fragments) - preferences2 = gatherer(fragments) + gatherer.query(fragments) + _, preferences1 = gatherer.gather() + gatherer.query(fragments) + _, preferences2 = gatherer.gather() assert np.all(preferences1 == preferences2) @@ -608,7 +647,8 @@ def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): gatherer = preference_comparisons.SyntheticGatherer(rng=rng) for i in range(6): fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=1) - preferences = gatherer(fragments) + gatherer.query(fragments) + fragments, preferences = gatherer.gather() assert len(dataset) == min(i, 5) dataset.push(fragments, preferences) assert len(dataset) == min(i + 1, 5) @@ -627,7 +667,8 @@ def test_store_and_load_preference_dataset( trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) - preferences = gatherer(fragments) + gatherer.query(fragments) + fragments, preferences = gatherer.gather() dataset.push(fragments, preferences) path = tmp_path / "preferences.pkl" @@ -1095,3 +1136,319 @@ def test_that_trainer_improves( novice_agent_rewards, trained_agent_rewards, ) + + +# PreferenceQuerent +def test_returns_query_dict_from_query_sequence_with_correct_length(): + querent = PreferenceQuerent() + query_sequence = [Mock()] + query_dict = querent(query_sequence) + assert len(query_dict) == len(query_sequence) + + +def test_returned_queries_have_uuid(): + querent = PreferenceQuerent() + query_dict = querent([Mock()]) + + try: + key = list(query_dict.keys())[0] + uuid.UUID(key, version=4) + except ValueError: + pytest.fail() + + +# PrefCollectQuerent +def test_sends_put_request_for_each_query(requests_mock): + address = "https://test.de" + querent = RESTQuerent(collection_service_address=address, video_output_dir="video") + querent._load_video_data = MagicMock() + query_id = "1234" + + requests_mock.put(f"{address}/{query_id}") + querent._query(query_id) + + assert requests_mock.last_request.method == "PUT" + + +@pytest.fixture +def empty_trajectory_with_rew_and_render_images() -> TrajectoryWithRew: + num_frames = 10 + frame_shape = (200, 200) + return types.TrajectoryWithRew( + obs=np.zeros((num_frames, *frame_shape, 3), np.uint8), + acts=np.zeros((num_frames - 1,), np.uint8), + infos=np.array( + [ + {"rendered_img": np.zeros((*frame_shape, 3), np.uint8)} + for _ in range(num_frames - 1) + ], + ), + rews=np.zeros((num_frames - 1,)), + terminal=True, + ) + + +def is_base64(data): + """Checks if the data is base64 encoded.""" + try: + base64.b64decode(data) + return True + except binascii.Error: + return False + + +def test_load_video_data(empty_trajectory_with_rew_and_render_images): + address = "https://test.de" + video_dir = "video" + querent = RESTQuerent( + collection_service_address=address, video_output_dir=video_dir, + ) + + # Setup query with saved videos + query_id = "1234" + frames = querent._get_frames_for_each_observation( + empty_trajectory_with_rew_and_render_images, + ) + output_path = querent._create_query_video_path(query_id, "left") + querent._write(frames, output_path, progress_logger=False) + output_path = querent._create_query_video_path(query_id, "right") + querent._write(frames, output_path, progress_logger=False) + + # Load videos and check their encoding + video_data = querent._load_video_data(query_id) + for alternative in ("left", "right"): + assert is_base64(video_data[alternative]) + + # Check that loading video data of non-existent query fails + with pytest.raises(RuntimeError): + querent._load_video_data("0") + + shutil.rmtree(video_dir) + + +def test_prefcollectquerent_call_creates_all_videos( + empty_trajectory_with_rew_and_render_images, +): + address = "https://test.de" + queries = [ + ( + empty_trajectory_with_rew_and_render_images, + empty_trajectory_with_rew_and_render_images, + ), + ] + querent = RESTQuerent(collection_service_address=address, video_output_dir="video") + identified_queries = querent(queries) + for query_id, _ in identified_queries.items(): + file = os.path.join(querent.video_output_dir, query_id + "-{}.webm") + for part in ["left", "right"]: + assert os.path.isfile(file.format(part)) + os.remove(file.format(part)) + + +@pytest.fixture( + params=["obs_only", "with_render_images", "with_render_image_paths"], +) +def fragment(request, empty_trajectory_with_rew_and_render_images): + obs = empty_trajectory_with_rew_and_render_images.obs + infos = empty_trajectory_with_rew_and_render_images.infos + if request.param == "with_render_images": + infos = np.array( + [ + {"rendered_img": frame} + for frame in empty_trajectory_with_rew_and_render_images.obs[1:] + ], + ) + elif request.param == "with_render_image_paths": + tmp_dir = tempfile.mkdtemp() + infos = [] + for frame in obs[1:]: + unique_file_path = str(pathlib.Path(tmp_dir) / (str(uuid.uuid4()) + ".png")) + imageio.imwrite(unique_file_path, frame) + infos.append({"rendered_img": unique_file_path}) + infos = np.array(infos) + yield types.TrajectoryWithRew( + obs=obs, + acts=empty_trajectory_with_rew_and_render_images.acts, + infos=infos, + terminal=True, + rews=empty_trajectory_with_rew_and_render_images.rews, + ) + if request.param == "with_render_image_paths": + shutil.rmtree(tmp_dir) + + +# utils +@pytest.mark.parametrize("codec", ["webm", "mp4"]) +def test_write_fragment_video(fragment, codec): + output_dir = "video" + video_based_querent = VideoBasedQuerent(video_output_dir=output_dir) + video_path = pathlib.Path(f"{output_dir}.{codec}") + if "rendered_img" not in fragment.infos[0]: + with pytest.raises(ValueError): + video_based_querent._write_fragment_video(fragment, output_path=video_path) + else: + video_based_querent._write_fragment_video(fragment, output_path=video_path) + assert os.path.isfile(video_path) + os.remove(video_path) + + +def test_remove_rendered_images(fragment): + trajectories = [fragment] + remove_rendered_images(trajectories) + assert not any("rendered_img" in info for trajectory in trajectories for info in trajectory.infos) + + +# PreferenceGatherer +class ConcretePreferenceGatherer(PreferenceGatherer): + """A concrete preference gatherer for unit testing purposes only.""" + + def _gather_preference(self, query_id: str) -> float: + return 0. + + +def test_adds_queries_to_pending_queries(): + gatherer = ConcretePreferenceGatherer() + query = Mock() + queries = [query] + + gatherer.query(queries) + assert query in list(gatherer.pending_queries.values()) + + +def test_clears_pending_queries(trajectory_with_rew): + gatherer = SyntheticGatherer(sample=False) + + queries = [(trajectory_with_rew, trajectory_with_rew)] + gatherer.query(queries) + + gatherer.gather() + + assert len(gatherer.pending_queries) == 0 + + +# PrefCollectGatherer +def test_returns_none_for_unanswered_query(requests_mock): + address = "https://test.de" + query_id = "1234" + answer = None + + gatherer = RESTGatherer( + collection_service_address=address, + querent_kwargs={"video_output_dir": "videos"}, + ) + + requests_mock.get( + f"{address}/{query_id}", + json={"query_id": query_id, "label": answer}, + ) + + preference = gatherer._gather_preference(query_id) + + assert preference is answer + + +def test_returns_preference_for_answered_query(requests_mock): + address = "https://test.de" + query_id = "1234" + answer = 1.0 + + gatherer = RESTGatherer( + collection_service_address=address, + querent_kwargs={"video_output_dir": "videos"}, + ) + + requests_mock.get( + f"{address}/{query_id}", + json={"query_id": query_id, "label": answer}, + ) + + preference = gatherer._gather_preference(query_id) + + assert preference == answer + + +def test_keeps_pending_query_for_unanswered_query(): + gatherer = RESTGatherer( + collection_service_address="https://test.de", + wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, + ) + gatherer._gather_preference = MagicMock(return_value=None) + gatherer.pending_queries = {"1234": Mock()} + + pending_queries_pre = gatherer.pending_queries.copy() + gatherer.gather() + + assert pending_queries_pre == gatherer.pending_queries + + +def test_deletes_pending_query_for_answered_query(): + gatherer = RESTGatherer( + collection_service_address="https://test.de", + wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, + ) + preference = 0.5 + gatherer._gather_preference = MagicMock(return_value=preference) + gatherer.pending_queries = {"1234": Mock()} + + gatherer.gather() + + assert len(gatherer.pending_queries) == 0 + + +def test_gathers_valid_preference(): + gatherer = RESTGatherer( + collection_service_address="https://test.de", + wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, + ) + preference = 0.5 + gatherer._gather_preference = MagicMock(return_value=preference) + query = Mock() + gatherer.pending_queries = {"1234": query} + + gathered_queries, gathered_preferences = gatherer.gather() + + assert gathered_preferences[0] == preference + assert gathered_queries[0] == query + + +def test_ignores_incomparable_answer(): + gatherer = RESTGatherer( + collection_service_address="https://test.de", + wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, + ) + gatherer._gather_preference = MagicMock(return_value=-1.0) + gatherer.pending_queries = {"1234": Mock()} + + gathered_queries, gathered_preferences = gatherer.gather() + + assert len(gathered_preferences) == 0 + assert len(gathered_queries) == 0 + + +# SynchronousHumanGatherer +@patch("builtins.input") +@patch("IPython.display.display") +def test_command_line_gatherer(mock_display, mock_input, fragment): + del mock_display # unused + video_dir = "videos" + gatherer = preference_comparisons.CommandLineGatherer( + video_dir=pathlib.Path(video_dir), + ) + + # these inputs are designed solely to pass the test. they aren't tested for anything + trajectory_pairs = [(fragment, fragment),] + gatherer.query(trajectory_pairs) + + # this is the actual test + mock_input.return_value = "1" + assert gatherer.gather()[1] == np.array([1.0]) + + gatherer.query(trajectory_pairs) + mock_input.return_value = "2" + assert gatherer.gather()[1] == np.array([0.0]) + + shutil.rmtree(video_dir) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 33677c68f..119ea609b 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,14 +1,16 @@ """Tests for `imitation.data.wrappers`.""" +from pathlib import Path from typing import List, Sequence, Type import gymnasium as gym +import imageio import numpy as np import pytest from stable_baselines3.common.vec_env import DummyVecEnv from imitation.data import types -from imitation.data.wrappers import BufferingWrapper +from imitation.data.wrappers import BufferingWrapper, RenderImageInfoWrapper class _CountingEnv(gym.Env): # pragma: no cover @@ -278,3 +280,44 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): assert venv.n_transitions == 0 with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() + + +@pytest.mark.parametrize("scale_factor", [0.1, 0.5, 1.0]) +def test_writes_rendered_img_to_info(scale_factor): + env = gym.make("CartPole-v0", render_mode="rgb_array") + wrapped_env = RenderImageInfoWrapper(env, scale_factor=scale_factor) + wrapped_env.reset() + rendered_img = wrapped_env.render() + _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample()) + assert "rendered_img" in info + assert isinstance(info["rendered_img"], np.ndarray) + if scale_factor == 1.0: + assert np.allclose(info["rendered_img"], rendered_img) + assert int(scale_factor * rendered_img.shape[0]) == info["rendered_img"].shape[0] + assert int(scale_factor * rendered_img.shape[1]) == info["rendered_img"].shape[1] + + +def test_raises_assertion_error_if_env_not_in_correct_render_mode(): + wrong_mode = "human" + env = gym.make("CartPole-v0", render_mode=wrong_mode) + + with pytest.raises( + AssertionError, + match="The environment must be in render mode 'rgb_array' " + "in order to use this wrapper " + f"but render_mode is '{wrong_mode}'.", + ): + RenderImageInfoWrapper(env) + + +def test_rendered_img_file_cache(): + env = gym.make("CartPole-v0", render_mode="rgb_array") + wrapped_env = RenderImageInfoWrapper(env, use_file_cache=True) + assert Path(wrapped_env.file_cache).exists() + wrapped_env.reset() + _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample()) + rendered_img_path = info["rendered_img"] + assert Path(rendered_img_path).exists() + assert (imageio.imread(rendered_img_path) == wrapped_env.render()).all() + wrapped_env.close() + assert not Path(wrapped_env.file_cache).exists() diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index ae39116e7..e4f0cade3 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -11,6 +11,7 @@ import pathlib import pickle import platform +import re import shutil import sys import tempfile @@ -22,6 +23,7 @@ import pandas as pd import pytest import ray.tune as tune +import requests_mock import sacred import sacred.utils import stable_baselines3 @@ -168,6 +170,41 @@ def test_train_preference_comparisons_main(tmpdir, preference_comparison_config) assert isinstance(run.result, dict) +def response_callback(request, context): + return {"query_id": request.path.split("/")[0], "label": 1.0} + + +def test_train_preference_comparisons_with_collected_preferences(tmpdir): + address = "http://localhost:8000" + config_updates = dict( + logging=dict(log_root=tmpdir), + environment=dict(env_make_kwargs=dict(render_mode="rgb_array")), + gatherer_kwargs=dict( + wait_for_user=False, + querent_kwargs=dict(video_output_dir=tmpdir), + collection_service_address=address, + ), + ) + + with requests_mock.Mocker() as m: + request_matcher = re.compile(f"{address}/") + + m.put(url=request_matcher) + m.get( + url=request_matcher, + json=response_callback, + ) + + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=["cartpole", "human_preferences"] + + ALGO_FAST_CONFIGS["preference_comparison"], + config_updates=config_updates, + ) + + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + @pytest.mark.parametrize( "env_name", ["seals_cartpole", "mountain_car", "seals_mountain_car"],