From 63e9112faf92d195b4b80169d213c3ad17c49c2e Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 15 Dec 2022 17:27:57 +0100 Subject: [PATCH 001/143] Adds RenderImageInfoWrapper --- src/imitation/data/wrappers.py | 75 ++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 09ad42247..e3aaf10f2 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -1,7 +1,12 @@ """Environment wrappers for collecting rollouts.""" +import os +import shutil +import tempfile +import uuid from typing import List, Optional, Sequence, Tuple +import cv2 import gym import numpy as np import numpy.typing as npt @@ -10,6 +15,76 @@ from imitation.data import rollout, types +class RenderImageInfoWrapper(gym.Wrapper): + """Saves render images to `info`. + + Can be very memory intensive for large render images. + Use `scale_factor` to reduce render image size. + If you need to preserve the resolution and memory + runs out, you can activate `ues_file_cache` to save + render images and instead put their path into `info`. + """ + + def __init__( + self, + env: gym.Env, + scale_factor: float = 1., + use_file_cache: bool = False, + ): + """Builds RenderImageInfoWrapper. + + Args: + env: Environment to wrap. + """ + super().__init__(env) + self.scale_factor = scale_factor + self.use_file_cache = use_file_cache + if self.use_file_cache: + self.file_cache = tempfile.mkdtemp("imitation_RenderImageInfoWrapper") + + self._active = True + + def set_render_image_active(self, active: bool): + self._active = active + + def step(self, action): + obs, rew, done, info = self.env.step(action) + + if self._active: + rendered_image = self.render(mode="rgb_array") + # Scale the render image + scaled_size = ( + int(self.scale_factor * rendered_image.shape[0]), + int(self.scale_factor * rendered_image.shape[1]), + ) + scaled_rendered_image = cv2.resize( + rendered_image, + scaled_size, + interpolation=cv2.INTER_AREA, + ) + # Store the render image + if not self.use_file_cache: + info["rendered_img"] = scaled_rendered_image + else: + unique_file_path = os.path.join( + self.file_cache, + str(uuid.uuid4()) + ".npy", + ) + np.save(unique_file_path, scaled_rendered_image) + info["rendered_img"] = unique_file_path + + # Do not show window of classic control envs + if self.env.viewer is not None and self.env.viewer.window.visible: + self.env.viewer.window.set_visible(False) + + return obs, rew, done, info + + def close(self) -> None: + if self.use_file_cache: + shutil.rmtree(self.file_cache) + return super().close() + + class BufferingWrapper(VecEnvWrapper): """Saves transitions of underlying VecEnv. From ed5362cecb5e8761d4c2ba19554a62742aa34308 Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 15 Dec 2022 17:29:16 +0100 Subject: [PATCH 002/143] Add PrefCollectGatherer --- .../algorithms/preference_comparisons.py | 163 +++++++++++++++++- 1 file changed, 162 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 413cd979a..12d6e023e 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -5,8 +5,10 @@ """ import abc import math +import os import pickle import re +import uuid from collections import defaultdict from typing import ( Any, @@ -24,7 +26,9 @@ overload, ) +import cv2 import numpy as np +import requests import torch as th from scipy import special from stable_baselines3.common import base_class, type_aliases, utils, vec_env @@ -906,6 +910,153 @@ def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]: return np.array(rews1, dtype=np.float32), np.array(rews2, dtype=np.float32) +class PrefCollectGatherer(PreferenceGatherer): + """Gathers preferences from PrefCollect interface.""" + def __init__( + self, + pref_collect_address: str, + video_output_dir: AnyPath, + video_fps: str = 20, + wait_for_user: bool = True, + random_preferences: bool = False, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ) -> None: + """Initializes the preference gatherer. + + Args: + pref_collect_address: Network address to PrefCollect instance. + video_output_dir: Path to where fragment videos are saved. + video_fps: Frames per second of the fragment videos. + random_preferences: Whether to gather random preferences (for debugging). + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + """ + super().__init__(custom_logger) + self.rng = rng + self.random_preferences = random_preferences + self.query_endpoint = pref_collect_address + "/preferences/query/" + self.video_output_dir = video_output_dir + self.frames_per_second = video_fps + self.pending_queries = {} + self.wait_for_user = wait_for_user + + # Create video directory + os.makedirs(self.video_output_dir, exist_ok=True) + + def __call__( + self, fragment_pairs: Sequence[TrajectoryPair] + ) -> Tuple[Sequence[TrajectoryPair], np.ndarray]: + + if self.random_preferences: + return fragment_pairs, self.rng.choice([0.0, 1.0, 0.5], size=len(fragment_pairs)) + + # Generate UUID for each query (fragment pair) + new_queries = {str(uuid.uuid4()): query for query in fragment_pairs} + + # Save fragment videos and submit queries + for query_id, query in new_queries.items(): + self._write_fragment_video(query[0], name=f"{query_id}-left") + self._write_fragment_video(query[1], name=f"{query_id}-right") + requests.put( + self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)} + ) + + if self.wait_for_user: + self.logger.log("Waiting for user to provide preferences. Press enter to continue.") + input() + + # Gather preferences for pending queries + self.pending_queries = {**self.pending_queries, **new_queries} + + gathered_queries = [] + gathered_preferences = [] + + for query_id, query in list(self.pending_queries.items()): + preference = self._gather_preference(query_id) + + if preference is not None: + # Preference for this query has been provided + if 0 <= preference <= 1: + gathered_queries.append(query) + gathered_preferences.append(preference) + # else: fragments were incomparable + del self.pending_queries[query_id] + + return gathered_queries, np.array(gathered_preferences, dtype=np.float32) + + def _write_fragment_video(self, fragment, name: str) -> None: + + output_file_name = os.path.join(self.video_output_dir, f'{name}.webm') + frame_shape = self._get_frame_shape(fragment) + video_writer = cv2.VideoWriter( + output_file_name, + cv2.VideoWriter_fourcc(*'VP90'), + self.frames_per_second, + frame_shape, + ) + + # Make videos from rendered observations if available + if "rendered_img" in fragment.infos[0]: + frames = [] + for i in range(len(fragment.infos)): + frame_info = fragment.infos[i]["rendered_img"] + # If path is provided load cached image + if isinstance(frame_info, AnyPath.__args__): + frame = np.load(frame_info) + elif isinstance(frame_info, np.ndarray): + frame = frame_info + frames.append(frame) + else: + frames = frames.obs + + for frame in frames: + # Transform to RGB frame if necessary + if frame.shape[-1] < 3: + missing_channels = 3 - frame.shape[-1] + frame = np.concatenate( + [frame] + missing_channels * [frame[..., -1][..., None]], axis=-1 + ) + video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + + video_writer.release() + + @staticmethod + def _get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: + if "rendered_img" in fragment.infos[0]: + rendered_img_info = fragment.infos[0]["rendered_img"] + # If path is provided load cached image + if isinstance(rendered_img_info, AnyPath.__args__): + single_frame = np.load(rendered_img_info) + else: + single_frame = rendered_img_info + else: + single_frame = np.array(fragment.obs[0]) + # Check whether obervations are image-like + if len(single_frame.shape) < 2: + raise ValueError("Observation must be an image, " + f"but shape {single_frame.shape} has too few dimensions!") + # Swap dimensions, because matrix and image dims are swapped + return single_frame.shape[1], single_frame.shape[0] + + def _gather_preference(self, query_id: str) -> float: + answered_query = requests.get(self.query_endpoint + query_id).json() + return answered_query["label"] + + def remove_rendered_images(self, trajectories: Sequence[TrajectoryWithRew]) -> None: + """Removes rendered images of the provided trajectories list.""" + for traj in trajectories: + for info in traj.infos: + try: + rendered_img_info = info["rendered_img"] + if isinstance(rendered_img_info, AnyPath.__args__): + os.remove(rendered_img_info) + elif isinstance(rendered_img_info, np.ndarray): + del info["rendered_img"] + except KeyError: + pass + + class PreferenceDataset(data_th.Dataset): """A PyTorch Dataset for preference comparisons. @@ -1707,9 +1858,19 @@ def train( fragments = self.fragmenter(trajectories, self.fragment_length, num_pairs) with self.logger.accumulate_means("preferences"): self.logger.log("Gathering preferences") - preferences = self.preference_gatherer(fragments) + if isinstance(self.preference_gatherer, PrefCollectGatherer): + # Gather fragment pairs for which preferences have been provided + fragments, preferences = self.preference_gatherer(fragments) + # Free up RAM or disk space from keeping rendered images + self.preference_gatherer.remove_rendered_images(trajectories) + else: + preferences = self.preference_gatherer(fragments) + self.dataset.push(fragments, preferences) self.logger.log(f"Dataset now contains {len(self.dataset)} comparisons") + # Skip training if dataset is empty + if len(self.dataset) == 0: + continue ########################## # Train the reward model # From 9359ad43dec479be361282ffef4fdee46e97f254 Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 15 Dec 2022 17:30:25 +0100 Subject: [PATCH 003/143] Adds configuration for post wrappers --- src/imitation/scripts/common/common.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 72d44f2f4..2240e5d83 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,8 +3,9 @@ import contextlib import logging import pathlib -from typing import Any, Generator, Mapping, Sequence, Tuple, Union +from typing import Any, Callable, Generator, Mapping, Sequence, Tuple, Union +import gym import numpy as np import sacred from stable_baselines3.common import vec_env @@ -36,6 +37,8 @@ def config(): parallel = True # Use SubprocVecEnv rather than DummyVecEnv max_episode_steps = None # Set to positive int to limit episode horizons env_make_kwargs = {} # The kwargs passed to `spec.make`. + post_wrappers = [] # Wrappers applied after `spec.make` + post_wrappers_kwargs = [] # The kwargs passed to post wrappers locals() # quieten flake8 @@ -142,6 +145,8 @@ def make_venv( log_dir: str, max_episode_steps: int, env_make_kwargs: Mapping[str, Any], + post_wrappers: Mapping[str, Callable[[gym.Env, int], gym.Env]], + post_wrappers_kwargs: Mapping[str, Mapping[str, Any]], **kwargs, ) -> Generator[vec_env.VecEnv, None, None]: """Builds the vector environment. @@ -156,12 +161,20 @@ def make_venv( episode. log_dir: Logs episode return statistics to a subdirectory 'monitor`. env_make_kwargs: The kwargs passed to `spec.make` of a gym environment. + post_wrappers: The wrappers applied after environment creation with `spec.make`. + post_wrappers_kwargs: List of kwargs passed to the respective post wrappers. kwargs: Passed through to `util.make_vec_env`. Yields: The constructed vector environment. """ rng = make_rng() + # Update env_fns for post wrappers with kwargs + updated_post_wrappers = [] + for key, post_wrapper in post_wrappers.items(): + def updated_post_wrapper(env, env_id): + return post_wrapper(env, env_id, **post_wrappers_kwargs[key]) + updated_post_wrappers.append(updated_post_wrapper) # Note: we create the venv outside the try -- finally block for the case that env # creation fails. venv = util.make_vec_env( @@ -172,6 +185,7 @@ def make_venv( max_episode_steps=max_episode_steps, log_dir=log_dir, env_make_kwargs=env_make_kwargs, + post_wrappers=updated_post_wrappers, **kwargs, ) try: From eec88e08df56095815e26ac28f1d48c81b71d0b9 Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 15 Dec 2022 17:31:18 +0100 Subject: [PATCH 004/143] Adds named_config for human preferences --- .../config/train_preference_comparisons.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ba4e9483c..4b7387442 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -3,6 +3,7 @@ import sacred from imitation.algorithms import preference_comparisons +from imitation.data.wrappers import RenderImageInfoWrapper from imitation.scripts.common import common, reward, rl, train train_preference_comparisons_ex = sacred.Experiment( @@ -61,6 +62,27 @@ def train_defaults(): query_schedule = "hyperbolic" +@train_preference_comparisons_ex.named_config +def human_preferences(): + gatherer_cls = preference_comparisons.PrefCollectGatherer + gatherer_kwargs = dict( + pref_collect_address="http://127.0.0.1:8000", + video_output_dir="../pref-collect/videofiles", + video_fps=20, + wait_for_user=True, + random_preferences=False, + ) + common = dict( + post_wrappers=dict( + RenderImageInfoWrapper=lambda env, env_id, **kwargs: + RenderImageInfoWrapper(env, **kwargs), + ), + post_wrappers_kwargs=dict( + RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True), + ), + ) + + @train_preference_comparisons_ex.named_config def cartpole(): common = dict(env_name="CartPole-v1") From 346efde443bb7f4b94095397f14f0667d495a53f Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 15 Dec 2022 17:38:37 +0100 Subject: [PATCH 005/143] Add experiment script for human preferences --- experiments/human_preferences.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 experiments/human_preferences.sh diff --git a/experiments/human_preferences.sh b/experiments/human_preferences.sh new file mode 100644 index 000000000..d1351aefb --- /dev/null +++ b/experiments/human_preferences.sh @@ -0,0 +1,11 @@ +python -m imitation.scripts.train_preference_comparisons \ + with \ + pendulum \ + human_preferences \ + total_comparisons=5000 \ + total_timesteps=1000000 \ + gatherer_kwargs.pref_collect_address=127.0.0.1:8000 \ + gatherer_kwargs.video_output_dir=../pref-collect/videofiles \ + gatherer_kwargs.wait_for_user=True \ + common.post_wrappers_kwargs.RenderImageInfoWrapper.scale_factor=0.5 \ + common.post_wrappers_kwargs.RenderImageInfoWrapper.use_file_cache=True \ \ No newline at end of file From e6a44002dee1dae663dedef9a018eadf67d6e0e6 Mon Sep 17 00:00:00 2001 From: rk1a Date: Fri, 16 Dec 2022 15:16:39 +0100 Subject: [PATCH 006/143] Fix post wrappers config --- src/imitation/scripts/common/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 2240e5d83..db7d57e30 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -37,8 +37,8 @@ def config(): parallel = True # Use SubprocVecEnv rather than DummyVecEnv max_episode_steps = None # Set to positive int to limit episode horizons env_make_kwargs = {} # The kwargs passed to `spec.make`. - post_wrappers = [] # Wrappers applied after `spec.make` - post_wrappers_kwargs = [] # The kwargs passed to post wrappers + post_wrappers = {} # Wrappers applied after `spec.make` + post_wrappers_kwargs = {} # The kwargs passed to post wrappers locals() # quieten flake8 From 4fa0726dd4087a1fb67012378ec7b5708492caed Mon Sep 17 00:00:00 2001 From: rk1a Date: Fri, 30 Dec 2022 12:36:55 +0100 Subject: [PATCH 007/143] Fix post wrappers config --- src/imitation/scripts/common/common.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 2240e5d83..82a3d65a4 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -1,6 +1,7 @@ """Common configuration elements for scripts.""" import contextlib +import functools import logging import pathlib from typing import Any, Callable, Generator, Mapping, Sequence, Tuple, Union @@ -37,8 +38,8 @@ def config(): parallel = True # Use SubprocVecEnv rather than DummyVecEnv max_episode_steps = None # Set to positive int to limit episode horizons env_make_kwargs = {} # The kwargs passed to `spec.make`. - post_wrappers = [] # Wrappers applied after `spec.make` - post_wrappers_kwargs = [] # The kwargs passed to post wrappers + post_wrappers = {} # Wrappers applied after `spec.make` + post_wrappers_kwargs = {} # The kwargs passed to post wrappers locals() # quieten flake8 @@ -172,9 +173,9 @@ def make_venv( # Update env_fns for post wrappers with kwargs updated_post_wrappers = [] for key, post_wrapper in post_wrappers.items(): - def updated_post_wrapper(env, env_id): - return post_wrapper(env, env_id, **post_wrappers_kwargs[key]) - updated_post_wrappers.append(updated_post_wrapper) + if key in post_wrappers_kwargs: + post_wrapper = functools.partial(post_wrapper, **post_wrappers_kwargs[key]) + updated_post_wrappers.append(post_wrapper) # Note: we create the venv outside the try -- finally block for the case that env # creation fails. venv = util.make_vec_env( From ae0087ffd99f1e0e3e6626727407361d35442a4e Mon Sep 17 00:00:00 2001 From: rk1a Date: Fri, 30 Dec 2022 12:37:44 +0100 Subject: [PATCH 008/143] Fix PrefCollect address --- experiments/human_preferences.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/human_preferences.sh b/experiments/human_preferences.sh index d1351aefb..1690d731e 100644 --- a/experiments/human_preferences.sh +++ b/experiments/human_preferences.sh @@ -4,7 +4,7 @@ python -m imitation.scripts.train_preference_comparisons \ human_preferences \ total_comparisons=5000 \ total_timesteps=1000000 \ - gatherer_kwargs.pref_collect_address=127.0.0.1:8000 \ + gatherer_kwargs.pref_collect_address=http://127.0.0.1:8000 \ gatherer_kwargs.video_output_dir=../pref-collect/videofiles \ gatherer_kwargs.wait_for_user=True \ common.post_wrappers_kwargs.RenderImageInfoWrapper.scale_factor=0.5 \ From 4b3968b4f1d751b7e5be66507366a574367fdb5e Mon Sep 17 00:00:00 2001 From: rk1a Date: Wed, 15 Feb 2023 15:34:03 +0100 Subject: [PATCH 009/143] Extract preference querying into querent class --- .../algorithms/preference_comparisons.py | 99 +++++++++++++++---- 1 file changed, 80 insertions(+), 19 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 413cd979a..1790e6844 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -7,6 +7,7 @@ import math import pickle import re +import uuid from collections import defaultdict from typing import ( Any, @@ -778,6 +779,36 @@ def variance_estimate(self, rews1: th.Tensor, rews2: th.Tensor) -> float: return var_estimate +class PreferenceQuerent: + """Dummy class for querying preferences between trajectory fragments.""" + + def __init__( + self, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ) -> None: + """Initializes the preference querent. + + Args: + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + """ + del rng + self.logger = custom_logger or imit_logger.configure() + + def __call__(self, queries: Sequence[TrajectoryWithRewPair]) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + """Queries the user for their preferences. + This dummy implementation does nothing because by default the queries are answered by an oracle. + + Args: + queries: sequence of pairs of trajectory fragments + + Returns: + dictionary with queries and their respective UUIDs + """ + return {str(uuid.uuid4()): query for query in queries} + + class PreferenceGatherer(abc.ABC): """Base class for gathering preference comparisons between trajectory fragments.""" @@ -798,15 +829,14 @@ def __init__( # the PreferenceGatherer we use needs one). del rng self.logger = custom_logger or imit_logger.configure() + self.pending_queries = {} @abc.abstractmethod - def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: - """Gathers the probabilities that fragment 1 is preferred in `fragment_pairs`. - - Args: - fragment_pairs: sequence of pairs of trajectory fragments + def __call__(self) -> Tuple[np.ndarray, np.ndarray]: + """Gathers the probabilities that fragment 1 is preferred in `queries`. Returns: + TODO return value A numpy array with shape (b, ), where b is the length of the input (i.e. batch size). Each item in the array is the probability that fragment 1 is preferred over fragment 2 for the corresponding @@ -817,6 +847,14 @@ def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarra probabilities. """ # noqa: DAR202 + def add(self, new_queries: Dict[str, Sequence[TrajectoryWithRewPair]]) -> None: + """Adds queries to pending queries. + + Args: + new_queries: pairs of trajectory fragments + """ + self.pending_queries = {**self.pending_queries, **new_queries} + class SyntheticGatherer(PreferenceGatherer): """Computes synthetic preferences using ground-truth environment rewards.""" @@ -865,9 +903,10 @@ def __init__( if self.sample and self.rng is None: raise ValueError("If `sample` is True, then `rng` must be provided.") - def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Computes probability fragment 1 is preferred over fragment 2.""" - returns1, returns2 = self._reward_sums(fragment_pairs) + returns1, returns2 = self._reward_sums(self.pending_queries.values()) + if self.temperature == 0: return (np.sign(returns1 - returns2) + 1) / 2 @@ -878,20 +917,24 @@ def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarra returns_diff = np.clip(returns2 - returns1, -self.threshold, self.threshold) # Instead of computing exp(rews1) / (exp(rews1) + exp(rews2)) directly, # we divide enumerator and denominator by exp(rews1) to prevent overflows: - model_probs = 1 / (1 + np.exp(returns_diff)) + choice_probs = 1 / (1 + np.exp(returns_diff)) # Compute the mean binary entropy. This metric helps estimate # how good we can expect the performance of the learned reward # model to be at predicting preferences. entropy = -( - special.xlogy(model_probs, model_probs) - + special.xlogy(1 - model_probs, 1 - model_probs) + special.xlogy(choice_probs, choice_probs) + + special.xlogy(1 - choice_probs, 1 - choice_probs) ).mean() self.logger.record("entropy", entropy) + # Clear pending queries because the oracle has answered all + queries = list(self.pending_queries.values()) + self.pending_queries.clear() + if self.sample: assert self.rng is not None - return self.rng.binomial(n=1, p=model_probs).astype(np.float32) - return model_probs + return queries, self.rng.binomial(n=1, p=choice_probs).astype(np.float32) + return queries, choice_probs def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]: rews1, rews2 = zip( @@ -1488,6 +1531,7 @@ def __init__( reward_model: reward_nets.RewardNet, num_iterations: int, fragmenter: Optional[Fragmenter] = None, + preference_querent: Optional[PreferenceQuerent] = None, preference_gatherer: Optional[PreferenceGatherer] = None, reward_trainer: Optional[RewardTrainer] = None, comparison_queue_size: Optional[int] = None, @@ -1516,7 +1560,8 @@ def __init__( for which preferences will be gathered. These fragments could be random, or they could be selected more deliberately (active learning). Default is a random fragmenter. - preference_gatherer: how to get preferences between trajectory fragments. + preference_querent: queries preferences between trajectory fragments. + preference_gatherer: gathers preferences between trajectory fragments. Default (and currently the only option) is to use synthetic preferences based on ground-truth rewards. Human preferences could be implemented here in the future. @@ -1628,6 +1673,15 @@ def __init__( rng=self.rng, ) self.fragmenter.logger = self.logger + if preference_querent: + self.preference_querent = preference_querent + else: + # TODO add querent to train script + #assert self.rng is not None + self.preference_querent = PreferenceQuerent( + custom_logger=self.logger, + rng=self.rng, + ) if preference_gatherer: self.preference_gatherer = preference_gatherer else: @@ -1688,15 +1742,15 @@ def train( reward_loss = None reward_accuracy = None - for i, num_pairs in enumerate(schedule): + for i, num_queries in enumerate(schedule): ########################## # Gather new preferences # ########################## num_steps = math.ceil( - self.transition_oversampling * 2 * num_pairs * self.fragment_length, + self.transition_oversampling * 2 * num_queries * self.fragment_length, ) self.logger.log( - f"Collecting {2 * num_pairs} fragments ({num_steps} transitions)", + f"Collecting {2 * num_queries} fragments ({num_steps} transitions)", ) trajectories = self.trajectory_generator.sample(num_steps) # This assumes there are no fragments missing initial timesteps @@ -1704,11 +1758,18 @@ def train( horizons = (len(traj) for traj in trajectories if traj.terminal) self._check_fixed_horizon(horizons) self.logger.log("Creating fragment pairs") - fragments = self.fragmenter(trajectories, self.fragment_length, num_pairs) + + queries = self.fragmenter(trajectories, self.fragment_length, num_queries) + + identified_queries = self.preference_querent(queries) + self.preference_gatherer.add(identified_queries) + with self.logger.accumulate_means("preferences"): self.logger.log("Gathering preferences") - preferences = self.preference_gatherer(fragments) - self.dataset.push(fragments, preferences) + # Gather fragment pairs (queries) for which preferences have been provided + queries, preferences = self.preference_gatherer() + + self.dataset.push(queries, preferences) self.logger.log(f"Dataset now contains {len(self.dataset)} comparisons") ########################## From 8ccaa71a4cd37d81c4362367dd7f111d82a780fc Mon Sep 17 00:00:00 2001 From: rk1a Date: Wed, 1 Mar 2023 15:04:29 +0100 Subject: [PATCH 010/143] Add querent and gatherer for human preferences Co-authored-by: Robert Klassert Co-authored-by: Marvin Schweizer --- .../algorithms/preference_comparisons.py | 209 +++++++++--------- .../config/train_preference_comparisons.py | 7 +- .../scripts/train_preference_comparisons.py | 10 + 3 files changed, 123 insertions(+), 103 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 66edc1ba3..85be24760 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -812,6 +812,94 @@ def __call__(self, queries: Sequence[TrajectoryWithRewPair]) -> Dict[str, Sequen return {str(uuid.uuid4()): query for query in queries} +class PrefCollectQuerent(PreferenceQuerent): + """Sends queries to the PrefCollect interface.""" + + def __init__( + self, + pref_collect_address: str, + video_output_dir: AnyPath, + video_fps: str = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ): + super().__init__(custom_logger) + self.rng = rng + self.query_endpoint = pref_collect_address + "/preferences/query/" + self.video_output_dir = video_output_dir + self.frames_per_second = video_fps + + # Create video directory + os.makedirs(self.video_output_dir, exist_ok=True) + + def __call__(self, queries: Sequence[TrajectoryWithRewPair]) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + identified_queries = super().__call__(queries) + + # Save fragment videos and submit queries + for query_id, query in identified_queries.items(): + self._write_fragment_video(query[0], name=f"{query_id}-left") + self._write_fragment_video(query[1], name=f"{query_id}-right") + requests.put( + self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)} + ) + + return identified_queries + + def _write_fragment_video(self, fragment, name: str) -> None: + + output_file_name = os.path.join(self.video_output_dir, f'{name}.webm') + frame_shape = self._get_frame_shape(fragment) + video_writer = cv2.VideoWriter( + output_file_name, + cv2.VideoWriter_fourcc(*'VP90'), + self.frames_per_second, + frame_shape, + ) + + # Make videos from rendered observations if available + if "rendered_img" in fragment.infos[0]: + frames = [] + for i in range(len(fragment.infos)): + frame_info = fragment.infos[i]["rendered_img"] + # If path is provided load cached image + if isinstance(frame_info, AnyPath.__args__): + frame = np.load(frame_info) + elif isinstance(frame_info, np.ndarray): + frame = frame_info + frames.append(frame) + else: + frames = frames.obs + + for frame in frames: + # Transform to RGB frame if necessary + if frame.shape[-1] < 3: + missing_channels = 3 - frame.shape[-1] + frame = np.concatenate( + [frame] + missing_channels * [frame[..., -1][..., None]], axis=-1 + ) + video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + + video_writer.release() + + @staticmethod + def _get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: + if "rendered_img" in fragment.infos[0]: + rendered_img_info = fragment.infos[0]["rendered_img"] + # If path is provided load cached image + if isinstance(rendered_img_info, AnyPath.__args__): + single_frame = np.load(rendered_img_info) + else: + single_frame = rendered_img_info + else: + single_frame = np.array(fragment.obs[0]) + # Check whether obervations are image-like + if len(single_frame.shape) < 2: + raise ValueError("Observation must be an image, " + f"but shape {single_frame.shape} has too few dimensions!") + # Swap dimensions, because matrix and image dims are swapped + return single_frame.shape[1], single_frame.shape[0] + + class PreferenceGatherer(abc.ABC): """Base class for gathering preference comparisons between trajectory fragments.""" @@ -957,10 +1045,7 @@ class PrefCollectGatherer(PreferenceGatherer): def __init__( self, pref_collect_address: str, - video_output_dir: AnyPath, - video_fps: str = 20, wait_for_user: bool = True, - random_preferences: bool = False, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ) -> None: @@ -968,49 +1053,22 @@ def __init__( Args: pref_collect_address: Network address to PrefCollect instance. - video_output_dir: Path to where fragment videos are saved. - video_fps: Frames per second of the fragment videos. - random_preferences: Whether to gather random preferences (for debugging). + wait_for_user: Waits for user to input their preferences. rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ super().__init__(custom_logger) self.rng = rng - self.random_preferences = random_preferences self.query_endpoint = pref_collect_address + "/preferences/query/" - self.video_output_dir = video_output_dir - self.frames_per_second = video_fps self.pending_queries = {} self.wait_for_user = wait_for_user - # Create video directory - os.makedirs(self.video_output_dir, exist_ok=True) - - def __call__( - self, fragment_pairs: Sequence[TrajectoryPair] - ) -> Tuple[Sequence[TrajectoryPair], np.ndarray]: - - if self.random_preferences: - return fragment_pairs, self.rng.choice([0.0, 1.0, 0.5], size=len(fragment_pairs)) - - # Generate UUID for each query (fragment pair) - new_queries = {str(uuid.uuid4()): query for query in fragment_pairs} - - # Save fragment videos and submit queries - for query_id, query in new_queries.items(): - self._write_fragment_video(query[0], name=f"{query_id}-left") - self._write_fragment_video(query[1], name=f"{query_id}-right") - requests.put( - self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)} - ) + def __call__(self) -> Tuple[Sequence[TrajectoryPair], np.ndarray]: if self.wait_for_user: self.logger.log("Waiting for user to provide preferences. Press enter to continue.") input() - # Gather preferences for pending queries - self.pending_queries = {**self.pending_queries, **new_queries} - gathered_queries = [] gathered_preferences = [] @@ -1027,76 +1085,23 @@ def __call__( return gathered_queries, np.array(gathered_preferences, dtype=np.float32) - def _write_fragment_video(self, fragment, name: str) -> None: - - output_file_name = os.path.join(self.video_output_dir, f'{name}.webm') - frame_shape = self._get_frame_shape(fragment) - video_writer = cv2.VideoWriter( - output_file_name, - cv2.VideoWriter_fourcc(*'VP90'), - self.frames_per_second, - frame_shape, - ) - - # Make videos from rendered observations if available - if "rendered_img" in fragment.infos[0]: - frames = [] - for i in range(len(fragment.infos)): - frame_info = fragment.infos[i]["rendered_img"] - # If path is provided load cached image - if isinstance(frame_info, AnyPath.__args__): - frame = np.load(frame_info) - elif isinstance(frame_info, np.ndarray): - frame = frame_info - frames.append(frame) - else: - frames = frames.obs - - for frame in frames: - # Transform to RGB frame if necessary - if frame.shape[-1] < 3: - missing_channels = 3 - frame.shape[-1] - frame = np.concatenate( - [frame] + missing_channels * [frame[..., -1][..., None]], axis=-1 - ) - video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - - video_writer.release() - - @staticmethod - def _get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: - if "rendered_img" in fragment.infos[0]: - rendered_img_info = fragment.infos[0]["rendered_img"] - # If path is provided load cached image - if isinstance(rendered_img_info, AnyPath.__args__): - single_frame = np.load(rendered_img_info) - else: - single_frame = rendered_img_info - else: - single_frame = np.array(fragment.obs[0]) - # Check whether obervations are image-like - if len(single_frame.shape) < 2: - raise ValueError("Observation must be an image, " - f"but shape {single_frame.shape} has too few dimensions!") - # Swap dimensions, because matrix and image dims are swapped - return single_frame.shape[1], single_frame.shape[0] - def _gather_preference(self, query_id: str) -> float: answered_query = requests.get(self.query_endpoint + query_id).json() return answered_query["label"] - - def remove_rendered_images(self, trajectories: Sequence[TrajectoryWithRew]) -> None: - """Removes rendered images of the provided trajectories list.""" - for traj in trajectories: - for info in traj.infos: - try: - rendered_img_info = info["rendered_img"] - if isinstance(rendered_img_info, AnyPath.__args__): - os.remove(rendered_img_info) - elif isinstance(rendered_img_info, np.ndarray): - del info["rendered_img"] - except KeyError: - pass + + +def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: + """Removes rendered images of the provided trajectories list.""" + for traj in trajectories: + for info in traj.infos: + try: + rendered_img_info = info["rendered_img"] + if isinstance(rendered_img_info, AnyPath.__args__): + os.remove(rendered_img_info) + elif isinstance(rendered_img_info, np.ndarray): + del info["rendered_img"] + except KeyError: + pass class PreferenceDataset(data_th.Dataset): @@ -1826,8 +1831,7 @@ def __init__( if preference_querent: self.preference_querent = preference_querent else: - # TODO add querent to train script - #assert self.rng is not None + assert self.rng is not None self.preference_querent = PreferenceQuerent( custom_logger=self.logger, rng=self.rng, @@ -1919,6 +1923,9 @@ def train( # Gather fragment pairs (queries) for which preferences have been provided queries, preferences = self.preference_gatherer() + # Free up RAM or disk space from keeping rendered images + remove_rendered_images(trajectories) + self.dataset.push(queries, preferences) self.logger.log(f"Dataset now contains {len(self.dataset)} comparisons") # Skip training if dataset is empty diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 152a0927e..df5629670 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -69,11 +69,14 @@ def train_defaults(): def human_preferences(): gatherer_cls = preference_comparisons.PrefCollectGatherer gatherer_kwargs = dict( + pref_collect_address="http://127.0.0.1:8000", + wait_for_user=True, + ) + querent_cls = preference_comparisons.PrefCollectQuerent + querent_kwargs = dict( pref_collect_address="http://127.0.0.1:8000", video_output_dir="../pref-collect/videofiles", video_fps=20, - wait_for_user=True, - random_preferences=False, ) environment = dict( post_wrappers=dict( diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 64e2c8e98..122b14255 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -77,6 +77,8 @@ def train_preference_comparisons( reward_trainer_kwargs: Mapping[str, Any], gatherer_cls: Type[preference_comparisons.PreferenceGatherer], gatherer_kwargs: Mapping[str, Any], + querent_cls: Type[preference_comparisons.PreferenceQuerent], + querent_kwargs: Mapping[str, Any], active_selection: bool, active_selection_oversampling: int, uncertainty_on: str, @@ -121,6 +123,8 @@ def train_preference_comparisons( reward_trainer_kwargs: passed to BasicRewardTrainer or EnsembleRewardTrainer gatherer_cls: type of PreferenceGatherer to use (defaults to SyntheticGatherer) gatherer_kwargs: passed to the PreferenceGatherer specified by gatherer_cls + querent_cls: type of PreferenceQuerent to use (defaults to PreferenceQuerent) + querent_kwargs: passed to the PreferenceQuerent specified by querent_cls active_selection: use active selection fragmenter instead of random fragmenter active_selection_oversampling: factor by which to oversample random fragments from the base fragmenter of active selection. @@ -235,6 +239,11 @@ def train_preference_comparisons( rng=_rnd, custom_logger=custom_logger, ) + querent = querent_cls( + **querent_kwargs, + rng=_rnd, + custom_logger=custom_logger, + ) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -251,6 +260,7 @@ def train_preference_comparisons( num_iterations=num_iterations, fragmenter=fragmenter, preference_gatherer=gatherer, + preference_querent=querent, reward_trainer=reward_trainer, comparison_queue_size=comparison_queue_size, fragment_length=fragment_length, From 3883a9f04d189a0e583a9fc2b37dc230a88efa20 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 9 Mar 2023 17:00:29 +0100 Subject: [PATCH 011/143] Add PreferenceQuerent tests, one for PrefCollectQuerent --- setup.py | 1 + .../algorithms/preference_comparisons.py | 9 +- .../algorithms/test_preference_comparisons.py | 272 ++++++++++-------- 3 files changed, 160 insertions(+), 122 deletions(-) diff --git a/setup.py b/setup.py index 2b00dc26b..31d375e9e 100644 --- a/setup.py +++ b/setup.py @@ -208,6 +208,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "sacred>=0.8.4", "tensorboard>=1.14", "huggingface_sb3>=2.2.1", + "opencv-python", # TODO: specify version ], tests_require=TESTS_REQUIRE, extras_require={ diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 85be24760..228452e96 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -839,12 +839,15 @@ def __call__(self, queries: Sequence[TrajectoryWithRewPair]) -> Dict[str, Sequen for query_id, query in identified_queries.items(): self._write_fragment_video(query[0], name=f"{query_id}-left") self._write_fragment_video(query[1], name=f"{query_id}-right") - requests.put( - self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)} - ) + self._query(query_id) return identified_queries + def _query(self, query_id): + requests.put( + self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)} + ) + def _write_fragment_video(self, fragment, name: str) -> None: output_file_name = os.path.join(self.video_output_dir, f'{name}.webm') diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 12727c1c9..c6513c0d3 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -2,7 +2,9 @@ import math import re +import uuid from typing import Any, Sequence +from unittest.mock import Mock import gym import numpy as np @@ -11,6 +13,7 @@ import stable_baselines3 import torch as th from gym import spaces +from imitation.algorithms.preference_comparisons import PreferenceQuerent, PrefCollectQuerent from stable_baselines3.common import evaluation from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.vec_env import DummyVecEnv @@ -88,8 +91,8 @@ def check_possibly_nested_dicts_equal(dict1, dict2): def _check_trajs_equal( - trajs1: Sequence[types.TrajectoryWithRew], - trajs2: Sequence[types.TrajectoryWithRew], + trajs1: Sequence[types.TrajectoryWithRew], + trajs2: Sequence[types.TrajectoryWithRew], ): assert len(trajs1) == len(trajs2) for traj1, traj2 in zip(trajs1, trajs2): @@ -113,8 +116,8 @@ def test_mismatched_spaces(venv, agent, rng): other_venv.action_space, ) with pytest.raises( - ValueError, - match="spaces do not match", + ValueError, + match="spaces do not match", ): preference_comparisons.AgentTrainer( agent, @@ -125,8 +128,8 @@ def test_mismatched_spaces(venv, agent, rng): def test_trajectory_dataset_seeding( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - num_samples: int = 400, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + num_samples: int = 400, ): dataset1 = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, @@ -153,9 +156,9 @@ def test_trajectory_dataset_seeding( # CartPole max episode length is 200 @pytest.mark.parametrize("num_steps", [0, 199, 200, 201, 400]) def test_trajectory_dataset_len( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - num_steps: int, - rng, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + num_steps: int, + rng, ): dataset = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, @@ -169,8 +172,8 @@ def test_trajectory_dataset_len( def test_trajectory_dataset_too_long( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - rng, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + rng, ): dataset = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, @@ -181,9 +184,9 @@ def test_trajectory_dataset_too_long( def test_trajectory_dataset_not_static( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - rng, - num_steps: int = 400, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + rng, + num_steps: int = 400, ): """Tests sample() doesn't always return the same value.""" dataset = preference_comparisons.TrajectoryDataset( @@ -207,27 +210,27 @@ def test_transitions_left_in_buffer(agent_trainer): # with transitions. agent_trainer.buffering_wrapper.n_transitions = 2 with pytest.raises( - RuntimeError, - match=re.escape( - "There are 2 transitions left in the buffer. " - "Call AgentTrainer.sample() first to clear them.", - ), + RuntimeError, + match=re.escape( + "There are 2 transitions left in the buffer. " + "Call AgentTrainer.sample() first to clear them.", + ), ): agent_trainer.train(steps=1) @pytest.mark.parametrize( "schedule", - ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], + ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t ** 3)], ) def test_preference_comparisons_raises( - agent_trainer, - reward_net, - random_fragmenter, - preference_model, - custom_logger, - schedule, - rng, + agent_trainer, + reward_net, + random_fragmenter, + preference_model, + custom_logger, + schedule, + rng, ): loss = preference_comparisons.CrossEntropyRewardLoss() reward_trainer = preference_comparisons.BasicRewardTrainer( @@ -291,15 +294,15 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): @pytest.mark.parametrize( "schedule", - ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], + ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t ** 3)], ) def test_trainer_no_crash( - agent_trainer, - reward_net, - random_fragmenter, - custom_logger, - schedule, - rng, + agent_trainer, + reward_net, + random_fragmenter, + custom_logger, + schedule, + rng, ): main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, @@ -331,8 +334,8 @@ def test_reward_ensemble_trainer_raises_type_error(venv, rng): loss = preference_comparisons.CrossEntropyRewardLoss() with pytest.raises( - TypeError, - match=r"PreferenceModel of a RewardEnsemble expected by EnsembleTrainer.", + TypeError, + match=r"PreferenceModel of a RewardEnsemble expected by EnsembleTrainer.", ): preference_comparisons.EnsembleTrainer( preference_model, @@ -342,11 +345,11 @@ def test_reward_ensemble_trainer_raises_type_error(venv, rng): def test_correct_reward_trainer_used_by_default( - agent_trainer, - reward_net, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + reward_net, + random_fragmenter, + custom_logger, + rng, ): main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, @@ -373,11 +376,11 @@ def test_correct_reward_trainer_used_by_default( def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + rng, ): reward_net = testing_reward_nets.make_ensemble( venv.observation_space, @@ -389,8 +392,8 @@ def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( r"AddSTDRewardWrapper but found NormalizedRewardNet." ) with pytest.raises( - ValueError, - match=rgx, + ValueError, + match=rgx, ): preference_comparisons.PreferenceComparisons( agent_trainer, @@ -405,11 +408,11 @@ def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( def test_discount_rate_no_crash( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + rng, ): # also use a non-zero noise probability to check that doesn't cause errors reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) @@ -441,10 +444,10 @@ def test_discount_rate_no_crash( def create_reward_trainer( - venv, - seed: int, - batch_size: int, - **kwargs: Any, + venv, + seed: int, + batch_size: int, + **kwargs: Any, ): th.manual_seed(seed) reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) @@ -462,11 +465,11 @@ def create_reward_trainer( def test_gradient_accumulation( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + rng, ): # Test that training steps on the same dataset with different minibatch sizes # result in the same reward network. @@ -484,7 +487,7 @@ def test_gradient_accumulation( preferences = preference_gatherer(fragments) dataset.push(fragments, preferences) - seed = rng.integers(2**32) + seed = rng.integers(2 ** 32) reward_trainer1, reward_net1 = create_reward_trainer(venv, seed, batch_size) reward_trainer2, reward_net2 = create_reward_trainer( venv, @@ -495,7 +498,7 @@ def test_gradient_accumulation( for step in range(8): print("Step", step) - seed = rng.integers(2**32) + seed = rng.integers(2 ** 32) th.manual_seed(seed) reward_trainer1.train(dataset) @@ -515,9 +518,9 @@ def test_gradient_accumulation( def test_synthetic_gatherer_deterministic( - agent_trainer, - random_fragmenter, - rng, + agent_trainer, + random_fragmenter, + rng, ): gatherer = preference_comparisons.SyntheticGatherer( temperature=0, @@ -531,12 +534,12 @@ def test_synthetic_gatherer_deterministic( def test_synthetic_gatherer_raises( - agent_trainer, - random_fragmenter, + agent_trainer, + random_fragmenter, ): with pytest.raises( - ValueError, - match="If `sample` is True, then `rng` must be provided", + ValueError, + match="If `sample` is True, then `rng` must be provided", ): preference_comparisons.SyntheticGatherer( temperature=0, @@ -574,8 +577,8 @@ def test_fragments_too_short_error(agent_trainer): warning_threshold=0, ) with pytest.raises( - ValueError, - match="No trajectories are long enough for the desired fragment length.", + ValueError, + match="No trajectories are long enough for the desired fragment length.", ): # the only important bit is that fragment_length is higher than # we'll ever reach @@ -597,6 +600,7 @@ def test_preference_dataset_errors(agent_trainer, random_fragmenter): dataset.push(fragments, preferences) +# TODO: update test def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): dataset = preference_comparisons.PreferenceDataset(max_size=5) trajectories = agent_trainer.sample(10) @@ -614,10 +618,10 @@ def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): def test_store_and_load_preference_dataset( - agent_trainer, - random_fragmenter, - tmp_path, - rng, + agent_trainer, + random_fragmenter, + tmp_path, + rng, ): dataset = preference_comparisons.PreferenceDataset() trajectories = agent_trainer.sample(10) @@ -639,12 +643,12 @@ def test_store_and_load_preference_dataset( def test_exploration_no_crash( - agent, - reward_net, - venv, - random_fragmenter, - custom_logger, - rng, + agent, + reward_net, + venv, + random_fragmenter, + custom_logger, + rng, ): agent_trainer = preference_comparisons.AgentTrainer( agent, @@ -668,12 +672,12 @@ def test_exploration_no_crash( @pytest.mark.parametrize("uncertainty_on", UNCERTAINTY_ON) def test_active_fragmenter_discount_rate_no_crash( - agent_trainer, - venv, - random_fragmenter, - uncertainty_on, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + uncertainty_on, + custom_logger, + rng, ): # also use a non-zero noise probability to check that doesn't cause errors reward_net = reward_nets.RewardEnsemble( @@ -736,13 +740,13 @@ def interval_param_scaler() -> updaters.IntervalParamScaler: def test_reward_trainer_regularization_no_crash( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - preference_model, - interval_param_scaler, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + preference_model, + interval_param_scaler, + rng, ): reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -776,13 +780,13 @@ def test_reward_trainer_regularization_no_crash( def test_reward_trainer_regularization_raises( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - preference_model, - interval_param_scaler, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + preference_model, + interval_param_scaler, + rng, ): reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -813,8 +817,8 @@ def test_reward_trainer_regularization_raises( rng=rng, ) with pytest.raises( - ValueError, - match="Not enough data samples to split " "into training and validation.*", + ValueError, + match="Not enough data samples to split " "into training and validation.*", ): main_trainer.train(100, 10) @@ -849,8 +853,8 @@ def preference_model(venv) -> preference_comparisons.PreferenceModel: def test_active_fragmenter_uncertainty_on_not_supported_error( - ensemble_preference_model, - random_fragmenter, + ensemble_preference_model, + random_fragmenter, ): re_match = r".* not supported\.\n\s+`uncertainty_on` should be from .*" with pytest.raises(ValueError, match=re_match): @@ -874,12 +878,12 @@ def test_active_fragmenter_uncertainty_on_not_supported_error( def test_active_selection_raises_error_when_initialized_without_an_ensemble( - preference_model, - random_fragmenter, + preference_model, + random_fragmenter, ): with pytest.raises( - ValueError, - match=r"PreferenceModel not wrapped over an ensemble.*", + ValueError, + match=r"PreferenceModel not wrapped over an ensemble.*", ): preference_comparisons.ActiveSelectionFragmenter( preference_model=preference_model, @@ -1032,12 +1036,12 @@ def ensemble_reward_trainer(venv, rng): [basic_reward_trainer, ensemble_reward_trainer], ) def test_that_trainer_improves( - action_is_reward_venv, - action_is_reward_agent, - action_is_reward_trainer_func, - random_fragmenter, - custom_logger, - rng, + action_is_reward_venv, + action_is_reward_agent, + action_is_reward_trainer_func, + random_fragmenter, + custom_logger, + rng, ): """Tests that training improves performance of the reward network and agent.""" action_is_reward_trainer = action_is_reward_trainer_func(action_is_reward_venv, rng) @@ -1075,8 +1079,8 @@ def test_that_trainer_improves( later_reward_network_stats = main_trainer.train(1000, 20) assert ( - first_reward_network_stats["reward_loss"] - > later_reward_network_stats["reward_loss"] + first_reward_network_stats["reward_loss"] + > later_reward_network_stats["reward_loss"] ) # The agent should have also improved @@ -1088,3 +1092,33 @@ def test_that_trainer_improves( ) assert np.mean(trained_agent_rewards) > np.mean(novice_agent_rewards) + + +def test_returns_query_dict_from_query_sequence_with_correct_length(): + querent = PreferenceQuerent() + query_sequence = [Mock()] + query_dict = querent(query_sequence) + assert len(query_dict) == len(query_sequence) + + +def test_returned_queries_have_uuid(): + querent = PreferenceQuerent() + query_dict = querent([Mock()]) + + try: + key = list(query_dict.keys())[0] + uuid.UUID(key, version=4) + except ValueError: + pytest.fail() + + +def test_sends_put_request_for_each_query(requests_mock): + address = "https://test.de" + querent = PrefCollectQuerent(pref_collect_address=address, video_output_dir="video") + query_id = "1234" + + requests_mock.put(f"{address}/preferences/query/{query_id}") + querent._query(query_id) + + assert requests_mock.last_request.method == "PUT" + assert requests_mock.last_request.text == f'{{"uuid": "{query_id}"}}' From 05f1da7a7e026ef4b83ff86bc64296e65aa6c2e7 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 16 Mar 2023 17:19:24 +0100 Subject: [PATCH 012/143] Correct method signature --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 228452e96..13dd203b0 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -941,7 +941,7 @@ def __call__(self) -> Tuple[np.ndarray, np.ndarray]: probabilities. """ # noqa: DAR202 - def add(self, new_queries: Dict[str, Sequence[TrajectoryWithRewPair]]) -> None: + def add(self, new_queries: Dict[str, TrajectoryWithRewPair]) -> None: """Adds queries to pending queries. Args: From 7ba60c384dee416f90abc47125dc1a3e9e233963 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 16 Mar 2023 17:20:20 +0100 Subject: [PATCH 013/143] Test PreferenceGatherer and partially test SyntheticGatherer --- .../algorithms/test_preference_comparisons.py | 48 ++++++++++++++++++- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index c6513c0d3..04daf70c9 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -3,7 +3,7 @@ import math import re import uuid -from typing import Any, Sequence +from typing import Any, Sequence, Tuple from unittest.mock import Mock import gym @@ -13,13 +13,14 @@ import stable_baselines3 import torch as th from gym import spaces -from imitation.algorithms.preference_comparisons import PreferenceQuerent, PrefCollectQuerent from stable_baselines3.common import evaluation from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.vec_env import DummyVecEnv import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons +from imitation.algorithms.preference_comparisons import PreferenceQuerent, PrefCollectQuerent, PreferenceGatherer, \ + SyntheticGatherer from imitation.data import types from imitation.data.types import TrajectoryWithRew from imitation.regularization import regularizers, updaters @@ -75,6 +76,23 @@ def agent_trainer(agent, reward_net, venv, rng): return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng) +# TODO: trajectory_with_rew fixture already exists in data.test_types, should be moved to a conftest.py +@pytest.fixture +def trajectory_with_rew(venv): + observations, rewards, dones, infos, actions = [], [], [], [], [] + observations.append(venv.observation_space.sample()) + for _ in range(2): + observations.append(venv.observation_space.sample()) + actions.append(venv.action_space.sample()) + rewards.append(0.0) + infos.append({}) + return TrajectoryWithRew(obs=np.array(observations), + acts=np.array(actions), + rews=np.array(rewards), + infos=np.array(infos), + terminal=False) + + def assert_info_arrs_equal(arr1, arr2): # pragma: no cover def check_possibly_nested_dicts_equal(dict1, dict2): for key, val1 in dict1.items(): @@ -1122,3 +1140,29 @@ def test_sends_put_request_for_each_query(requests_mock): assert requests_mock.last_request.method == "PUT" assert requests_mock.last_request.text == f'{{"uuid": "{query_id}"}}' + + +class ConcretePreferenceGatherer(PreferenceGatherer): + + def __call__(self) -> Tuple[np.ndarray, np.ndarray]: + pass + + +def test_adds_queries_to_pending_queries(): + gatherer = ConcretePreferenceGatherer() + query_id = "id" + queries = {query_id: Mock()} + + gatherer.add(new_queries=queries) + assert query_id in list(gatherer.pending_queries.keys()) + + +def test_clears_pending_queries(trajectory_with_rew): + gatherer = SyntheticGatherer(sample=False) + + queries = {"id": (trajectory_with_rew, trajectory_with_rew)} + gatherer.add(new_queries=queries) + + gatherer() + + assert len(gatherer.pending_queries) == 0 From 65acb037349cc15e6684fd660cda896659755e0d Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Mon, 17 Apr 2023 17:30:46 +0200 Subject: [PATCH 014/143] Fix bug --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 13dd203b0..29f88880c 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -871,7 +871,7 @@ def _write_fragment_video(self, fragment, name: str) -> None: frame = frame_info frames.append(frame) else: - frames = frames.obs + frames = fragment.obs for frame in frames: # Transform to RGB frame if necessary From 2185a1e7a1e95f3bd0a7b7d53a32cd16ca57fc63 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Mon, 17 Apr 2023 18:06:33 +0200 Subject: [PATCH 015/143] Add gather preferences tests --- .../algorithms/preference_comparisons.py | 3 +- .../algorithms/test_preference_comparisons.py | 31 ++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 29f88880c..0acd2aef0 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1068,8 +1068,9 @@ def __init__( def __call__(self) -> Tuple[Sequence[TrajectoryPair], np.ndarray]: + # TODO: create user-independent (automated) waiting policy if self.wait_for_user: - self.logger.log("Waiting for user to provide preferences. Press enter to continue.") + print("Waiting for user to provide preferences. Press enter to continue.") input() gathered_queries = [] diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 04daf70c9..c1fc2bc9e 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -20,7 +20,7 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons from imitation.algorithms.preference_comparisons import PreferenceQuerent, PrefCollectQuerent, PreferenceGatherer, \ - SyntheticGatherer + SyntheticGatherer, PrefCollectGatherer from imitation.data import types from imitation.data.types import TrajectoryWithRew from imitation.regularization import regularizers, updaters @@ -1166,3 +1166,32 @@ def test_clears_pending_queries(trajectory_with_rew): gatherer() assert len(gatherer.pending_queries) == 0 + + +def test_returns_none_for_unanswered_query(requests_mock): + address = "https://test.de" + query_id = "1234" + answer = None + + gatherer = PrefCollectGatherer(pref_collect_address=address) + + requests_mock.get(f"{address}/preferences/query/{query_id}", json={"query_id": query_id, "label": answer}) + + preference = gatherer._gather_preference(query_id) + + assert preference is answer + + +def test_returns_preference_for_answered_query(requests_mock): + address = "https://test.de" + query_id = "1234" + answer = 1.0 + + gatherer = PrefCollectGatherer(pref_collect_address=address) + + requests_mock.get(f"{address}/preferences/query/{query_id}", json={"query_id": query_id, "label": answer}) + + preference = gatherer._gather_preference(query_id) + + assert preference == answer + From f7b22f24e78fa0e5a02b16164e3a38777f7cf6aa Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 28 Apr 2023 16:13:31 +0200 Subject: [PATCH 016/143] Add pref collect gatherer tests --- .../algorithms/test_preference_comparisons.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index c1fc2bc9e..b40f8b6c7 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -4,7 +4,7 @@ import re import uuid from typing import Any, Sequence, Tuple -from unittest.mock import Mock +from unittest.mock import Mock, MagicMock import gym import numpy as np @@ -1195,3 +1195,26 @@ def test_returns_preference_for_answered_query(requests_mock): assert preference == answer + +def test_keeps_pending_query_for_unanswered_query(): + gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + gatherer._gather_preference = MagicMock(return_value=None) + gatherer.pending_queries = {"1234": Mock()} + + pending_queries_pre = gatherer.pending_queries.copy() + gatherer() + + assert pending_queries_pre == gatherer.pending_queries + + +def test_delete_pending_query_for_answered_query(): + gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + gatherer._gather_preferences = MagicMock(return_value=None) + + pending_queries_pre = gatherer.pending_queries.copy() + gatherer() + + assert pending_queries_pre == gatherer.pending_queries + + + From c0e884d1a4f2308c7dd6aebc62428ab0bb186c93 Mon Sep 17 00:00:00 2001 From: rk1a Date: Fri, 28 Apr 2023 17:22:26 +0200 Subject: [PATCH 017/143] Add PrefCollectGatherer tests --- .../algorithms/test_preference_comparisons.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index b40f8b6c7..e83c9fafa 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1207,14 +1207,37 @@ def test_keeps_pending_query_for_unanswered_query(): assert pending_queries_pre == gatherer.pending_queries -def test_delete_pending_query_for_answered_query(): +def test_deletes_pending_query_for_answered_query(): gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) - gatherer._gather_preferences = MagicMock(return_value=None) + preference = 0.5 + gatherer._gather_preference = MagicMock(return_value=preference) + gatherer.pending_queries = {"1234": Mock()} - pending_queries_pre = gatherer.pending_queries.copy() gatherer() - assert pending_queries_pre == gatherer.pending_queries + assert len(gatherer.pending_queries) == 0 + + +def test_gathers_valid_preference(): + gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + preference = 0.5 + gatherer._gather_preference = MagicMock(return_value=preference) + query = Mock() + gatherer.pending_queries = {"1234": query} + + gathered_queries, gathered_preferences = gatherer() + + assert gathered_preferences[0] == preference + assert gathered_queries[0] == query +def test_ignores_incomparable_answer(): + gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + # incomparable preference value = -1 + gatherer._gather_preference = MagicMock(return_value=-1.) + gatherer.pending_queries = {"1234": Mock()} + + gathered_queries, gathered_preferences = gatherer() + assert len(gathered_preferences) == 0 + assert len(gathered_queries) == 0 From fa587e5fea5e71a3aedf96f4f3bc551dd52eb6f5 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 5 May 2023 17:14:50 +0200 Subject: [PATCH 018/143] Add todos --- .../algorithms/preference_comparisons.py | 1 + .../algorithms/test_preference_comparisons.py | 31 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 0acd2aef0..32e6e5c24 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1793,6 +1793,7 @@ def __init__( reward_trainer, ) + # TODO: update messages with preference querent if self.rng is None and has_any_rng_args_none: raise ValueError( "If you don't provide a random state, you must provide your own " diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index b40f8b6c7..2e73e0a66 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -256,14 +256,15 @@ def test_preference_comparisons_raises( loss, rng=rng, ) + # TODO: also instantiate querent gatherer = preference_comparisons.SyntheticGatherer(rng=rng) - # no rng, must provide fragmenter, preference gatherer, reward trainer + # no rng, must provide fragmenter, preference gatherer, preference_querent, reward trainer no_rng_msg = ( ".*don't provide.*random state.*provide.*fragmenter" ".*preference gatherer.*reward_trainer.*" ) - def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): + def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): preference_comparisons.PreferenceComparisons( agent_trainer, reward_net, @@ -278,16 +279,17 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): ) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparsions(gatherer, None, None, rng=None) + build_preference_comparisons(gatherer, None, None, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparsions(None, reward_trainer, None, rng=None) + build_preference_comparisons(None, reward_trainer, None, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparsions(None, None, random_fragmenter, rng=None) + build_preference_comparisons(None, None, random_fragmenter, rng=None) + # TODO: this raises because querent not passed and has to be instantiated # This should not raise - build_preference_comparsions(gatherer, reward_trainer, random_fragmenter, rng=None) + build_preference_comparisons(gatherer, reward_trainer, random_fragmenter, rng=None) # if providing fragmenter, preference gatherer, reward trainer, does not need rng. with_rng_msg = ( @@ -296,7 +298,7 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): ) with pytest.raises(ValueError, match=with_rng_msg): - build_preference_comparsions( + build_preference_comparisons( gatherer, reward_trainer, random_fragmenter, @@ -304,10 +306,10 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): ) # This should not raise - build_preference_comparsions(None, None, None, rng=rng) - build_preference_comparsions(gatherer, None, None, rng=rng) - build_preference_comparsions(None, reward_trainer, None, rng=rng) - build_preference_comparsions(None, None, random_fragmenter, rng=rng) + build_preference_comparisons(None, None, None, rng=rng) + build_preference_comparisons(gatherer, None, None, rng=rng) + build_preference_comparisons(None, reward_trainer, None, rng=rng) + build_preference_comparisons(None, None, random_fragmenter, rng=rng) @pytest.mark.parametrize( @@ -495,6 +497,7 @@ def test_gradient_accumulation( minibatch_size = 3 num_trajectories = 5 + # TODO: create querent preference_gatherer = preference_comparisons.SyntheticGatherer( custom_logger=custom_logger, rng=rng, @@ -502,6 +505,7 @@ def test_gradient_accumulation( dataset = preference_comparisons.PreferenceDataset() trajectory = agent_trainer.sample(num_trajectories) fragments = random_fragmenter(trajectory, 1, num_trajectories) + # TODO: call querent with fragments (= queries) preferences = preference_gatherer(fragments) dataset.push(fragments, preferences) @@ -534,7 +538,7 @@ def test_gradient_accumulation( for p1, p2 in zip(reward_net1.parameters(), reward_net2.parameters()): th.testing.assert_close(p1, p2, atol=atol, rtol=rtol) - +# TODO: fix test (same as test_gradient_accumulation) def test_synthetic_gatherer_deterministic( agent_trainer, random_fragmenter, @@ -618,7 +622,7 @@ def test_preference_dataset_errors(agent_trainer, random_fragmenter): dataset.push(fragments, preferences) -# TODO: update test +# TODO: fix test (same as test_gradient_accumulation) def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): dataset = preference_comparisons.PreferenceDataset(max_size=5) trajectories = agent_trainer.sample(10) @@ -635,6 +639,7 @@ def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): assert len(dataset) == 5 +# TODO: fix test (same as test_gradient_accumulation) def test_store_and_load_preference_dataset( agent_trainer, random_fragmenter, From 9208b89c29320c19b25cecfc692812863603f4e5 Mon Sep 17 00:00:00 2001 From: timbauman Date: Wed, 10 May 2023 17:14:29 -0700 Subject: [PATCH 019/143] Add support for a simple preference UI --- .../algorithms/preference_comparisons.py | 154 ++++++++++++++++++ .../config/train_preference_comparisons.py | 4 + src/imitation/scripts/eval_policy.py | 17 +- .../scripts/train_preference_comparisons.py | 14 +- src/imitation/util/video_wrapper.py | 28 +++- 5 files changed, 200 insertions(+), 17 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 413cd979a..75220a095 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -5,6 +5,7 @@ """ import abc import math +import pathlib import pickle import re from collections import defaultdict @@ -24,6 +25,7 @@ overload, ) +import cv2 import numpy as np import torch as th from scipy import special @@ -906,6 +908,158 @@ def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]: return np.array(rews1, dtype=np.float32), np.array(rews2, dtype=np.float32) +class SynchronousCLIGatherer(PreferenceGatherer): + """Queries for human preferences using the command line interface.""" + + def __init__( + self, + video_dir: pathlib.Path, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ) -> None: + """Initialize the human preference gatherer. + + Args: + video_dir: directory where videos of the trajectories are saved. + custom_logger: Where to log to; if None (default), creates a new logger. + """ + super().__init__(custom_logger=custom_logger) + self.video_dir = video_dir + + def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: + """Displays each pair of fragments and asks for a preference. + + It iteratively requests user feedback for each pair of fragments. If in the + command line, it will pop out a video player for each fragment. If in a + notebook, it will display the videos. Either way, it will request 1 or 2 to + indicate which is preferred. + + Args: + fragment_pairs: sequence of pairs of trajectory fragments + + Returns: + A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape + (b, ), where b is the length of the input + """ + + preferences = np.zeros(len(fragment_pairs), dtype=np.float32) + for i, (frag1, frag2) in enumerate(fragment_pairs): + if self._display_videos(frag1, frag2): + preferences[i] = 1 + return preferences + + def _display_videos( + self, frag1: TrajectoryWithRew, frag2: TrajectoryWithRew + ) -> bool: + """Displays the videos of the two fragments. + + Args: + frag1: first fragment + frag2: second fragment + """ + # display the videos + frag1_video_path = frag1.infos[0]["video_path"] + frag2_video_path = frag2.infos[0]["video_path"] + if self._in_ipython(): + self._display_videos_in_notebook(frag1_video_path, frag2_video_path) + + pref = input( + "Which video is preferred? (1 or 2, or q to quit, or r to replay): " + ) + while pref not in ["1", "2", "q"]: + if pref == "r": + self._display_videos_in_notebook(frag1_video_path, frag2_video_path) + pref = input("Please enter 1 or 2 or q or r: ") + + if pref == "q": + raise KeyboardInterrupt + elif pref == "1": + return True + elif pref == "2": + return False + + # should never be hit + raise ValueError(f"Unexpected input {pref}") + else: + print("Which video is preferred? (1 or 2, or q to quit, or r to replay):\n") + cap1 = cv2.VideoCapture(str(frag1_video_path)) + cap2 = cv2.VideoCapture(str(frag2_video_path)) + cv2.namedWindow("Video 1", cv2.WINDOW_NORMAL) + cv2.namedWindow("Video 2", cv2.WINDOW_NORMAL) + + # set window sizes + cv2.resizeWindow("Video 1", 500, 500) + cv2.resizeWindow("Video 2", 500, 500) + + # move windows side by side + cv2.moveWindow("Video 1", 0, 0) + cv2.moveWindow("Video 2", 500, 0) + + if not cap1.isOpened(): + raise RuntimeError(f"Error opening video file {frag1_video_path}.") + + if not cap2.isOpened(): + raise RuntimeError(f"Error opening video file {frag2_video_path}.") + + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + while cap1.isOpened() and cap2.isOpened(): + if ret1 or ret2: + cv2.imshow("Video 1", frame1) + cv2.imshow("Video 2", frame2) + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + + key = chr(cv2.waitKey(1) & 0xFF) + if key == "q": + raise KeyboardInterrupt + elif key == "r": + cap1.set(cv2.CAP_PROP_POS_FRAMES, 0) + cap2.set(cv2.CAP_PROP_POS_FRAMES, 0) + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + elif key == "1" or key == "2": + cap1.release() + cap2.release() + cv2.destroyAllWindows() + return key == "1" + + cap1.release() + cap2.release() + cv2.destroyAllWindows() + raise KeyboardInterrupt + + def _display_videos_in_notebook( + self, frag1_video_path: pathlib.Path, frag2_video_path: pathlib.Path + ) -> None: + from IPython.display import HTML, Video, clear_output, display + + display(HTML("

Video 1

")) + display( + Video( + filename=str(frag1_video_path), + height=500, + width=500, + html_attributes="controls autoplay muted", + ) + ) + display(HTML("

Video 2

")) + display( + Video( + filename=str(frag2_video_path), + height=500, + width=500, + html_attributes="controls autoplay muted", + ) + ) + clear_output(wait=True) + + def _in_ipython(self) -> bool: + try: + return get_ipython().__class__.__name__ == "ZMQInteractiveShell" + except NameError: + return False + + class PreferenceDataset(data_th.Dataset): """A PyTorch Dataset for preference comparisons. diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 28890bf33..487e6bd23 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -63,6 +63,10 @@ def train_defaults(): checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only) query_schedule = "hyperbolic" + # If set, save trajectory videos to this directory. Must be present if gather_cls is + # SynchronousCLIGatherer + video_log_dir = None + @train_preference_comparisons_ex.named_config def cartpole(): diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index fb1efb5c9..166fb53e9 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -41,17 +41,6 @@ def step_wait(self): return ob -def video_wrapper_factory(log_dir: pathlib.Path, **kwargs): - """Returns a function that wraps the environment in a video recorder.""" - - def f(env: gym.Env, i: int) -> video_wrapper.VideoWrapper: - """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" - directory = log_dir / "videos" / str(i) - return video_wrapper.VideoWrapper(env, directory=directory, **kwargs) - - return f - - @eval_policy_ex.main def eval_policy( eval_n_timesteps: Optional[int], @@ -94,7 +83,11 @@ def eval_policy( """ log_dir = logging_ingredient.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) - post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None + post_wrappers = ( + [video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)] + if videos + else None + ) with environment.make_venv(post_wrappers=post_wrappers) as venv: if render: venv = InteractiveRender(venv, render_fps) diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 79ee4c136..96c255b41 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -23,6 +23,7 @@ from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, reward from imitation.scripts.ingredients import rl as rl_common +from imitation.util import video_wrapper def save_model( @@ -84,6 +85,7 @@ def train_preference_comparisons( allow_variable_horizon: bool, checkpoint_interval: int, query_schedule: Union[str, type_aliases.Schedule], + video_log_dir: Optional[str], _rnd: np.random.Generator, ) -> Mapping[str, Any]: """Train a reward model using preference comparisons. @@ -144,6 +146,7 @@ def train_preference_comparisons( be allocated to each iteration. "hyperbolic" and "inverse_quadratic" apportion fewer queries to later iterations when the policy is assumed to be better and more stable. + video_log_dir: If set, save videos to this directory. _rnd: Random number generator provided by Sacred. Returns: @@ -166,7 +169,16 @@ def train_preference_comparisons( custom_logger, log_dir = logging_ingredient.setup_logging() - with environment.make_venv() as venv: + wrappers = [] + if video_log_dir is not None: + wrappers.append( + video_wrapper.video_wrapper_factory( + log_dir=pathlib.Path(video_log_dir), + single_video=False, + ), + ) + + with environment.make_venv(post_wrappers=wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index a59641aa1..ee1402fc6 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -7,6 +7,17 @@ from gym.wrappers.monitoring import video_recorder +def video_wrapper_factory(log_dir: pathlib.Path, **kwargs): + """Returns a function that wraps the environment in a video recorder.""" + + def f(env: gym.Env, i: int) -> VideoWrapper: + """Wraps `env` in a recorder saving videos to `{log_dir}/videos/{i}`.""" + directory = log_dir / "videos" / str(i) + return VideoWrapper(env, directory=directory, **kwargs) + + return f + + class VideoWrapper(gym.Wrapper): """Creates videos from wrapped environment by calling render after each timestep.""" @@ -61,18 +72,27 @@ def _reset_video_recorder(self) -> None: metadata={"episode_id": self.episode_id}, ) - def reset(self): + def reset(self, **kwargs): + new_obs = super().reset(**kwargs) self._reset_video_recorder() self.episode_id += 1 - return self.env.reset() + return new_obs def step(self, action): - res = self.env.step(action) + obs, rew, done, info = self.env.step(action) self.video_recorder.capture_frame() - return res + # is it crazy to save the video path at every step? + info["video_path"] = self.get_current_video_path() + return obs, rew, done, info def close(self) -> None: if self.video_recorder is not None: self.video_recorder.close() self.video_recorder = None super().close() + + def get_current_video_path(self) -> Optional[pathlib.Path]: + """Returns the path to the current video file, or None if no video is active.""" + if self.video_recorder is None: + return None + return pathlib.Path(self.video_recorder.path) From 97c21713beac3f0c7afeb9f287ea431ba063e09e Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 09:53:28 -0700 Subject: [PATCH 020/143] style --- .../algorithms/preference_comparisons.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 75220a095..de89ab45c 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -940,7 +940,6 @@ def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarra A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape (b, ), where b is the length of the input """ - preferences = np.zeros(len(fragment_pairs), dtype=np.float32) for i, (frag1, frag2) in enumerate(fragment_pairs): if self._display_videos(frag1, frag2): @@ -948,22 +947,34 @@ def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarra return preferences def _display_videos( - self, frag1: TrajectoryWithRew, frag2: TrajectoryWithRew + self, frag1: TrajectoryWithRew, frag2: TrajectoryWithRew, ) -> bool: """Displays the videos of the two fragments. Args: frag1: first fragment frag2: second fragment + + Returns: + True if the first fragment is preferred, False if not. + + Raises: + KeyboardInterrupt: if the user presses q to quit. + RuntimeError: if the video files cannot be opened. + ValueError: if the trajectory infos are not set. """ # display the videos + if frag1.infos is None or frag2.infos is None: + raise ValueError( + "TrajectoryWithRew.infos must be set to display videos.", + ) frag1_video_path = frag1.infos[0]["video_path"] frag2_video_path = frag2.infos[0]["video_path"] if self._in_ipython(): self._display_videos_in_notebook(frag1_video_path, frag2_video_path) pref = input( - "Which video is preferred? (1 or 2, or q to quit, or r to replay): " + "Which video is preferred? (1 or 2, or q to quit, or r to replay): ", ) while pref not in ["1", "2", "q"]: if pref == "r": @@ -978,7 +989,7 @@ def _display_videos( return False # should never be hit - raise ValueError(f"Unexpected input {pref}") + assert False else: print("Which video is preferred? (1 or 2, or q to quit, or r to replay):\n") cap1 = cv2.VideoCapture(str(frag1_video_path)) @@ -1029,7 +1040,7 @@ def _display_videos( raise KeyboardInterrupt def _display_videos_in_notebook( - self, frag1_video_path: pathlib.Path, frag2_video_path: pathlib.Path + self, frag1_video_path: pathlib.Path, frag2_video_path: pathlib.Path, ) -> None: from IPython.display import HTML, Video, clear_output, display @@ -1055,7 +1066,7 @@ def _display_videos_in_notebook( def _in_ipython(self) -> bool: try: - return get_ipython().__class__.__name__ == "ZMQInteractiveShell" + return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[attr-defined] except NameError: return False From 5182acc1aa565e423f7a269690bbd5bbc0dba68c Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 09:53:51 -0700 Subject: [PATCH 021/143] style --- src/imitation/algorithms/preference_comparisons.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index de89ab45c..ba32b7476 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -947,17 +947,19 @@ def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarra return preferences def _display_videos( - self, frag1: TrajectoryWithRew, frag2: TrajectoryWithRew, + self, + frag1: TrajectoryWithRew, + frag2: TrajectoryWithRew, ) -> bool: """Displays the videos of the two fragments. Args: frag1: first fragment frag2: second fragment - + Returns: True if the first fragment is preferred, False if not. - + Raises: KeyboardInterrupt: if the user presses q to quit. RuntimeError: if the video files cannot be opened. @@ -1040,7 +1042,9 @@ def _display_videos( raise KeyboardInterrupt def _display_videos_in_notebook( - self, frag1_video_path: pathlib.Path, frag2_video_path: pathlib.Path, + self, + frag1_video_path: pathlib.Path, + frag2_video_path: pathlib.Path, ) -> None: from IPython.display import HTML, Video, clear_output, display @@ -1066,7 +1070,7 @@ def _display_videos_in_notebook( def _in_ipython(self) -> bool: try: - return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[attr-defined] + return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[attr-defined] except NameError: return False From f176bdc5c25536b21f28514dcfee0b0f77705776 Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 11:02:57 -0700 Subject: [PATCH 022/143] comments --- .../algorithms/preference_comparisons.py | 168 ++++++++++-------- .../scripts/train_preference_comparisons.py | 17 +- src/imitation/util/video_wrapper.py | 17 +- 3 files changed, 117 insertions(+), 85 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index ba32b7476..310ed2948 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -938,15 +938,15 @@ def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarra Returns: A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape - (b, ), where b is the length of the input + (b, ), where b is the length of `fragment_pairs` """ preferences = np.zeros(len(fragment_pairs), dtype=np.float32) for i, (frag1, frag2) in enumerate(fragment_pairs): - if self._display_videos(frag1, frag2): + if self._display_videos_and_gather_preference(frag1, frag2): preferences[i] = 1 return preferences - def _display_videos( + def _display_videos_and_gather_preference( self, frag1: TrajectoryWithRew, frag2: TrajectoryWithRew, @@ -965,7 +965,6 @@ def _display_videos( RuntimeError: if the video files cannot be opened. ValueError: if the trajectory infos are not set. """ - # display the videos if frag1.infos is None or frag2.infos is None: raise ValueError( "TrajectoryWithRew.infos must be set to display videos.", @@ -993,84 +992,113 @@ def _display_videos( # should never be hit assert False else: - print("Which video is preferred? (1 or 2, or q to quit, or r to replay):\n") - cap1 = cv2.VideoCapture(str(frag1_video_path)) - cap2 = cv2.VideoCapture(str(frag2_video_path)) - cv2.namedWindow("Video 1", cv2.WINDOW_NORMAL) - cv2.namedWindow("Video 2", cv2.WINDOW_NORMAL) - - # set window sizes - cv2.resizeWindow("Video 1", 500, 500) - cv2.resizeWindow("Video 2", 500, 500) - - # move windows side by side - cv2.moveWindow("Video 1", 0, 0) - cv2.moveWindow("Video 2", 500, 0) - - if not cap1.isOpened(): - raise RuntimeError(f"Error opening video file {frag1_video_path}.") - - if not cap2.isOpened(): - raise RuntimeError(f"Error opening video file {frag2_video_path}.") - - ret1, frame1 = cap1.read() - ret2, frame2 = cap2.read() - while cap1.isOpened() and cap2.isOpened(): - if ret1 or ret2: - cv2.imshow("Video 1", frame1) - cv2.imshow("Video 2", frame2) - ret1, frame1 = cap1.read() - ret2, frame2 = cap2.read() - - key = chr(cv2.waitKey(1) & 0xFF) - if key == "q": - raise KeyboardInterrupt - elif key == "r": - cap1.set(cv2.CAP_PROP_POS_FRAMES, 0) - cap2.set(cv2.CAP_PROP_POS_FRAMES, 0) - ret1, frame1 = cap1.read() - ret2, frame2 = cap2.read() - elif key == "1" or key == "2": - cap1.release() - cap2.release() - cv2.destroyAllWindows() - return key == "1" - - cap1.release() - cap2.release() - cv2.destroyAllWindows() - raise KeyboardInterrupt + return self._display_in_windows(frag1_video_path, frag2_video_path) + + def _display_in_windows( + self, frag1_video_path: pathlib.Path, frag2_video_path: pathlib.Path + ) -> bool: + """Displays the videos in separate windows. + + The videos are displayed side by side and the user is asked to indicate + which one is preferred. The interaction is done in the window rather than + in the command line because the command line is not interactive when + the video is playing, and it's nice to allow the user to choose a video before + the videos are done playing. The downside is that the instructions appear on the + command line and the interaction happens in the video window. + + Args: + frag1_video_path: path to the video file of the first fragment + frag2_video_path: path to the video file of the second fragment + + Returns: + True if the first fragment is preferred, False if not. + + Raises: + KeyboardInterrupt: if the user presses q to quit. + RuntimeError: if the video files cannot be opened. + """ + print("Which video is preferred? (1 or 2, or q to quit, or r to replay):\n") + + cap1 = cv2.VideoCapture(str(frag1_video_path)) + cap2 = cv2.VideoCapture(str(frag2_video_path)) + cv2.namedWindow("Video 1", cv2.WINDOW_NORMAL) + cv2.namedWindow("Video 2", cv2.WINDOW_NORMAL) + + # set window sizes + cv2.resizeWindow("Video 1", 500, 500) + cv2.resizeWindow("Video 2", 500, 500) + + # move windows side by side + cv2.moveWindow("Video 1", 0, 0) + cv2.moveWindow("Video 2", 500, 0) + + if not cap1.isOpened(): + raise RuntimeError(f"Error opening video file {frag1_video_path}.") + + if not cap2.isOpened(): + raise RuntimeError(f"Error opening video file {frag2_video_path}.") + + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + while cap1.isOpened() and cap2.isOpened(): + if ret1 or ret2: + cv2.imshow("Video 1", frame1) + cv2.imshow("Video 2", frame2) + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + + key = chr(cv2.waitKey(1) & 0xFF) + if key == "q": + cv2.destroyAllWindows() + raise KeyboardInterrupt + elif key == "r": + cap1.set(cv2.CAP_PROP_POS_FRAMES, 0) + cap2.set(cv2.CAP_PROP_POS_FRAMES, 0) + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + elif key == "1" or key == "2": + cv2.destroyAllWindows() + return key == "1" + + cv2.destroyAllWindows() + raise KeyboardInterrupt def _display_videos_in_notebook( self, frag1_video_path: pathlib.Path, frag2_video_path: pathlib.Path, ) -> None: + """Displays the videos in a notebook. + + Interaction can happen in the notebook while the videos are playing. + + Args: + frag1_video_path: path to the video file of the first fragment + frag2_video_path: path to the video file of the second fragment + + Raises: + RuntimeError: if the video files cannot be opened. + """ from IPython.display import HTML, Video, clear_output, display - display(HTML("

Video 1

")) - display( - Video( - filename=str(frag1_video_path), - height=500, - width=500, - html_attributes="controls autoplay muted", - ) - ) - display(HTML("

Video 2

")) - display( - Video( - filename=str(frag2_video_path), - height=500, - width=500, - html_attributes="controls autoplay muted", - ) - ) clear_output(wait=True) + for i, path in enumerate([frag1_video_path, frag2_video_path]): + if not path.exists(): + raise RuntimeError(f"Video file {path} does not exist.") + display(HTML(f"

Video {i}

")) + display( + Video( + filename=str(path), + height=500, + width=500, + html_attributes="controls autoplay muted", + ) + ) + def _in_ipython(self) -> bool: try: - return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[attr-defined] + return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore # noqa except NameError: return False diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 96c255b41..e0933035a 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -169,16 +169,17 @@ def train_preference_comparisons( custom_logger, log_dir = logging_ingredient.setup_logging() - wrappers = [] - if video_log_dir is not None: - wrappers.append( + post_wrappers = ( + [ video_wrapper.video_wrapper_factory( - log_dir=pathlib.Path(video_log_dir), - single_video=False, - ), - ) + pathlib.Path(video_log_dir), single_video=False + ) + ] + if video_log_dir + else None + ) - with environment.make_venv(post_wrappers=wrappers) as venv: + with environment.make_venv(post_wrappers=post_wrappers) as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index ee1402fc6..f46c3430a 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -31,6 +31,7 @@ def __init__( env: gym.Env, directory: pathlib.Path, single_video: bool = True, + delete_on_close: bool = True, ): """Builds a VideoWrapper. @@ -42,11 +43,15 @@ def __init__( Usually a single video file is what is desired. However, if one is searching for an interesting episode (perhaps by looking at the metadata), then saving to different files can be useful. + delete_on_close: if True, deletes the video file when the environment is + closed. If False, the video file is left on disk. """ super().__init__(env) self.episode_id = 0 self.video_recorder = None self.single_video = single_video + self.delete_on_close = delete_on_close + self.current_video_path: Optional[pathlib.Path] = None self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) @@ -71,6 +76,7 @@ def _reset_video_recorder(self) -> None: base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, ) + self.current_video_path = pathlib.Path(self.video_recorder.path) def reset(self, **kwargs): new_obs = super().reset(**kwargs) @@ -82,17 +88,14 @@ def step(self, action): obs, rew, done, info = self.env.step(action) self.video_recorder.capture_frame() # is it crazy to save the video path at every step? - info["video_path"] = self.get_current_video_path() + info["video_path"] = self.current_video_path return obs, rew, done, info def close(self) -> None: if self.video_recorder is not None: self.video_recorder.close() self.video_recorder = None + if self.delete_on_close: + for path in self.directory.glob("*.mp4"): + path.unlink() super().close() - - def get_current_video_path(self) -> Optional[pathlib.Path]: - """Returns the path to the current video file, or None if no video is active.""" - if self.video_recorder is None: - return None - return pathlib.Path(self.video_recorder.path) From 2540ba456e0af693e47c89c4ff1c108c91fbc2db Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 14:02:16 -0700 Subject: [PATCH 023/143] add notebook --- ci/clean_notebooks.py | 4 +- ...sons_with_synchronous_human_feedback.ipynb | 246 ++++++++++++++++++ .../algorithms/preference_comparisons.py | 21 +- src/imitation/util/video_wrapper.py | 1 - 4 files changed, 261 insertions(+), 11 deletions(-) create mode 100644 docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb diff --git a/ci/clean_notebooks.py b/ci/clean_notebooks.py index b9e394a20..cb54b659d 100755 --- a/ci/clean_notebooks.py +++ b/ci/clean_notebooks.py @@ -27,6 +27,7 @@ class UncleanNotebookError(Exception): "outputs": {"do": "constant", "value": list()}, "execution_count": {"do": "constant", "value": None}, "id": {"do": "keep"}, + "attachments": {"do": "constant", "value": None}, } structure: Dict[str, Dict[str, Dict[str, Any]]] = { @@ -63,7 +64,6 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None: print(f"Checking {file}") for cell in nb.cells: - # Remove empty cells if cell["cell_type"] == "code" and not cell["source"]: if check_only: @@ -77,7 +77,7 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None: if key not in structure[cell["cell_type"]]: if check_only: raise UncleanNotebookError( - f"Notebook {file} has unknown cell key {key}", + f"Notebook {file} has unknown cell key {key} for cell type {cell['cell_type']}", ) del cell[key] was_dirty = True diff --git a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb new file mode 100644 index 000000000..40d9c18ce --- /dev/null +++ b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb @@ -0,0 +1,246 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/5_train_preference_comparisons.ipynb)\n", + "# Learning a Reward Function using Preference Comparisons with Synchronous Human Feedback\n", + "\n", + "You can request human feedback via synchronous CLI or Notebook interactions as well. The setup is only slightly different than it would be with a synthetic preference gatherer." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's the starting setup. The major differences from the synthetic setup are indicated with comments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pathlib\n", + "import random\n", + "import tempfile\n", + "from imitation.algorithms import preference_comparisons\n", + "from imitation.rewards.reward_nets import BasicRewardNet\n", + "from imitation.util import video_wrapper\n", + "from imitation.util.networks import RunningNorm\n", + "from imitation.util.util import make_vec_env\n", + "from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor\n", + "import gym\n", + "from stable_baselines3 import PPO\n", + "import numpy as np\n", + "\n", + "# Add a temporary directory for video recordings of trajectories. Unfortunately Jupyter\n", + "# won't play videos outside the current directory, so we have to put them here. We'll\n", + "# delete them at the end of the script.\n", + "video_dir = tempfile.mkdtemp(dir=\".\", prefix=\"videos_\")\n", + "\n", + "rng = np.random.default_rng(0)\n", + "\n", + "# Add a video wrapper to the environment. This will record videos of the agent's\n", + "# trajectories so we can review them later.\n", + "venv = make_vec_env(\n", + " \"Pendulum-v1\",\n", + " rng=rng,\n", + " post_wrappers=[\n", + " video_wrapper.video_wrapper_factory(pathlib.Path(video_dir), single_video=False)\n", + " ],\n", + ")\n", + "\n", + "reward_net = BasicRewardNet(\n", + " venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm\n", + ")\n", + "\n", + "fragmenter = preference_comparisons.RandomFragmenter(\n", + " warning_threshold=0,\n", + " rng=rng,\n", + ")\n", + "\n", + "# This gatherer will show the user (you!) pairs of trajectories and ask it to choose\n", + "# which one is better. It will then use the user's feedback to train the reward network.\n", + "gatherer = preference_comparisons.SynchronousHumanGatherer(video_dir=video_dir)\n", + "\n", + "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", + "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", + " preference_model=preference_model,\n", + " loss=preference_comparisons.CrossEntropyRewardLoss(),\n", + " epochs=3,\n", + " rng=rng,\n", + ")\n", + "\n", + "agent = PPO(\n", + " policy=FeedForward32Policy,\n", + " policy_kwargs=dict(\n", + " features_extractor_class=NormalizeFeaturesExtractor,\n", + " features_extractor_kwargs=dict(normalize_class=RunningNorm),\n", + " ),\n", + " env=venv,\n", + " seed=0,\n", + " n_steps=2048 // venv.num_envs,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + ")\n", + "\n", + "trajectory_generator = preference_comparisons.AgentTrainer(\n", + " algorithm=agent,\n", + " reward_fn=reward_net,\n", + " venv=venv,\n", + " exploration_frac=0.0,\n", + " rng=rng,\n", + ")\n", + "\n", + "pref_comparisons = preference_comparisons.PreferenceComparisons(\n", + " trajectory_generator,\n", + " reward_net,\n", + " num_iterations=5,\n", + " fragmenter=fragmenter,\n", + " preference_gatherer=gatherer,\n", + " reward_trainer=reward_trainer,\n", + " fragment_length=100,\n", + " transition_oversampling=1,\n", + " initial_comparison_frac=0.1,\n", + " allow_variable_horizon=False,\n", + " initial_epoch_multiplier=1,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We're going to train with only 20 comparisons to make it faster for you to evaluate. The videos will appear in-line in this notebook for you to watch, and a text input will appear for you to choose one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pref_comparisons.train(\n", + " total_timesteps=5_000, # For good performance this should be 1_000_000\n", + " total_comparisons=20, # For good performance this should be 5_000\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "From this point onward, this notebook is the same as [the synthetic gatherer notebook](5_train_preference_comparisons.ipynb).\n", + "\n", + "After we trained the reward network using the preference comparisons algorithm, we can wrap our environment with that learned reward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n", + "\n", + "\n", + "learned_reward_venv = RewardVecEnvWrapper(venv, reward_net.predict)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can train an agent, that only sees those learned reward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3 import PPO\n", + "from stable_baselines3.ppo import MlpPolicy\n", + "\n", + "learner = PPO(\n", + " policy=MlpPolicy,\n", + " env=learned_reward_venv,\n", + " seed=0,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + " n_steps=64,\n", + ")\n", + "learner.learn(1000) # Note: set to 100000 to train a proficient expert" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we can evaluate it using the original reward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "\n", + "reward, _ = evaluate_policy(learner.policy, venv, 10)\n", + "print(reward)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# clean up the videos we made\n", + "import shutil\n", + "\n", + "shutil.rmtree(video_dir)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "439158cd89905785fcc749928062ade7bfccc3f087fab145e5671f895c635937" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 310ed2948..0d017ef29 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -908,12 +908,14 @@ def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]: return np.array(rews1, dtype=np.float32), np.array(rews2, dtype=np.float32) -class SynchronousCLIGatherer(PreferenceGatherer): - """Queries for human preferences using the command line interface.""" +class SynchronousHumanGatherer(PreferenceGatherer): + """Queries for human preferences using the command line or a notebook.""" def __init__( self, video_dir: pathlib.Path, + video_width: int = 500, + video_height: int = 500, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ) -> None: """Initialize the human preference gatherer. @@ -924,6 +926,8 @@ def __init__( """ super().__init__(custom_logger=custom_logger) self.video_dir = video_dir + self.video_width = video_width + self.video_height = video_height def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: """Displays each pair of fragments and asks for a preference. @@ -1025,12 +1029,12 @@ def _display_in_windows( cv2.namedWindow("Video 2", cv2.WINDOW_NORMAL) # set window sizes - cv2.resizeWindow("Video 1", 500, 500) - cv2.resizeWindow("Video 2", 500, 500) + cv2.resizeWindow("Video 1", self.video_width, self.video_height) + cv2.resizeWindow("Video 2", self.video_width, self.video_height) # move windows side by side cv2.moveWindow("Video 1", 0, 0) - cv2.moveWindow("Video 2", 500, 0) + cv2.moveWindow("Video 2", self.video_width, 0) if not cap1.isOpened(): raise RuntimeError(f"Error opening video file {frag1_video_path}.") @@ -1086,13 +1090,14 @@ def _display_videos_in_notebook( for i, path in enumerate([frag1_video_path, frag2_video_path]): if not path.exists(): raise RuntimeError(f"Video file {path} does not exist.") - display(HTML(f"

Video {i}

")) + display(HTML(f"

Video {i+1}

")) display( Video( filename=str(path), - height=500, - width=500, + height=self.video_height, + width=self.video_width, html_attributes="controls autoplay muted", + embed=False, ) ) diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index f46c3430a..fb58323d8 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -87,7 +87,6 @@ def reset(self, **kwargs): def step(self, action): obs, rew, done, info = self.env.step(action) self.video_recorder.capture_frame() - # is it crazy to save the video path at every step? info["video_path"] = self.current_video_path return obs, rew, done, info From 3ff95791450dafa8802b35889538048aff919bec Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 14:33:47 -0700 Subject: [PATCH 024/143] tutorial --- docs/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/index.rst b/docs/index.rst index 0db52c23b..eecc79981 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -72,6 +72,7 @@ If you use ``imitation`` in your research project, please cite our paper to help tutorials/4_train_airl tutorials/5_train_preference_comparisons tutorials/5a_train_preference_comparisons_with_cnn + tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback tutorials/6_train_mce tutorials/7_train_density From ce5efa75a3d638b03c0d910a37caec24b2cc4918 Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 15:00:47 -0700 Subject: [PATCH 025/143] lint --- ci/clean_notebooks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ci/clean_notebooks.py b/ci/clean_notebooks.py index cb54b659d..432cea649 100755 --- a/ci/clean_notebooks.py +++ b/ci/clean_notebooks.py @@ -18,6 +18,7 @@ class UncleanNotebookError(Exception): "metadata": {"do": "constant", "value": dict()}, "source": {"do": "keep"}, "id": {"do": "keep"}, + "attachments": {"do": "constant", "value": None}, } code_structure: Dict[str, Dict[str, Any]] = { @@ -27,7 +28,6 @@ class UncleanNotebookError(Exception): "outputs": {"do": "constant", "value": list()}, "execution_count": {"do": "constant", "value": None}, "id": {"do": "keep"}, - "attachments": {"do": "constant", "value": None}, } structure: Dict[str, Dict[str, Dict[str, Any]]] = { @@ -77,7 +77,8 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None: if key not in structure[cell["cell_type"]]: if check_only: raise UncleanNotebookError( - f"Notebook {file} has unknown cell key {key} for cell type {cell['cell_type']}", + f"Notebook {file} has unknown cell key {key} for cell type " + + f"{cell['cell_type']}", ) del cell[key] was_dirty = True From 2ba2994acc8fd435adb91542d307907177f5b540 Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 16:44:26 -0700 Subject: [PATCH 026/143] lint --- ci/clean_notebooks.py | 18 +++++++++++++++--- setup.cfg | 4 ++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/ci/clean_notebooks.py b/ci/clean_notebooks.py index 432cea649..9b87aaea3 100755 --- a/ci/clean_notebooks.py +++ b/ci/clean_notebooks.py @@ -18,7 +18,7 @@ class UncleanNotebookError(Exception): "metadata": {"do": "constant", "value": dict()}, "source": {"do": "keep"}, "id": {"do": "keep"}, - "attachments": {"do": "constant", "value": None}, + "attachments": {"do": "constant", "value": {}}, } code_structure: Dict[str, Dict[str, Any]] = { @@ -110,7 +110,12 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None: def parse_args(): - """Parse command-line arguments.""" + """Parse command-line arguments. + + Returns: + parser: The parser object. + args: The parsed arguments. + """ # if the argument --check has been passed, check if the notebooks are clean # otherwise, clean them in-place parser = argparse.ArgumentParser() @@ -127,7 +132,14 @@ def parse_args(): def get_files(input_paths: List): - """Build list of files to scan from list of paths and files.""" + """Build list of files to scan from list of paths and files. + + Args: + input_paths: List of paths and files to scan. + + Returns: + files: List of files to scan. + """ files = [] for file in input_paths: if file.is_dir(): diff --git a/setup.cfg b/setup.cfg index 979c3ca46..7520679bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,8 +9,8 @@ per-file-ignores = ../src/imitation/scripts/config/*.py:F841 src/imitation/envs/examples/airl_envs/*.py:D -[darglint] -strictness=short +# [darglint] +# strictness=short [isort] known_first_party=imitation From b3fd565580455859d7a28406955d8e63dc516ee7 Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 18:22:27 -0700 Subject: [PATCH 027/143] more lint --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 0d017ef29..92a8ead8a 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1103,7 +1103,7 @@ def _display_videos_in_notebook( def _in_ipython(self) -> bool: try: - return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore # noqa + return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[name-defined] # noqa except NameError: return False From 002146f0943d0e90e7f554134751e0eb9edff03e Mon Sep 17 00:00:00 2001 From: timbauman Date: Thu, 11 May 2023 21:30:53 -0700 Subject: [PATCH 028/143] oops --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7520679bc..979c3ca46 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,8 +9,8 @@ per-file-ignores = ../src/imitation/scripts/config/*.py:F841 src/imitation/envs/examples/airl_envs/*.py:D -# [darglint] -# strictness=short +[darglint] +strictness=short [isort] known_first_party=imitation From 9bb1f0090bbe8c42f988042ec282a42b7c10b815 Mon Sep 17 00:00:00 2001 From: timbauman Date: Fri, 12 May 2023 18:20:08 -0700 Subject: [PATCH 029/143] add test --- .../algorithms/preference_comparisons.py | 6 ++- .../algorithms/test_preference_comparisons.py | 53 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 92a8ead8a..1dbe93471 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -5,6 +5,7 @@ """ import abc import math +import os import pathlib import pickle import re @@ -1103,10 +1104,13 @@ def _display_videos_in_notebook( def _in_ipython(self) -> bool: try: - return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[name-defined] # noqa + return self.is_running_pytest_test() or get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[name-defined] # noqa except NameError: return False + def is_running_pytest_test(self) -> bool: + return "PYTEST_CURRENT_TEST" in os.environ + class PreferenceDataset(data_th.Dataset): """A PyTorch Dataset for preference comparisons. diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 5f237812f..e16478449 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1,8 +1,10 @@ """Tests for the preference comparisons reward learning implementation.""" import math +import pathlib import re from typing import Any, Sequence +from unittest.mock import patch import gym import numpy as np @@ -290,6 +292,57 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): build_preference_comparsions(None, None, random_fragmenter, rng=rng) +@patch("builtins.input") +@patch("IPython.display.display") +def test_synchronous_human_gatherer(mock_display, mock_input): + del mock_display # unused + gatherer = preference_comparisons.SynchronousHumanGatherer( + video_dir=pathlib.Path(".") + ) + + # these inputs are designed solely to pass the test. they aren't tested for anything + trajectory_pairs = [ + ( + types.TrajectoryWithRew( + np.array([1, 2]), + np.array([1]), + np.array( + [ + { + "video_path": pathlib.Path( + "tests/algorithms/test_preference_comparisons.py" + ) + } + ] + ), + True, + np.array([1.0]), + ), + types.TrajectoryWithRew( + np.array([1, 2]), + np.array([1]), + np.array( + [ + { + "video_path": pathlib.Path( + "tests/algorithms/test_preference_comparisons.py" + ) + } + ] + ), + True, + np.array([1.0]), + ), + ) + ] + + # this is the actual test + mock_input.return_value = "1" + assert gatherer(trajectory_pairs) == np.array([1.0]) + mock_input.return_value = "2" + assert gatherer(trajectory_pairs) == np.array([0.0]) + + @pytest.mark.parametrize( "schedule", ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], From f7b7a0d18310afaf52bba4d2c97527cbf70c8d88 Mon Sep 17 00:00:00 2001 From: timbauman Date: Fri, 12 May 2023 18:46:03 -0700 Subject: [PATCH 030/143] lint --- .../algorithms/preference_comparisons.py | 11 ++++++++-- src/imitation/scripts/eval_policy.py | 1 - .../scripts/train_preference_comparisons.py | 5 +++-- .../algorithms/test_preference_comparisons.py | 20 +++++++++---------- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 1dbe93471..b46a096fe 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -923,7 +923,12 @@ def __init__( Args: video_dir: directory where videos of the trajectories are saved. + video_width: width of the video in pixels. + video_height: height of the video in pixels. custom_logger: Where to log to; if None (default), creates a new logger. + + Raises: + ValueError: if `video_dir` is not a directory. """ super().__init__(custom_logger=custom_logger) self.video_dir = video_dir @@ -1000,7 +1005,9 @@ def _display_videos_and_gather_preference( return self._display_in_windows(frag1_video_path, frag2_video_path) def _display_in_windows( - self, frag1_video_path: pathlib.Path, frag2_video_path: pathlib.Path + self, + frag1_video_path: pathlib.Path, + frag2_video_path: pathlib.Path, ) -> bool: """Displays the videos in separate windows. @@ -1099,7 +1106,7 @@ def _display_videos_in_notebook( width=self.video_width, html_attributes="controls autoplay muted", embed=False, - ) + ), ) def _in_ipython(self) -> bool: diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 166fb53e9..1a2659c61 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -5,7 +5,6 @@ import time from typing import Any, Mapping, Optional -import gym import numpy as np from sacred.observers import FileStorageObserver from stable_baselines3.common.vec_env import VecEnvWrapper diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index e0933035a..3b5fc9208 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -172,8 +172,9 @@ def train_preference_comparisons( post_wrappers = ( [ video_wrapper.video_wrapper_factory( - pathlib.Path(video_log_dir), single_video=False - ) + pathlib.Path(video_log_dir), + single_video=False, + ), ] if video_log_dir else None diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index e16478449..79f5e476b 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -297,7 +297,7 @@ def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): def test_synchronous_human_gatherer(mock_display, mock_input): del mock_display # unused gatherer = preference_comparisons.SynchronousHumanGatherer( - video_dir=pathlib.Path(".") + video_dir=pathlib.Path("."), ) # these inputs are designed solely to pass the test. they aren't tested for anything @@ -310,10 +310,10 @@ def test_synchronous_human_gatherer(mock_display, mock_input): [ { "video_path": pathlib.Path( - "tests/algorithms/test_preference_comparisons.py" - ) - } - ] + "tests/algorithms/test_preference_comparisons.py", + ), + }, + ], ), True, np.array([1.0]), @@ -325,15 +325,15 @@ def test_synchronous_human_gatherer(mock_display, mock_input): [ { "video_path": pathlib.Path( - "tests/algorithms/test_preference_comparisons.py" - ) - } - ] + "tests/algorithms/test_preference_comparisons.py", + ), + }, + ], ), True, np.array([1.0]), ), - ) + ), ] # this is the actual test From d82542c82ad145a9ad590a0df3fcab0fcc9c7bcd Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Mon, 15 May 2023 22:05:19 +0200 Subject: [PATCH 031/143] Integrate SynchronousHumanPreferenceGatherer --- .../algorithms/preference_comparisons.py | 137 ++++++++++-------- .../config/train_preference_comparisons.py | 20 +++ .../scripts/train_preference_comparisons.py | 13 +- 3 files changed, 94 insertions(+), 76 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 83678340c..7700e5bb7 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -838,8 +838,11 @@ def __call__(self, queries: Sequence[TrajectoryWithRewPair]) -> Dict[str, Sequen # Save fragment videos and submit queries for query_id, query in identified_queries.items(): - self._write_fragment_video(query[0], name=f"{query_id}-left") - self._write_fragment_video(query[1], name=f"{query_id}-right") + output_file_name = os.path.join(self.video_output_dir, f'{query_id}' + '{}.webm') + write_fragment_video(query[0], frames_per_second=self.frames_per_second, + output_path=output_file_name.format("left")) + write_fragment_video(query[1], frames_per_second=self.frames_per_second, + output_path=output_file_name.format("right")) self._query(query_id) return identified_queries @@ -849,59 +852,58 @@ def _query(self, query_id): self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)} ) - def _write_fragment_video(self, fragment, name: str) -> None: - output_file_name = os.path.join(self.video_output_dir, f'{name}.webm') - frame_shape = self._get_frame_shape(fragment) - video_writer = cv2.VideoWriter( - output_file_name, - cv2.VideoWriter_fourcc(*'VP90'), - self.frames_per_second, - frame_shape, - ) +def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> None: + frame_shape = get_frame_shape(fragment) + video_writer = cv2.VideoWriter( + output_path, + cv2.VideoWriter_fourcc(*'VP90'), + frames_per_second, + frame_shape, + ) + + # Make videos from rendered observations if available + if "rendered_img" in fragment.infos[0]: + frames = [] + for i in range(len(fragment.infos)): + frame_info = fragment.infos[i]["rendered_img"] + # If path is provided load cached image + if isinstance(frame_info, AnyPath.__args__): + frame = np.load(frame_info) + elif isinstance(frame_info, np.ndarray): + frame = frame_info + frames.append(frame) + else: + frames = fragment.obs + + for frame in frames: + # Transform to RGB frame if necessary + if frame.shape[-1] < 3: + missing_channels = 3 - frame.shape[-1] + frame = np.concatenate( + [frame] + missing_channels * [frame[..., -1][..., None]], axis=-1 + ) + video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - # Make videos from rendered observations if available - if "rendered_img" in fragment.infos[0]: - frames = [] - for i in range(len(fragment.infos)): - frame_info = fragment.infos[i]["rendered_img"] - # If path is provided load cached image - if isinstance(frame_info, AnyPath.__args__): - frame = np.load(frame_info) - elif isinstance(frame_info, np.ndarray): - frame = frame_info - frames.append(frame) - else: - frames = fragment.obs - - for frame in frames: - # Transform to RGB frame if necessary - if frame.shape[-1] < 3: - missing_channels = 3 - frame.shape[-1] - frame = np.concatenate( - [frame] + missing_channels * [frame[..., -1][..., None]], axis=-1 - ) - video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + video_writer.release() - video_writer.release() - @staticmethod - def _get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: - if "rendered_img" in fragment.infos[0]: - rendered_img_info = fragment.infos[0]["rendered_img"] - # If path is provided load cached image - if isinstance(rendered_img_info, AnyPath.__args__): - single_frame = np.load(rendered_img_info) - else: - single_frame = rendered_img_info +def get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: + if "rendered_img" in fragment.infos[0]: + rendered_img_info = fragment.infos[0]["rendered_img"] + # If path is provided load cached image + if isinstance(rendered_img_info, AnyPath.__args__): + single_frame = np.load(rendered_img_info) else: - single_frame = np.array(fragment.obs[0]) - # Check whether obervations are image-like - if len(single_frame.shape) < 2: - raise ValueError("Observation must be an image, " - f"but shape {single_frame.shape} has too few dimensions!") - # Swap dimensions, because matrix and image dims are swapped - return single_frame.shape[1], single_frame.shape[0] + single_frame = rendered_img_info + else: + single_frame = np.array(fragment.obs[0]) + # Check whether obervations are image-like + if len(single_frame.shape) < 2: + raise ValueError("Observation must be an image, " + f"but shape {single_frame.shape} has too few dimensions!") + # Swap dimensions, because matrix and image dims are swapped + return single_frame.shape[1], single_frame.shape[0] class PreferenceGatherer(abc.ABC): @@ -1052,7 +1054,9 @@ def __init__( video_dir: pathlib.Path, video_width: int = 500, video_height: int = 500, + frames_per_second: int = 25, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + rng: Optional[np.random.Generator] = None, ) -> None: """Initialize the human preference gatherer. @@ -1065,12 +1069,14 @@ def __init__( Raises: ValueError: if `video_dir` is not a directory. """ - super().__init__(custom_logger=custom_logger) + super().__init__(custom_logger=custom_logger, rng=rng) self.video_dir = video_dir + os.makedirs(video_dir, exist_ok=True) self.video_width = video_width self.video_height = video_height + self.frames_per_second = frames_per_second - def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Displays each pair of fragments and asks for a preference. It iteratively requests user feedback for each pair of fragments. If in the @@ -1085,16 +1091,23 @@ def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarra A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape (b, ), where b is the length of `fragment_pairs` """ - preferences = np.zeros(len(fragment_pairs), dtype=np.float32) - for i, (frag1, frag2) in enumerate(fragment_pairs): - if self._display_videos_and_gather_preference(frag1, frag2): + preferences = np.zeros(len(self.pending_queries), dtype=np.float32) + for i, (query_id, query) in enumerate(self.pending_queries.items()): + + write_fragment_video(query[0], frames_per_second=self.frames_per_second, + output_path=os.path.join(self.video_dir, f'{query_id}-left.webm')) + write_fragment_video(query[1], frames_per_second=self.frames_per_second, + output_path=os.path.join(self.video_dir, f'{query_id}-right.webm')) + if self._display_videos_and_gather_preference(query_id): preferences[i] = 1 - return preferences + + queries = list(self.pending_queries.values()) + self.pending_queries.clear() + return queries, preferences def _display_videos_and_gather_preference( self, - frag1: TrajectoryWithRew, - frag2: TrajectoryWithRew, + query_id ) -> bool: """Displays the videos of the two fragments. @@ -1110,12 +1123,8 @@ def _display_videos_and_gather_preference( RuntimeError: if the video files cannot be opened. ValueError: if the trajectory infos are not set. """ - if frag1.infos is None or frag2.infos is None: - raise ValueError( - "TrajectoryWithRew.infos must be set to display videos.", - ) - frag1_video_path = frag1.infos[0]["video_path"] - frag2_video_path = frag2.infos[0]["video_path"] + frag1_video_path = os.path.join(self.video_dir, f'{query_id}-left.webm') + frag2_video_path = os.path.join(self.video_dir, f'{query_id}-right.webm') if self._in_ipython(): self._display_videos_in_notebook(frag1_video_path, frag2_video_path) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 3a6a92da7..7af6f7a55 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -69,6 +69,26 @@ def train_defaults(): video_log_dir = None +@train_preference_comparisons_ex.named_config +def synch_human_preferences(): + gatherer_cls = preference_comparisons.SynchronousHumanGatherer + gatherer_kwargs = dict( + video_dir="videos" + ) + querent_cls = preference_comparisons.PreferenceQuerent + querent_kwargs = dict() + environment = dict( + post_wrappers=dict( + RenderImageInfoWrapper=lambda env, env_id, **kwargs: + RenderImageInfoWrapper(env, **kwargs), + ), + num_vec=2, + post_wrappers_kwargs=dict( + RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True), + ), + ) + + @train_preference_comparisons_ex.named_config def human_preferences(): gatherer_cls = preference_comparisons.PrefCollectGatherer diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 9cced26d8..057dfe302 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -173,18 +173,7 @@ def train_preference_comparisons( custom_logger, log_dir = logging_ingredient.setup_logging() - post_wrappers = ( - [ - video_wrapper.video_wrapper_factory( - pathlib.Path(video_log_dir), - single_video=False, - ), - ] - if video_log_dir - else None - ) - - with environment.make_venv(post_wrappers=post_wrappers) as venv: + with environment.make_venv() as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, From 72ad6aa52e8028e55694b039c2948587ef7a9a5a Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 19 May 2023 17:37:18 +0200 Subject: [PATCH 032/143] Fix remaining tests --- .../algorithms/preference_comparisons.py | 62 +-- .../algorithms/test_preference_comparisons.py | 358 ++++++++++-------- 2 files changed, 238 insertions(+), 182 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 7700e5bb7..e9cff3fb8 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -800,7 +800,9 @@ def __init__( del rng self.logger = custom_logger or imit_logger.configure() - def __call__(self, queries: Sequence[TrajectoryWithRewPair]) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + def __call__( + self, queries: Sequence[TrajectoryWithRewPair] + ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: """Queries the user for their preferences. This dummy implementation does nothing because by default the queries are answered by an oracle. @@ -833,16 +835,26 @@ def __init__( # Create video directory os.makedirs(self.video_output_dir, exist_ok=True) - def __call__(self, queries: Sequence[TrajectoryWithRewPair]) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + def __call__( + self, queries: Sequence[TrajectoryWithRewPair] + ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: identified_queries = super().__call__(queries) # Save fragment videos and submit queries for query_id, query in identified_queries.items(): - output_file_name = os.path.join(self.video_output_dir, f'{query_id}' + '{}.webm') - write_fragment_video(query[0], frames_per_second=self.frames_per_second, - output_path=output_file_name.format("left")) - write_fragment_video(query[1], frames_per_second=self.frames_per_second, - output_path=output_file_name.format("right")) + output_file_name = os.path.join( + self.video_output_dir, f"{query_id}" + "{}.webm" + ) + write_fragment_video( + query[0], + frames_per_second=self.frames_per_second, + output_path=output_file_name.format("left"), + ) + write_fragment_video( + query[1], + frames_per_second=self.frames_per_second, + output_path=output_file_name.format("right"), + ) self._query(query_id) return identified_queries @@ -857,7 +869,7 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> frame_shape = get_frame_shape(fragment) video_writer = cv2.VideoWriter( output_path, - cv2.VideoWriter_fourcc(*'VP90'), + cv2.VideoWriter_fourcc(*"VP90"), frames_per_second, frame_shape, ) @@ -900,8 +912,10 @@ def get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: single_frame = np.array(fragment.obs[0]) # Check whether obervations are image-like if len(single_frame.shape) < 2: - raise ValueError("Observation must be an image, " - f"but shape {single_frame.shape} has too few dimensions!") + raise ValueError( + "Observation must be an image, " + f"but shape {single_frame.shape} has too few dimensions!" + ) # Swap dimensions, because matrix and image dims are swapped return single_frame.shape[1], single_frame.shape[0] @@ -1094,10 +1108,16 @@ def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: preferences = np.zeros(len(self.pending_queries), dtype=np.float32) for i, (query_id, query) in enumerate(self.pending_queries.items()): - write_fragment_video(query[0], frames_per_second=self.frames_per_second, - output_path=os.path.join(self.video_dir, f'{query_id}-left.webm')) - write_fragment_video(query[1], frames_per_second=self.frames_per_second, - output_path=os.path.join(self.video_dir, f'{query_id}-right.webm')) + write_fragment_video( + query[0], + frames_per_second=self.frames_per_second, + output_path=os.path.join(self.video_dir, f"{query_id}-left.webm"), + ) + write_fragment_video( + query[1], + frames_per_second=self.frames_per_second, + output_path=os.path.join(self.video_dir, f"{query_id}-right.webm"), + ) if self._display_videos_and_gather_preference(query_id): preferences[i] = 1 @@ -1105,10 +1125,7 @@ def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: self.pending_queries.clear() return queries, preferences - def _display_videos_and_gather_preference( - self, - query_id - ) -> bool: + def _display_videos_and_gather_preference(self, query_id) -> bool: """Displays the videos of the two fragments. Args: @@ -1123,8 +1140,8 @@ def _display_videos_and_gather_preference( RuntimeError: if the video files cannot be opened. ValueError: if the trajectory infos are not set. """ - frag1_video_path = os.path.join(self.video_dir, f'{query_id}-left.webm') - frag2_video_path = os.path.join(self.video_dir, f'{query_id}-right.webm') + frag1_video_path = os.path.join(self.video_dir, f"{query_id}-left.webm") + frag2_video_path = os.path.join(self.video_dir, f"{query_id}-right.webm") if self._in_ipython(): self._display_videos_in_notebook(frag1_video_path, frag2_video_path) @@ -1265,6 +1282,7 @@ def is_running_pytest_test(self) -> bool: class PrefCollectGatherer(PreferenceGatherer): """Gathers preferences from PrefCollect interface.""" + def __init__( self, pref_collect_address: str, @@ -2017,12 +2035,12 @@ def __init__( if self.rng is None and has_any_rng_args_none: raise ValueError( "If you don't provide a random state, you must provide your own " - "seeded fragmenter, preference gatherer, and reward_trainer. " + "seeded fragmenter, preference gatherer, preference querent, and reward_trainer. " "You can initialize a random state with `np.random.default_rng(seed)`.", ) elif self.rng is not None and not has_any_rng_args_none: raise ValueError( - "If you provide your own fragmenter, preference gatherer, " + "If you provide your own fragmenter, preference gatherer, preference querent," "and reward trainer, you don't need to provide a random state.", ) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 3a935f7e1..da3ec5cdb 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -5,7 +5,7 @@ import re import uuid from typing import Any, Sequence, Tuple -from unittest.mock import Mock, MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import gym import numpy as np @@ -20,8 +20,13 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons -from imitation.algorithms.preference_comparisons import PreferenceQuerent, PrefCollectQuerent, PreferenceGatherer, \ - SyntheticGatherer, PrefCollectGatherer +from imitation.algorithms.preference_comparisons import ( + PrefCollectGatherer, + PrefCollectQuerent, + PreferenceGatherer, + PreferenceQuerent, + SyntheticGatherer, +) from imitation.data import types from imitation.data.types import TrajectoryWithRew from imitation.regularization import regularizers, updaters @@ -78,21 +83,22 @@ def agent_trainer(agent, reward_net, venv, rng): return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng) -# TODO: trajectory_with_rew fixture already exists in data.test_types, should be moved to a conftest.py @pytest.fixture def trajectory_with_rew(venv): - observations, rewards, dones, infos, actions = [], [], [], [], [] + observations, rewards, _, infos, actions = [], [], [], [], [] observations.append(venv.observation_space.sample()) for _ in range(2): observations.append(venv.observation_space.sample()) actions.append(venv.action_space.sample()) rewards.append(0.0) infos.append({}) - return TrajectoryWithRew(obs=np.array(observations), - acts=np.array(actions), - rews=np.array(rewards), - infos=np.array(infos), - terminal=False) + return TrajectoryWithRew( + obs=np.array(observations), + acts=np.array(actions), + rews=np.array(rewards), + infos=np.array(infos), + terminal=False, + ) def assert_info_arrs_equal(arr1, arr2): # pragma: no cover @@ -111,8 +117,8 @@ def check_possibly_nested_dicts_equal(dict1, dict2): def _check_trajs_equal( - trajs1: Sequence[types.TrajectoryWithRew], - trajs2: Sequence[types.TrajectoryWithRew], + trajs1: Sequence[types.TrajectoryWithRew], + trajs2: Sequence[types.TrajectoryWithRew], ): assert len(trajs1) == len(trajs2) for traj1, traj2 in zip(trajs1, trajs2): @@ -136,8 +142,8 @@ def test_mismatched_spaces(venv, agent, rng): other_venv.action_space, ) with pytest.raises( - ValueError, - match="spaces do not match", + ValueError, + match="spaces do not match", ): preference_comparisons.AgentTrainer( agent, @@ -148,8 +154,8 @@ def test_mismatched_spaces(venv, agent, rng): def test_trajectory_dataset_seeding( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - num_samples: int = 400, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + num_samples: int = 400, ): dataset1 = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, @@ -176,9 +182,9 @@ def test_trajectory_dataset_seeding( # CartPole max episode length is 200 @pytest.mark.parametrize("num_steps", [0, 199, 200, 201, 400]) def test_trajectory_dataset_len( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - num_steps: int, - rng, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + num_steps: int, + rng, ): dataset = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, @@ -192,8 +198,8 @@ def test_trajectory_dataset_len( def test_trajectory_dataset_too_long( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - rng, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + rng, ): dataset = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, @@ -204,9 +210,9 @@ def test_trajectory_dataset_too_long( def test_trajectory_dataset_not_static( - cartpole_expert_trajectories: Sequence[TrajectoryWithRew], - rng, - num_steps: int = 400, + cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + rng, + num_steps: int = 400, ): """Tests sample() doesn't always return the same value.""" dataset = preference_comparisons.TrajectoryDataset( @@ -230,27 +236,27 @@ def test_transitions_left_in_buffer(agent_trainer): # with transitions. agent_trainer.buffering_wrapper.n_transitions = 2 with pytest.raises( - RuntimeError, - match=re.escape( - "There are 2 transitions left in the buffer. " - "Call AgentTrainer.sample() first to clear them.", - ), + RuntimeError, + match=re.escape( + "There are 2 transitions left in the buffer. " + "Call AgentTrainer.sample() first to clear them.", + ), ): agent_trainer.train(steps=1) @pytest.mark.parametrize( "schedule", - ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t ** 3)], + ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], ) def test_preference_comparisons_raises( - agent_trainer, - reward_net, - random_fragmenter, - preference_model, - custom_logger, - schedule, - rng, + agent_trainer, + reward_net, + random_fragmenter, + preference_model, + custom_logger, + schedule, + rng, ): loss = preference_comparisons.CrossEntropyRewardLoss() reward_trainer = preference_comparisons.BasicRewardTrainer( @@ -258,21 +264,25 @@ def test_preference_comparisons_raises( loss, rng=rng, ) - # TODO: also instantiate querent + + querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) # no rng, must provide fragmenter, preference gatherer, preference_querent, reward trainer no_rng_msg = ( ".*don't provide.*random state.*provide.*fragmenter" - ".*preference gatherer.*reward_trainer.*" + ".*preference gatherer.*preference querent.*reward_trainer.*" ) - def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): + def build_preference_comparisons( + gatherer, querent, reward_trainer, fragmenter, rng + ): preference_comparisons.PreferenceComparisons( agent_trainer, reward_net, num_iterations=2, transition_oversampling=2, reward_trainer=reward_trainer, + preference_querent=querent, preference_gatherer=gatherer, fragmenter=fragmenter, custom_logger=custom_logger, @@ -281,37 +291,43 @@ def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): ) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparisons(gatherer, None, None, rng=None) + build_preference_comparisons(gatherer, None, None, None, rng=None) + + with pytest.raises(ValueError, match=no_rng_msg): + build_preference_comparisons(None, None, reward_trainer, None, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparisons(None, reward_trainer, None, rng=None) + build_preference_comparisons(None, None, None, random_fragmenter, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparisons(None, None, random_fragmenter, rng=None) + build_preference_comparisons(None, querent, None, None, rng=None) - # TODO: this raises because querent not passed and has to be instantiated # This should not raise - build_preference_comparisons(gatherer, reward_trainer, random_fragmenter, rng=None) + build_preference_comparisons( + gatherer, querent, reward_trainer, random_fragmenter, rng=None + ) # if providing fragmenter, preference gatherer, reward trainer, does not need rng. with_rng_msg = ( - "provide.*fragmenter.*preference gatherer.*reward trainer" + "provide.*fragmenter.*preference gatherer.*preference querent.*reward trainer" ".*don't need.*random state.*" ) with pytest.raises(ValueError, match=with_rng_msg): build_preference_comparisons( gatherer, + querent, reward_trainer, random_fragmenter, rng=rng, ) # This should not raise - build_preference_comparisons(None, None, None, rng=rng) - build_preference_comparisons(gatherer, None, None, rng=rng) - build_preference_comparisons(None, reward_trainer, None, rng=rng) - build_preference_comparisons(None, None, random_fragmenter, rng=rng) + build_preference_comparisons(None, None, None, None, rng=rng) + build_preference_comparisons(gatherer, None, None, None, rng=rng) + build_preference_comparisons(None, querent, None, None, rng=rng) + build_preference_comparisons(None, None, reward_trainer, None, rng=rng) + build_preference_comparisons(None, None, None, random_fragmenter, rng=rng) @patch("builtins.input") @@ -367,15 +383,15 @@ def test_synchronous_human_gatherer(mock_display, mock_input): @pytest.mark.parametrize( "schedule", - ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t ** 3)], + ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], ) def test_trainer_no_crash( - agent_trainer, - reward_net, - random_fragmenter, - custom_logger, - schedule, - rng, + agent_trainer, + reward_net, + random_fragmenter, + custom_logger, + schedule, + rng, ): main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, @@ -407,8 +423,8 @@ def test_reward_ensemble_trainer_raises_type_error(venv, rng): loss = preference_comparisons.CrossEntropyRewardLoss() with pytest.raises( - TypeError, - match=r"PreferenceModel of a RewardEnsemble expected by EnsembleTrainer.", + TypeError, + match=r"PreferenceModel of a RewardEnsemble expected by EnsembleTrainer.", ): preference_comparisons.EnsembleTrainer( preference_model, @@ -418,11 +434,11 @@ def test_reward_ensemble_trainer_raises_type_error(venv, rng): def test_correct_reward_trainer_used_by_default( - agent_trainer, - reward_net, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + reward_net, + random_fragmenter, + custom_logger, + rng, ): main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, @@ -449,11 +465,11 @@ def test_correct_reward_trainer_used_by_default( def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + rng, ): reward_net = testing_reward_nets.make_ensemble( venv.observation_space, @@ -465,8 +481,8 @@ def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( r"AddSTDRewardWrapper but found NormalizedRewardNet." ) with pytest.raises( - ValueError, - match=rgx, + ValueError, + match=rgx, ): preference_comparisons.PreferenceComparisons( agent_trainer, @@ -481,11 +497,11 @@ def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( def test_discount_rate_no_crash( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + rng, ): # also use a non-zero noise probability to check that doesn't cause errors reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) @@ -517,10 +533,10 @@ def test_discount_rate_no_crash( def create_reward_trainer( - venv, - seed: int, - batch_size: int, - **kwargs: Any, + venv, + seed: int, + batch_size: int, + **kwargs: Any, ): th.manual_seed(seed) reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) @@ -538,11 +554,11 @@ def create_reward_trainer( def test_gradient_accumulation( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + rng, ): # Test that training steps on the same dataset with different minibatch sizes # result in the same reward network. @@ -550,7 +566,7 @@ def test_gradient_accumulation( minibatch_size = 3 num_trajectories = 5 - # TODO: create querent + preference_querent = preference_comparisons.PreferenceQuerent(rng=rng) preference_gatherer = preference_comparisons.SyntheticGatherer( custom_logger=custom_logger, rng=rng, @@ -558,11 +574,12 @@ def test_gradient_accumulation( dataset = preference_comparisons.PreferenceDataset() trajectory = agent_trainer.sample(num_trajectories) fragments = random_fragmenter(trajectory, 1, num_trajectories) - # TODO: call querent with fragments (= queries) - preferences = preference_gatherer(fragments) + identified_queries = preference_querent(fragments) + preference_gatherer.add(identified_queries) + fragments, preferences = preference_gatherer() dataset.push(fragments, preferences) - seed = rng.integers(2 ** 32) + seed = rng.integers(2**32) reward_trainer1, reward_net1 = create_reward_trainer(venv, seed, batch_size) reward_trainer2, reward_net2 = create_reward_trainer( venv, @@ -573,7 +590,7 @@ def test_gradient_accumulation( for step in range(8): print("Step", step) - seed = rng.integers(2 ** 32) + seed = rng.integers(2**32) th.manual_seed(seed) reward_trainer1.train(dataset) @@ -591,30 +608,34 @@ def test_gradient_accumulation( for p1, p2 in zip(reward_net1.parameters(), reward_net2.parameters()): th.testing.assert_close(p1, p2, atol=atol, rtol=rtol) -# TODO: fix test (same as test_gradient_accumulation) + def test_synthetic_gatherer_deterministic( - agent_trainer, - random_fragmenter, - rng, + agent_trainer, + random_fragmenter, + rng, ): + preference_querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer( temperature=0, rng=rng, ) trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) - preferences1 = gatherer(fragments) - preferences2 = gatherer(fragments) + identified_queries = preference_querent(fragments) + gatherer.add(identified_queries) + _, preferences1 = gatherer() + gatherer.add(identified_queries) + _, preferences2 = gatherer() assert np.all(preferences1 == preferences2) def test_synthetic_gatherer_raises( - agent_trainer, - random_fragmenter, + agent_trainer, + random_fragmenter, ): with pytest.raises( - ValueError, - match="If `sample` is True, then `rng` must be provided", + ValueError, + match="If `sample` is True, then `rng` must be provided", ): preference_comparisons.SyntheticGatherer( temperature=0, @@ -652,8 +673,8 @@ def test_fragments_too_short_error(agent_trainer): warning_threshold=0, ) with pytest.raises( - ValueError, - match="No trajectories are long enough for the desired fragment length.", + ValueError, + match="No trajectories are long enough for the desired fragment length.", ): # the only important bit is that fragment_length is higher than # we'll ever reach @@ -675,15 +696,17 @@ def test_preference_dataset_errors(agent_trainer, random_fragmenter): dataset.push(fragments, preferences) -# TODO: fix test (same as test_gradient_accumulation) def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): dataset = preference_comparisons.PreferenceDataset(max_size=5) trajectories = agent_trainer.sample(10) + querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) for i in range(6): fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=1) - preferences = gatherer(fragments) + identified_queries = querent(fragments) + gatherer.add(identified_queries) + fragments, preferences = gatherer() assert len(dataset) == min(i, 5) dataset.push(fragments, preferences) assert len(dataset) == min(i + 1, 5) @@ -692,18 +715,20 @@ def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): assert len(dataset) == 5 -# TODO: fix test (same as test_gradient_accumulation) def test_store_and_load_preference_dataset( - agent_trainer, - random_fragmenter, - tmp_path, - rng, + agent_trainer, + random_fragmenter, + tmp_path, + rng, ): dataset = preference_comparisons.PreferenceDataset() trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) + querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) - preferences = gatherer(fragments) + identified_queries = querent(fragments) + gatherer.add(identified_queries) + fragments, preferences = gatherer() dataset.push(fragments, preferences) path = tmp_path / "preferences.pkl" @@ -719,12 +744,12 @@ def test_store_and_load_preference_dataset( def test_exploration_no_crash( - agent, - reward_net, - venv, - random_fragmenter, - custom_logger, - rng, + agent, + reward_net, + venv, + random_fragmenter, + custom_logger, + rng, ): agent_trainer = preference_comparisons.AgentTrainer( agent, @@ -748,12 +773,12 @@ def test_exploration_no_crash( @pytest.mark.parametrize("uncertainty_on", UNCERTAINTY_ON) def test_active_fragmenter_discount_rate_no_crash( - agent_trainer, - venv, - random_fragmenter, - uncertainty_on, - custom_logger, - rng, + agent_trainer, + venv, + random_fragmenter, + uncertainty_on, + custom_logger, + rng, ): # also use a non-zero noise probability to check that doesn't cause errors reward_net = reward_nets.RewardEnsemble( @@ -816,13 +841,13 @@ def interval_param_scaler() -> updaters.IntervalParamScaler: def test_reward_trainer_regularization_no_crash( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - preference_model, - interval_param_scaler, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + preference_model, + interval_param_scaler, + rng, ): reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -856,13 +881,13 @@ def test_reward_trainer_regularization_no_crash( def test_reward_trainer_regularization_raises( - agent_trainer, - venv, - random_fragmenter, - custom_logger, - preference_model, - interval_param_scaler, - rng, + agent_trainer, + venv, + random_fragmenter, + custom_logger, + preference_model, + interval_param_scaler, + rng, ): reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -893,8 +918,8 @@ def test_reward_trainer_regularization_raises( rng=rng, ) with pytest.raises( - ValueError, - match="Not enough data samples to split " "into training and validation.*", + ValueError, + match="Not enough data samples to split " "into training and validation.*", ): main_trainer.train(100, 10) @@ -929,8 +954,8 @@ def preference_model(venv) -> preference_comparisons.PreferenceModel: def test_active_fragmenter_uncertainty_on_not_supported_error( - ensemble_preference_model, - random_fragmenter, + ensemble_preference_model, + random_fragmenter, ): re_match = r".* not supported\.\n\s+`uncertainty_on` should be from .*" with pytest.raises(ValueError, match=re_match): @@ -954,12 +979,12 @@ def test_active_fragmenter_uncertainty_on_not_supported_error( def test_active_selection_raises_error_when_initialized_without_an_ensemble( - preference_model, - random_fragmenter, + preference_model, + random_fragmenter, ): with pytest.raises( - ValueError, - match=r"PreferenceModel not wrapped over an ensemble.*", + ValueError, + match=r"PreferenceModel not wrapped over an ensemble.*", ): preference_comparisons.ActiveSelectionFragmenter( preference_model=preference_model, @@ -1112,12 +1137,12 @@ def ensemble_reward_trainer(venv, rng): [basic_reward_trainer, ensemble_reward_trainer], ) def test_that_trainer_improves( - action_is_reward_venv, - action_is_reward_agent, - action_is_reward_trainer_func, - random_fragmenter, - custom_logger, - rng, + action_is_reward_venv, + action_is_reward_agent, + action_is_reward_trainer_func, + random_fragmenter, + custom_logger, + rng, ): """Tests that training improves performance of the reward network and agent.""" action_is_reward_trainer = action_is_reward_trainer_func(action_is_reward_venv, rng) @@ -1155,8 +1180,8 @@ def test_that_trainer_improves( later_reward_network_stats = main_trainer.train(50, 20) assert ( - first_reward_network_stats["reward_loss"] - > later_reward_network_stats["reward_loss"] + first_reward_network_stats["reward_loss"] + > later_reward_network_stats["reward_loss"] ) # The agent should have also improved @@ -1204,7 +1229,6 @@ def test_sends_put_request_for_each_query(requests_mock): class ConcretePreferenceGatherer(PreferenceGatherer): - def __call__(self) -> Tuple[np.ndarray, np.ndarray]: pass @@ -1236,7 +1260,10 @@ def test_returns_none_for_unanswered_query(requests_mock): gatherer = PrefCollectGatherer(pref_collect_address=address) - requests_mock.get(f"{address}/preferences/query/{query_id}", json={"query_id": query_id, "label": answer}) + requests_mock.get( + f"{address}/preferences/query/{query_id}", + json={"query_id": query_id, "label": answer}, + ) preference = gatherer._gather_preference(query_id) @@ -1250,7 +1277,10 @@ def test_returns_preference_for_answered_query(requests_mock): gatherer = PrefCollectGatherer(pref_collect_address=address) - requests_mock.get(f"{address}/preferences/query/{query_id}", json={"query_id": query_id, "label": answer}) + requests_mock.get( + f"{address}/preferences/query/{query_id}", + json={"query_id": query_id, "label": answer}, + ) preference = gatherer._gather_preference(query_id) @@ -1258,7 +1288,9 @@ def test_returns_preference_for_answered_query(requests_mock): def test_keeps_pending_query_for_unanswered_query(): - gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + gatherer = PrefCollectGatherer( + pref_collect_address="https://test.de", wait_for_user=False + ) gatherer._gather_preference = MagicMock(return_value=None) gatherer.pending_queries = {"1234": Mock()} @@ -1269,7 +1301,9 @@ def test_keeps_pending_query_for_unanswered_query(): def test_deletes_pending_query_for_answered_query(): - gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + gatherer = PrefCollectGatherer( + pref_collect_address="https://test.de", wait_for_user=False + ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) gatherer.pending_queries = {"1234": Mock()} @@ -1280,7 +1314,9 @@ def test_deletes_pending_query_for_answered_query(): def test_gathers_valid_preference(): - gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + gatherer = PrefCollectGatherer( + pref_collect_address="https://test.de", wait_for_user=False + ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) query = Mock() @@ -1293,9 +1329,11 @@ def test_gathers_valid_preference(): def test_ignores_incomparable_answer(): - gatherer = PrefCollectGatherer(pref_collect_address="https://test.de", wait_for_user=False) + gatherer = PrefCollectGatherer( + pref_collect_address="https://test.de", wait_for_user=False + ) # incomparable preference value = -1 - gatherer._gather_preference = MagicMock(return_value=-1.) + gatherer._gather_preference = MagicMock(return_value=-1.0) gatherer.pending_queries = {"1234": Mock()} gathered_queries, gathered_preferences = gatherer() From 5be908703387e84ab82b6e00dfdbd36bf09b902e Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 25 May 2023 20:45:07 +0200 Subject: [PATCH 033/143] Fix flake8 --- experiments/human_preferences.sh | 11 -------- .../algorithms/preference_comparisons.py | 26 ++++++++++++++----- src/imitation/data/wrappers.py | 13 ++++------ .../config/train_preference_comparisons.py | 14 +++++----- .../scripts/train_preference_comparisons.py | 1 - .../algorithms/test_preference_comparisons.py | 14 +++++----- 6 files changed, 39 insertions(+), 40 deletions(-) delete mode 100644 experiments/human_preferences.sh diff --git a/experiments/human_preferences.sh b/experiments/human_preferences.sh deleted file mode 100644 index 1690d731e..000000000 --- a/experiments/human_preferences.sh +++ /dev/null @@ -1,11 +0,0 @@ -python -m imitation.scripts.train_preference_comparisons \ - with \ - pendulum \ - human_preferences \ - total_comparisons=5000 \ - total_timesteps=1000000 \ - gatherer_kwargs.pref_collect_address=http://127.0.0.1:8000 \ - gatherer_kwargs.video_output_dir=../pref-collect/videofiles \ - gatherer_kwargs.wait_for_user=True \ - common.post_wrappers_kwargs.RenderImageInfoWrapper.scale_factor=0.5 \ - common.post_wrappers_kwargs.RenderImageInfoWrapper.use_file_cache=True \ \ No newline at end of file diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e9cff3fb8..f31990d00 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -801,10 +801,12 @@ def __init__( self.logger = custom_logger or imit_logger.configure() def __call__( - self, queries: Sequence[TrajectoryWithRewPair] + self, queries: Sequence[TrajectoryWithRewPair], ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: """Queries the user for their preferences. - This dummy implementation does nothing because by default the queries are answered by an oracle. + + This dummy implementation does nothing because by default the queries are + answered by an oracle. Args: queries: sequence of pairs of trajectory fragments @@ -816,7 +818,7 @@ def __call__( class PrefCollectQuerent(PreferenceQuerent): - """Sends queries to the PrefCollect interface.""" + """Sends queries to a preference collection web service via HTTP requests.""" def __init__( self, @@ -826,6 +828,15 @@ def __init__( rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): + """Initializes the PrefCollect querent. + + Args: + pref_collect_address: end point of the PrefCollect web service. + video_output_dir: path to the video clip directory. + video_fps: frames per second of the generated videos. + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + """ super().__init__(custom_logger) self.rng = rng self.query_endpoint = pref_collect_address + "/preferences/query/" @@ -836,7 +847,7 @@ def __init__( os.makedirs(self.video_output_dir, exist_ok=True) def __call__( - self, queries: Sequence[TrajectoryWithRewPair] + self, queries: Sequence[TrajectoryWithRewPair], ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: identified_queries = super().__call__(queries) @@ -861,11 +872,12 @@ def __call__( def _query(self, query_id): requests.put( - self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)} + self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)}, ) def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> None: + """Write fragment video clip.""" frame_shape = get_frame_shape(fragment) video_writer = cv2.VideoWriter( output_path, @@ -901,6 +913,7 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> def get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: + """Calculate frame shape.""" if "rendered_img" in fragment.infos[0]: rendered_img_info = fragment.infos[0]["rendered_img"] # If path is provided load cached image @@ -1079,6 +1092,7 @@ def __init__( video_width: width of the video in pixels. video_height: height of the video in pixels. custom_logger: Where to log to; if None (default), creates a new logger. + rng: random number generator Raises: ValueError: if `video_dir` is not a directory. @@ -2163,7 +2177,7 @@ def train( with self.logger.accumulate_means("preferences"): self.logger.log("Gathering preferences") - # Gather fragment pairs (queries) for which preferences have been provided + # Gather fragment pairs for which preferences have been provided queries, preferences = self.preference_gatherer() # Free up RAM or disk space from keeping rendered images diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 2e16bd0b4..58f4846b1 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -21,20 +21,22 @@ class RenderImageInfoWrapper(gym.Wrapper): Can be very memory intensive for large render images. Use `scale_factor` to reduce render image size. If you need to preserve the resolution and memory - runs out, you can activate `ues_file_cache` to save - render images and instead put their path into `info`. + runs out, you can activate `use_file_cache` to save + rendered images and instead put their path into `info`. """ def __init__( self, env: gym.Env, - scale_factor: float = 1., + scale_factor: float = 1.0, use_file_cache: bool = False, ): """Builds RenderImageInfoWrapper. Args: env: Environment to wrap. + scale_factor: scales rendered images to be stored. + use_file_cache: whether to save rendered images to disk. """ super().__init__(env) self.scale_factor = scale_factor @@ -42,11 +44,6 @@ def __init__( if self.use_file_cache: self.file_cache = tempfile.mkdtemp("imitation_RenderImageInfoWrapper") - self._active = True - - def set_render_image_active(self, active: bool): - self._active = active - def step(self, action): obs, rew, done, info = self.env.step(action) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 7af6f7a55..c177fadce 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -72,15 +72,14 @@ def train_defaults(): @train_preference_comparisons_ex.named_config def synch_human_preferences(): gatherer_cls = preference_comparisons.SynchronousHumanGatherer - gatherer_kwargs = dict( - video_dir="videos" - ) + gatherer_kwargs = dict(video_dir="videos") querent_cls = preference_comparisons.PreferenceQuerent querent_kwargs = dict() environment = dict( post_wrappers=dict( - RenderImageInfoWrapper=lambda env, env_id, **kwargs: - RenderImageInfoWrapper(env, **kwargs), + RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( + env, **kwargs, + ), ), num_vec=2, post_wrappers_kwargs=dict( @@ -104,8 +103,9 @@ def human_preferences(): ) environment = dict( post_wrappers=dict( - RenderImageInfoWrapper=lambda env, env_id, **kwargs: - RenderImageInfoWrapper(env, **kwargs), + RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( + env, **kwargs, + ), ), post_wrappers_kwargs=dict( RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True), diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 057dfe302..774c84f5f 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -23,7 +23,6 @@ from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, reward from imitation.scripts.ingredients import rl as rl_common -from imitation.util import video_wrapper def save_model( diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index da3ec5cdb..46707ddb3 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -267,14 +267,13 @@ def test_preference_comparisons_raises( querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) - # no rng, must provide fragmenter, preference gatherer, preference_querent, reward trainer no_rng_msg = ( ".*don't provide.*random state.*provide.*fragmenter" ".*preference gatherer.*preference querent.*reward_trainer.*" ) def build_preference_comparisons( - gatherer, querent, reward_trainer, fragmenter, rng + gatherer, querent, reward_trainer, fragmenter, rng, ): preference_comparisons.PreferenceComparisons( agent_trainer, @@ -304,7 +303,7 @@ def build_preference_comparisons( # This should not raise build_preference_comparisons( - gatherer, querent, reward_trainer, random_fragmenter, rng=None + gatherer, querent, reward_trainer, random_fragmenter, rng=None, ) # if providing fragmenter, preference gatherer, reward trainer, does not need rng. @@ -1229,6 +1228,7 @@ def test_sends_put_request_for_each_query(requests_mock): class ConcretePreferenceGatherer(PreferenceGatherer): + """A concrete preference gatherer for unit testing purposes only.""" def __call__(self) -> Tuple[np.ndarray, np.ndarray]: pass @@ -1289,7 +1289,7 @@ def test_returns_preference_for_answered_query(requests_mock): def test_keeps_pending_query_for_unanswered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False + pref_collect_address="https://test.de", wait_for_user=False, ) gatherer._gather_preference = MagicMock(return_value=None) gatherer.pending_queries = {"1234": Mock()} @@ -1302,7 +1302,7 @@ def test_keeps_pending_query_for_unanswered_query(): def test_deletes_pending_query_for_answered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False + pref_collect_address="https://test.de", wait_for_user=False, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1315,7 +1315,7 @@ def test_deletes_pending_query_for_answered_query(): def test_gathers_valid_preference(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False + pref_collect_address="https://test.de", wait_for_user=False, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1330,7 +1330,7 @@ def test_gathers_valid_preference(): def test_ignores_incomparable_answer(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False + pref_collect_address="https://test.de", wait_for_user=False, ) # incomparable preference value = -1 gatherer._gather_preference = MagicMock(return_value=-1.0) From 8e3c7b7e869e178bccf3fd8ac0fd9aee01a055c8 Mon Sep 17 00:00:00 2001 From: rklassert Date: Thu, 25 May 2023 21:10:52 +0200 Subject: [PATCH 034/143] Fix flake8 and codespell --- .../algorithms/preference_comparisons.py | 39 +++++++++---------- .../config/train_preference_comparisons.py | 6 ++- .../algorithms/test_preference_comparisons.py | 25 +++++++++--- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index f31990d00..ca4df3125 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -801,7 +801,8 @@ def __init__( self.logger = custom_logger or imit_logger.configure() def __call__( - self, queries: Sequence[TrajectoryWithRewPair], + self, + queries: Sequence[TrajectoryWithRewPair], ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: """Queries the user for their preferences. @@ -847,14 +848,16 @@ def __init__( os.makedirs(self.video_output_dir, exist_ok=True) def __call__( - self, queries: Sequence[TrajectoryWithRewPair], + self, + queries: Sequence[TrajectoryWithRewPair], ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: identified_queries = super().__call__(queries) # Save fragment videos and submit queries for query_id, query in identified_queries.items(): output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "{}.webm" + self.video_output_dir, + f"{query_id}" + "{}.webm", ) write_fragment_video( query[0], @@ -872,7 +875,8 @@ def __call__( def _query(self, query_id): requests.put( - self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)}, + self.query_endpoint + query_id, + json={"uuid": "{}".format(query_id)}, ) @@ -905,7 +909,8 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> if frame.shape[-1] < 3: missing_channels = 3 - frame.shape[-1] frame = np.concatenate( - [frame] + missing_channels * [frame[..., -1][..., None]], axis=-1 + [frame] + missing_channels * [frame[..., -1][..., None]], + axis=-1, ) video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) @@ -923,11 +928,11 @@ def get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: single_frame = rendered_img_info else: single_frame = np.array(fragment.obs[0]) - # Check whether obervations are image-like + # Check whether observations are image-like if len(single_frame.shape) < 2: raise ValueError( "Observation must be an image, " - f"but shape {single_frame.shape} has too few dimensions!" + f"but shape {single_frame.shape} has too few dimensions!", ) # Swap dimensions, because matrix and image dims are swapped return single_frame.shape[1], single_frame.shape[0] @@ -1091,11 +1096,9 @@ def __init__( video_dir: directory where videos of the trajectories are saved. video_width: width of the video in pixels. video_height: height of the video in pixels. + frames_per_second: frames per second of the video. custom_logger: Where to log to; if None (default), creates a new logger. rng: random number generator - - Raises: - ValueError: if `video_dir` is not a directory. """ super().__init__(custom_logger=custom_logger, rng=rng) self.video_dir = video_dir @@ -1112,9 +1115,6 @@ def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: notebook, it will display the videos. Either way, it will request 1 or 2 to indicate which is preferred. - Args: - fragment_pairs: sequence of pairs of trajectory fragments - Returns: A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape (b, ), where b is the length of `fragment_pairs` @@ -1143,16 +1143,13 @@ def _display_videos_and_gather_preference(self, query_id) -> bool: """Displays the videos of the two fragments. Args: - frag1: first fragment - frag2: second fragment + query_id: the id of the fragment pair to be displayed. Returns: True if the first fragment is preferred, False if not. Raises: KeyboardInterrupt: if the user presses q to quit. - RuntimeError: if the video files cannot be opened. - ValueError: if the trajectory infos are not set. """ frag1_video_path = os.path.join(self.video_dir, f"{query_id}-left.webm") frag2_video_path = os.path.join(self.video_dir, f"{query_id}-right.webm") @@ -2049,13 +2046,15 @@ def __init__( if self.rng is None and has_any_rng_args_none: raise ValueError( "If you don't provide a random state, you must provide your own " - "seeded fragmenter, preference gatherer, preference querent, and reward_trainer. " + "seeded fragmenter, preference gatherer, preference querent, " + "and reward_trainer. " "You can initialize a random state with `np.random.default_rng(seed)`.", ) elif self.rng is not None and not has_any_rng_args_none: raise ValueError( - "If you provide your own fragmenter, preference gatherer, preference querent," - "and reward trainer, you don't need to provide a random state.", + "If you provide your own fragmenter, preference gatherer, " + "preference querent, and reward trainer, " + "you don't need to provide a random state.", ) if reward_trainer is None: diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index c177fadce..6ba4b6bfd 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -78,7 +78,8 @@ def synch_human_preferences(): environment = dict( post_wrappers=dict( RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( - env, **kwargs, + env, + **kwargs, ), ), num_vec=2, @@ -104,7 +105,8 @@ def human_preferences(): environment = dict( post_wrappers=dict( RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( - env, **kwargs, + env, + **kwargs, ), ), post_wrappers_kwargs=dict( diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 46707ddb3..26332d022 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -273,7 +273,11 @@ def test_preference_comparisons_raises( ) def build_preference_comparisons( - gatherer, querent, reward_trainer, fragmenter, rng, + gatherer, + querent, + reward_trainer, + fragmenter, + rng, ): preference_comparisons.PreferenceComparisons( agent_trainer, @@ -303,7 +307,11 @@ def build_preference_comparisons( # This should not raise build_preference_comparisons( - gatherer, querent, reward_trainer, random_fragmenter, rng=None, + gatherer, + querent, + reward_trainer, + random_fragmenter, + rng=None, ) # if providing fragmenter, preference gatherer, reward trainer, does not need rng. @@ -1229,6 +1237,7 @@ def test_sends_put_request_for_each_query(requests_mock): class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" + def __call__(self) -> Tuple[np.ndarray, np.ndarray]: pass @@ -1289,7 +1298,8 @@ def test_returns_preference_for_answered_query(requests_mock): def test_keeps_pending_query_for_unanswered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) gatherer._gather_preference = MagicMock(return_value=None) gatherer.pending_queries = {"1234": Mock()} @@ -1302,7 +1312,8 @@ def test_keeps_pending_query_for_unanswered_query(): def test_deletes_pending_query_for_answered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1315,7 +1326,8 @@ def test_deletes_pending_query_for_answered_query(): def test_gathers_valid_preference(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1330,7 +1342,8 @@ def test_gathers_valid_preference(): def test_ignores_incomparable_answer(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) # incomparable preference value = -1 gatherer._gather_preference = MagicMock(return_value=-1.0) From b23abcda4dafa07a04aedf8490bbb30600613149 Mon Sep 17 00:00:00 2001 From: rklassert Date: Thu, 8 Jun 2023 17:38:54 +0200 Subject: [PATCH 035/143] Fix some mypy errors --- src/imitation/algorithms/preference_comparisons.py | 4 ++-- tests/algorithms/test_preference_comparisons.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index ca4df3125..63a8f9c1f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -961,7 +961,7 @@ def __init__( self.pending_queries = {} @abc.abstractmethod - def __call__(self) -> Tuple[np.ndarray, np.ndarray]: + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Gathers the probabilities that fragment 1 is preferred in `queries`. Returns: @@ -1349,7 +1349,7 @@ def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: for info in traj.infos: try: rendered_img_info = info["rendered_img"] - if isinstance(rendered_img_info, AnyPath.__args__): + if isinstance(rendered_img_info, (str, bytes, os.PathLike)): os.remove(rendered_img_info) elif isinstance(rendered_img_info, np.ndarray): del info["rendered_img"] diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 26332d022..48688e396 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1,5 +1,6 @@ """Tests for the preference comparisons reward learning implementation.""" +import abc import math import pathlib import re @@ -28,7 +29,7 @@ SyntheticGatherer, ) from imitation.data import types -from imitation.data.types import TrajectoryWithRew +from imitation.data.types import TrajectoryWithRew, TrajectoryWithRewPair from imitation.regularization import regularizers, updaters from imitation.rewards import reward_nets from imitation.testing import reward_improvement @@ -1238,7 +1239,8 @@ def test_sends_put_request_for_each_query(requests_mock): class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" - def __call__(self) -> Tuple[np.ndarray, np.ndarray]: + @abc.abstractmethod + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: pass From b6fb9147c8b8acdbbf25b165190a5b5c5a3f6c67 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 8 Jun 2023 17:39:49 +0200 Subject: [PATCH 036/143] Fix mypy --- setup.py | 1 + src/imitation/algorithms/preference_comparisons.py | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 90a659354..1a94e4f36 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ "wandb==0.12.21", "setuptools_scm~=7.0.5", "pre-commit>=2.20.0", + "types-requests~=2.31.0.1" # here or in general requirements? ] + PARALLEL_REQUIRE + ATARI_REQUIRE diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 63a8f9c1f..4fe100a63 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -788,22 +788,19 @@ class PreferenceQuerent: def __init__( self, - rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ) -> None: """Initializes the preference querent. Args: - rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - del rng self.logger = custom_logger or imit_logger.configure() def __call__( self, queries: Sequence[TrajectoryWithRewPair], - ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + ) -> Dict[str, TrajectoryWithRewPair]: """Queries the user for their preferences. This dummy implementation does nothing because by default the queries are @@ -825,7 +822,7 @@ def __init__( self, pref_collect_address: str, video_output_dir: AnyPath, - video_fps: str = 20, + video_fps: int = 20, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): From d230f44753c1acec1dd07b238725598fbc6e980a Mon Sep 17 00:00:00 2001 From: rklassert Date: Thu, 8 Jun 2023 17:56:10 +0200 Subject: [PATCH 037/143] Fix mypy error --- setup.py | 2 +- src/imitation/algorithms/preference_comparisons.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 1a94e4f36..1756979a0 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ "wandb==0.12.21", "setuptools_scm~=7.0.5", "pre-commit>=2.20.0", - "types-requests~=2.31.0.1" # here or in general requirements? + "types-requests~=2.31.0.1", ] + PARALLEL_REQUIRE + ATARI_REQUIRE diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 4fe100a63..ccb0fc93e 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -847,7 +847,7 @@ def __init__( def __call__( self, queries: Sequence[TrajectoryWithRewPair], - ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + ) -> Dict[str, TrajectoryWithRewPair]: identified_queries = super().__call__(queries) # Save fragment videos and submit queries From c7d90105fbc89413c41c5b8b7f0c86b1df4e3b80 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 8 Jun 2023 18:47:52 +0200 Subject: [PATCH 038/143] Fix mypy --- .../algorithms/preference_comparisons.py | 38 +++++++++---------- .../scripts/train_preference_comparisons.py | 1 - 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index ccb0fc93e..a8dab2c16 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -821,7 +821,7 @@ class PrefCollectQuerent(PreferenceQuerent): def __init__( self, pref_collect_address: str, - video_output_dir: AnyPath, + video_output_dir: str, video_fps: int = 20, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, @@ -877,7 +877,7 @@ def _query(self, query_id): ) -def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> None: +def write_fragment_video(fragment: TrajectoryWithRew, frames_per_second: int, output_path: AnyPath) -> None: """Write fragment video clip.""" frame_shape = get_frame_shape(fragment) video_writer = cv2.VideoWriter( @@ -888,16 +888,18 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> ) # Make videos from rendered observations if available - if "rendered_img" in fragment.infos[0]: - frames = [] + frames: np.ndarray + if fragment.infos is not None and "rendered_img" in fragment.infos[0]: + frames_list = [] for i in range(len(fragment.infos)): frame_info = fragment.infos[i]["rendered_img"] # If path is provided load cached image - if isinstance(frame_info, AnyPath.__args__): + if isinstance(frame_info, (str, bytes, os.PathLike)): frame = np.load(frame_info) elif isinstance(frame_info, np.ndarray): frame = frame_info - frames.append(frame) + frames_list.append(frame) + frames = np.array(frames_list) else: frames = fragment.obs @@ -916,10 +918,10 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> def get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: """Calculate frame shape.""" - if "rendered_img" in fragment.infos[0]: + if fragment.infos is not None and "rendered_img" in fragment.infos[0]: rendered_img_info = fragment.infos[0]["rendered_img"] # If path is provided load cached image - if isinstance(rendered_img_info, AnyPath.__args__): + if isinstance(rendered_img_info, (str, bytes, os.PathLike)): single_frame = np.load(rendered_img_info) else: single_frame = rendered_img_info @@ -955,7 +957,7 @@ def __init__( # the PreferenceGatherer we use needs one). del rng self.logger = custom_logger or imit_logger.configure() - self.pending_queries = {} + self.pending_queries: Dict = {} @abc.abstractmethod def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: @@ -1136,7 +1138,7 @@ def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: self.pending_queries.clear() return queries, preferences - def _display_videos_and_gather_preference(self, query_id) -> bool: + def _display_videos_and_gather_preference(self, query_id: uuid.UUID) -> bool: """Displays the videos of the two fragments. Args: @@ -1148,8 +1150,8 @@ def _display_videos_and_gather_preference(self, query_id) -> bool: Raises: KeyboardInterrupt: if the user presses q to quit. """ - frag1_video_path = os.path.join(self.video_dir, f"{query_id}-left.webm") - frag2_video_path = os.path.join(self.video_dir, f"{query_id}-right.webm") + frag1_video_path = pathlib.Path(self.video_dir, f"{query_id}-left.webm") + frag2_video_path = pathlib.Path(self.video_dir, f"{query_id}-right.webm") if self._in_ipython(): self._display_videos_in_notebook(frag1_video_path, frag2_video_path) @@ -1306,13 +1308,12 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(custom_logger) - self.rng = rng + super().__init__(rng, custom_logger) self.query_endpoint = pref_collect_address + "/preferences/query/" self.pending_queries = {} self.wait_for_user = wait_for_user - def __call__(self) -> Tuple[Sequence[TrajectoryPair], np.ndarray]: + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: # TODO: create user-independent (automated) waiting policy if self.wait_for_user: @@ -1343,15 +1344,13 @@ def _gather_preference(self, query_id: str) -> float: def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: """Removes rendered images of the provided trajectories list.""" for traj in trajectories: - for info in traj.infos: - try: + if traj.infos is not None and "rendered_img" in traj.infos[0]: + for info in traj.infos: rendered_img_info = info["rendered_img"] if isinstance(rendered_img_info, (str, bytes, os.PathLike)): os.remove(rendered_img_info) elif isinstance(rendered_img_info, np.ndarray): del info["rendered_img"] - except KeyError: - pass class PreferenceDataset(data_th.Dataset): @@ -2087,7 +2086,6 @@ def __init__( assert self.rng is not None self.preference_querent = PreferenceQuerent( custom_logger=self.logger, - rng=self.rng, ) if preference_gatherer: self.preference_gatherer = preference_gatherer diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 774c84f5f..423248a08 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -245,7 +245,6 @@ def train_preference_comparisons( ) querent = querent_cls( **querent_kwargs, - rng=_rnd, custom_logger=custom_logger, ) From 2de7491a538fabe1234ab6355c6bac9e6244ecde Mon Sep 17 00:00:00 2001 From: rk1a Date: Sun, 25 Jun 2023 23:04:56 +0200 Subject: [PATCH 039/143] Fix mypy, flake8, codespell --- setup.py | 1 + .../algorithms/preference_comparisons.py | 90 +++++++++---------- .../config/train_preference_comparisons.py | 6 +- .../scripts/train_preference_comparisons.py | 1 - .../algorithms/test_preference_comparisons.py | 31 +++++-- 5 files changed, 70 insertions(+), 59 deletions(-) diff --git a/setup.py b/setup.py index 90a659354..1756979a0 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ "wandb==0.12.21", "setuptools_scm~=7.0.5", "pre-commit>=2.20.0", + "types-requests~=2.31.0.1", ] + PARALLEL_REQUIRE + ATARI_REQUIRE diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index f31990d00..a8dab2c16 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -788,21 +788,19 @@ class PreferenceQuerent: def __init__( self, - rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ) -> None: """Initializes the preference querent. Args: - rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - del rng self.logger = custom_logger or imit_logger.configure() def __call__( - self, queries: Sequence[TrajectoryWithRewPair], - ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: """Queries the user for their preferences. This dummy implementation does nothing because by default the queries are @@ -823,8 +821,8 @@ class PrefCollectQuerent(PreferenceQuerent): def __init__( self, pref_collect_address: str, - video_output_dir: AnyPath, - video_fps: str = 20, + video_output_dir: str, + video_fps: int = 20, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): @@ -847,14 +845,16 @@ def __init__( os.makedirs(self.video_output_dir, exist_ok=True) def __call__( - self, queries: Sequence[TrajectoryWithRewPair], - ) -> Dict[str, Sequence[TrajectoryWithRewPair]]: + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: identified_queries = super().__call__(queries) # Save fragment videos and submit queries for query_id, query in identified_queries.items(): output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "{}.webm" + self.video_output_dir, + f"{query_id}" + "{}.webm", ) write_fragment_video( query[0], @@ -872,11 +872,12 @@ def __call__( def _query(self, query_id): requests.put( - self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)}, + self.query_endpoint + query_id, + json={"uuid": "{}".format(query_id)}, ) -def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> None: +def write_fragment_video(fragment: TrajectoryWithRew, frames_per_second: int, output_path: AnyPath) -> None: """Write fragment video clip.""" frame_shape = get_frame_shape(fragment) video_writer = cv2.VideoWriter( @@ -887,16 +888,18 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> ) # Make videos from rendered observations if available - if "rendered_img" in fragment.infos[0]: - frames = [] + frames: np.ndarray + if fragment.infos is not None and "rendered_img" in fragment.infos[0]: + frames_list = [] for i in range(len(fragment.infos)): frame_info = fragment.infos[i]["rendered_img"] # If path is provided load cached image - if isinstance(frame_info, AnyPath.__args__): + if isinstance(frame_info, (str, bytes, os.PathLike)): frame = np.load(frame_info) elif isinstance(frame_info, np.ndarray): frame = frame_info - frames.append(frame) + frames_list.append(frame) + frames = np.array(frames_list) else: frames = fragment.obs @@ -905,7 +908,8 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> if frame.shape[-1] < 3: missing_channels = 3 - frame.shape[-1] frame = np.concatenate( - [frame] + missing_channels * [frame[..., -1][..., None]], axis=-1 + [frame] + missing_channels * [frame[..., -1][..., None]], + axis=-1, ) video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) @@ -914,20 +918,20 @@ def write_fragment_video(fragment, frames_per_second: int, output_path: str) -> def get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: """Calculate frame shape.""" - if "rendered_img" in fragment.infos[0]: + if fragment.infos is not None and "rendered_img" in fragment.infos[0]: rendered_img_info = fragment.infos[0]["rendered_img"] # If path is provided load cached image - if isinstance(rendered_img_info, AnyPath.__args__): + if isinstance(rendered_img_info, (str, bytes, os.PathLike)): single_frame = np.load(rendered_img_info) else: single_frame = rendered_img_info else: single_frame = np.array(fragment.obs[0]) - # Check whether obervations are image-like + # Check whether observations are image-like if len(single_frame.shape) < 2: raise ValueError( "Observation must be an image, " - f"but shape {single_frame.shape} has too few dimensions!" + f"but shape {single_frame.shape} has too few dimensions!", ) # Swap dimensions, because matrix and image dims are swapped return single_frame.shape[1], single_frame.shape[0] @@ -953,10 +957,10 @@ def __init__( # the PreferenceGatherer we use needs one). del rng self.logger = custom_logger or imit_logger.configure() - self.pending_queries = {} + self.pending_queries: Dict = {} @abc.abstractmethod - def __call__(self) -> Tuple[np.ndarray, np.ndarray]: + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Gathers the probabilities that fragment 1 is preferred in `queries`. Returns: @@ -1091,11 +1095,9 @@ def __init__( video_dir: directory where videos of the trajectories are saved. video_width: width of the video in pixels. video_height: height of the video in pixels. + frames_per_second: frames per second of the video. custom_logger: Where to log to; if None (default), creates a new logger. rng: random number generator - - Raises: - ValueError: if `video_dir` is not a directory. """ super().__init__(custom_logger=custom_logger, rng=rng) self.video_dir = video_dir @@ -1112,9 +1114,6 @@ def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: notebook, it will display the videos. Either way, it will request 1 or 2 to indicate which is preferred. - Args: - fragment_pairs: sequence of pairs of trajectory fragments - Returns: A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape (b, ), where b is the length of `fragment_pairs` @@ -1139,23 +1138,20 @@ def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: self.pending_queries.clear() return queries, preferences - def _display_videos_and_gather_preference(self, query_id) -> bool: + def _display_videos_and_gather_preference(self, query_id: uuid.UUID) -> bool: """Displays the videos of the two fragments. Args: - frag1: first fragment - frag2: second fragment + query_id: the id of the fragment pair to be displayed. Returns: True if the first fragment is preferred, False if not. Raises: KeyboardInterrupt: if the user presses q to quit. - RuntimeError: if the video files cannot be opened. - ValueError: if the trajectory infos are not set. """ - frag1_video_path = os.path.join(self.video_dir, f"{query_id}-left.webm") - frag2_video_path = os.path.join(self.video_dir, f"{query_id}-right.webm") + frag1_video_path = pathlib.Path(self.video_dir, f"{query_id}-left.webm") + frag2_video_path = pathlib.Path(self.video_dir, f"{query_id}-right.webm") if self._in_ipython(): self._display_videos_in_notebook(frag1_video_path, frag2_video_path) @@ -1312,13 +1308,12 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(custom_logger) - self.rng = rng + super().__init__(rng, custom_logger) self.query_endpoint = pref_collect_address + "/preferences/query/" self.pending_queries = {} self.wait_for_user = wait_for_user - def __call__(self) -> Tuple[Sequence[TrajectoryPair], np.ndarray]: + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: # TODO: create user-independent (automated) waiting policy if self.wait_for_user: @@ -1349,15 +1344,13 @@ def _gather_preference(self, query_id: str) -> float: def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: """Removes rendered images of the provided trajectories list.""" for traj in trajectories: - for info in traj.infos: - try: + if traj.infos is not None and "rendered_img" in traj.infos[0]: + for info in traj.infos: rendered_img_info = info["rendered_img"] - if isinstance(rendered_img_info, AnyPath.__args__): + if isinstance(rendered_img_info, (str, bytes, os.PathLike)): os.remove(rendered_img_info) elif isinstance(rendered_img_info, np.ndarray): del info["rendered_img"] - except KeyError: - pass class PreferenceDataset(data_th.Dataset): @@ -2049,13 +2042,15 @@ def __init__( if self.rng is None and has_any_rng_args_none: raise ValueError( "If you don't provide a random state, you must provide your own " - "seeded fragmenter, preference gatherer, preference querent, and reward_trainer. " + "seeded fragmenter, preference gatherer, preference querent, " + "and reward_trainer. " "You can initialize a random state with `np.random.default_rng(seed)`.", ) elif self.rng is not None and not has_any_rng_args_none: raise ValueError( - "If you provide your own fragmenter, preference gatherer, preference querent," - "and reward trainer, you don't need to provide a random state.", + "If you provide your own fragmenter, preference gatherer, " + "preference querent, and reward trainer, " + "you don't need to provide a random state.", ) if reward_trainer is None: @@ -2091,7 +2086,6 @@ def __init__( assert self.rng is not None self.preference_querent = PreferenceQuerent( custom_logger=self.logger, - rng=self.rng, ) if preference_gatherer: self.preference_gatherer = preference_gatherer diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index c177fadce..6ba4b6bfd 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -78,7 +78,8 @@ def synch_human_preferences(): environment = dict( post_wrappers=dict( RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( - env, **kwargs, + env, + **kwargs, ), ), num_vec=2, @@ -104,7 +105,8 @@ def human_preferences(): environment = dict( post_wrappers=dict( RenderImageInfoWrapper=lambda env, env_id, **kwargs: RenderImageInfoWrapper( - env, **kwargs, + env, + **kwargs, ), ), post_wrappers_kwargs=dict( diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 774c84f5f..423248a08 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -245,7 +245,6 @@ def train_preference_comparisons( ) querent = querent_cls( **querent_kwargs, - rng=_rnd, custom_logger=custom_logger, ) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 46707ddb3..48688e396 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1,5 +1,6 @@ """Tests for the preference comparisons reward learning implementation.""" +import abc import math import pathlib import re @@ -28,7 +29,7 @@ SyntheticGatherer, ) from imitation.data import types -from imitation.data.types import TrajectoryWithRew +from imitation.data.types import TrajectoryWithRew, TrajectoryWithRewPair from imitation.regularization import regularizers, updaters from imitation.rewards import reward_nets from imitation.testing import reward_improvement @@ -273,7 +274,11 @@ def test_preference_comparisons_raises( ) def build_preference_comparisons( - gatherer, querent, reward_trainer, fragmenter, rng, + gatherer, + querent, + reward_trainer, + fragmenter, + rng, ): preference_comparisons.PreferenceComparisons( agent_trainer, @@ -303,7 +308,11 @@ def build_preference_comparisons( # This should not raise build_preference_comparisons( - gatherer, querent, reward_trainer, random_fragmenter, rng=None, + gatherer, + querent, + reward_trainer, + random_fragmenter, + rng=None, ) # if providing fragmenter, preference gatherer, reward trainer, does not need rng. @@ -1229,7 +1238,9 @@ def test_sends_put_request_for_each_query(requests_mock): class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" - def __call__(self) -> Tuple[np.ndarray, np.ndarray]: + + @abc.abstractmethod + def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: pass @@ -1289,7 +1300,8 @@ def test_returns_preference_for_answered_query(requests_mock): def test_keeps_pending_query_for_unanswered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) gatherer._gather_preference = MagicMock(return_value=None) gatherer.pending_queries = {"1234": Mock()} @@ -1302,7 +1314,8 @@ def test_keeps_pending_query_for_unanswered_query(): def test_deletes_pending_query_for_answered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1315,7 +1328,8 @@ def test_deletes_pending_query_for_answered_query(): def test_gathers_valid_preference(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1330,7 +1344,8 @@ def test_gathers_valid_preference(): def test_ignores_incomparable_answer(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", wait_for_user=False, + pref_collect_address="https://test.de", + wait_for_user=False, ) # incomparable preference value = -1 gatherer._gather_preference = MagicMock(return_value=-1.0) From c7f6bacbf9052d6bd67e0d9da9b1fac0822e383a Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Mon, 26 Jun 2023 23:13:54 +0200 Subject: [PATCH 040/143] Adds preference_querent to args-is-none-check --- src/imitation/algorithms/preference_comparisons.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index a8dab2c16..f65e0f13a 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -877,7 +877,9 @@ def _query(self, query_id): ) -def write_fragment_video(fragment: TrajectoryWithRew, frames_per_second: int, output_path: AnyPath) -> None: +def write_fragment_video( + fragment: TrajectoryWithRew, frames_per_second: int, output_path: AnyPath +) -> None: """Write fragment video clip.""" frame_shape = get_frame_shape(fragment) video_writer = cv2.VideoWriter( @@ -2033,6 +2035,7 @@ def __init__( # are any of the optional args that require a rng None? has_any_rng_args_none = None in ( + preference_querent, preference_gatherer, fragmenter, reward_trainer, From 9f1de17dc3c75742be869591ae4ae0d41d2557a3 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Sun, 2 Jul 2023 20:18:38 +0200 Subject: [PATCH 041/143] Adds querent to pref comparisons in tutorial 5 --- .../5_train_preference_comparisons.ipynb | 557 +++++++++++++++++- 1 file changed, 543 insertions(+), 14 deletions(-) diff --git a/docs/tutorials/5_train_preference_comparisons.ipynb b/docs/tutorials/5_train_preference_comparisons.ipynb index b2cf6a500..9fbfd429e 100644 --- a/docs/tutorials/5_train_preference_comparisons.ipynb +++ b/docs/tutorials/5_train_preference_comparisons.ipynb @@ -19,8 +19,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:16:31.401489014Z", + "start_time": "2023-07-02T18:16:26.574553894Z" + } + }, "outputs": [], "source": [ "import random\n", @@ -45,6 +50,7 @@ " warning_threshold=0,\n", " rng=rng,\n", ")\n", + "querent = preference_comparisons.PreferenceQuerent()\n", "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", @@ -82,6 +88,7 @@ " reward_net,\n", " num_iterations=5,\n", " fragmenter=fragmenter,\n", + " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=100,\n", @@ -101,9 +108,499 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:17:18.887827019Z", + "start_time": "2023-07-02T18:16:31.404169166Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query schedule: [20, 51, 41, 34, 29, 25]\n", + "Collecting 40 fragments (4000 transitions)\n", + "Requested 4000 transitions but only 0 in buffer. Sampling 4000 additional transitions.\n", + "Creating fragment pairs\n", + "Gathering preferences\n", + "Dataset now contains 20 comparisons\n" + ] + }, + { + "data": { + "text/plain": "Training reward model: 0%| | 0/3 [00:00" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from stable_baselines3 import PPO\n", "from stable_baselines3.ppo import MlpPolicy\n", @@ -168,9 +684,22 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:17:21.983164120Z", + "start_time": "2023-07-02T18:17:20.917257800Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-1057.3085665\n" + ] + } + ], "source": [ "from stable_baselines3.common.evaluation import evaluate_policy\n", "\n", @@ -203,4 +732,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} From 27881a7cae944378102426a044f2ae036689690f Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Sun, 2 Jul 2023 20:30:04 +0200 Subject: [PATCH 042/143] Adds querent to pref comparisons in tutorial 5a --- ...rain_preference_comparisons_with_cnn.ipynb | 51 +++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb index 570e6154f..7d6a32207 100644 --- a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb +++ b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb @@ -22,10 +22,53 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "93187e19", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:29:25.396097810Z", + "start_time": "2023-07-02T18:29:18.742766046Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)\n", + "[Powered by Stella]\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'numpy.random._generator.Generator' object has no attribute 'randint'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[1], line 67\u001B[0m\n\u001B[1;32m 49\u001B[0m reward_trainer \u001B[38;5;241m=\u001B[39m preference_comparisons\u001B[38;5;241m.\u001B[39mBasicRewardTrainer(\n\u001B[1;32m 50\u001B[0m preference_model\u001B[38;5;241m=\u001B[39mpreference_model,\n\u001B[1;32m 51\u001B[0m loss\u001B[38;5;241m=\u001B[39mpreference_comparisons\u001B[38;5;241m.\u001B[39mCrossEntropyRewardLoss(),\n\u001B[1;32m 52\u001B[0m epochs\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m3\u001B[39m,\n\u001B[1;32m 53\u001B[0m rng\u001B[38;5;241m=\u001B[39mrng,\n\u001B[1;32m 54\u001B[0m )\n\u001B[1;32m 56\u001B[0m agent \u001B[38;5;241m=\u001B[39m PPO(\n\u001B[1;32m 57\u001B[0m policy\u001B[38;5;241m=\u001B[39mCnnPolicy,\n\u001B[1;32m 58\u001B[0m env\u001B[38;5;241m=\u001B[39mvenv,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 64\u001B[0m n_epochs\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m4\u001B[39m,\n\u001B[1;32m 65\u001B[0m )\n\u001B[0;32m---> 67\u001B[0m trajectory_generator \u001B[38;5;241m=\u001B[39m \u001B[43mpreference_comparisons\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mAgentTrainer\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 68\u001B[0m \u001B[43m \u001B[49m\u001B[43malgorithm\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43magent\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 69\u001B[0m \u001B[43m \u001B[49m\u001B[43mreward_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mreward_net\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 70\u001B[0m \u001B[43m \u001B[49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 71\u001B[0m \u001B[43m \u001B[49m\u001B[43mexploration_frac\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.0\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m 72\u001B[0m \u001B[43m \u001B[49m\u001B[43mrng\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mrng\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 73\u001B[0m \u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m pref_comparisons \u001B[38;5;241m=\u001B[39m preference_comparisons\u001B[38;5;241m.\u001B[39mPreferenceComparisons(\n\u001B[1;32m 76\u001B[0m trajectory_generator,\n\u001B[1;32m 77\u001B[0m reward_net,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 87\u001B[0m initial_epoch_multiplier\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m,\n\u001B[1;32m 88\u001B[0m )\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/algorithms/preference_comparisons.py:187\u001B[0m, in \u001B[0;36mAgentTrainer.__init__\u001B[0;34m(self, algorithm, reward_fn, venv, rng, exploration_frac, switch_prob, random_prob, custom_logger)\u001B[0m\n\u001B[1;32m 177\u001B[0m \u001B[38;5;66;03m# The BufferingWrapper records all trajectories, so we can return\u001B[39;00m\n\u001B[1;32m 178\u001B[0m \u001B[38;5;66;03m# them after training. This should come first (before the wrapper that\u001B[39;00m\n\u001B[1;32m 179\u001B[0m \u001B[38;5;66;03m# changes the reward function), so that we return the original environment\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 184\u001B[0m \u001B[38;5;66;03m# SB3 may move the image-channel dimension in the observation space, making\u001B[39;00m\n\u001B[1;32m 185\u001B[0m \u001B[38;5;66;03m# `algorithm.get_env()` not match with `reward_fn`.\u001B[39;00m\n\u001B[1;32m 186\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuffering_wrapper \u001B[38;5;241m=\u001B[39m wrappers\u001B[38;5;241m.\u001B[39mBufferingWrapper(venv)\n\u001B[0;32m--> 187\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mvenv \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreward_venv_wrapper \u001B[38;5;241m=\u001B[39m \u001B[43mreward_wrapper\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mRewardVecEnvWrapper\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 188\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbuffering_wrapper\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 189\u001B[0m \u001B[43m \u001B[49m\u001B[43mreward_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreward_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 190\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 192\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlog_callback \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreward_venv_wrapper\u001B[38;5;241m.\u001B[39mmake_log_callback()\n\u001B[1;32m 194\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39malgorithm\u001B[38;5;241m.\u001B[39mset_env(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mvenv)\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/rewards/reward_wrapper.py:73\u001B[0m, in \u001B[0;36mRewardVecEnvWrapper.__init__\u001B[0;34m(self, venv, reward_fn, ep_history)\u001B[0m\n\u001B[1;32m 71\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[1;32m 72\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_actions \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m---> 73\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/rewards/reward_wrapper.py:84\u001B[0m, in \u001B[0;36mRewardVecEnvWrapper.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 83\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m):\n\u001B[0;32m---> 84\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 85\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/data/wrappers.py:126\u001B[0m, in \u001B[0;36mBufferingWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 124\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_init_reset \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[1;32m 125\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_transitions \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m--> 126\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 127\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_traj_accum \u001B[38;5;241m=\u001B[39m rollout\u001B[38;5;241m.\u001B[39mTrajectoryAccumulator()\n\u001B[1;32m 128\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, ob \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(obs):\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py:58\u001B[0m, in \u001B[0;36mVecFrameStack.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 54\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Union[np\u001B[38;5;241m.\u001B[39mndarray, Dict[\u001B[38;5;28mstr\u001B[39m, np\u001B[38;5;241m.\u001B[39mndarray]]:\n\u001B[1;32m 55\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 56\u001B[0m \u001B[38;5;124;03m Reset all environments\u001B[39;00m\n\u001B[1;32m 57\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[0;32m---> 58\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m \u001B[38;5;66;03m# pytype:disable=annotation-type-mismatch\u001B[39;00m\n\u001B[1;32m 60\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstackedobs\u001B[38;5;241m.\u001B[39mreset(observation)\n\u001B[1;32m 61\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m observation\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:74\u001B[0m, in \u001B[0;36mDummyVecEnv.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 72\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m VecEnvObs:\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m env_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnum_envs):\n\u001B[0;32m---> 74\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menvs\u001B[49m\u001B[43m[\u001B[49m\u001B[43menv_idx\u001B[49m\u001B[43m]\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_save_obs(env_idx, obs)\n\u001B[1;32m 76\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_obs_from_buf()\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/monitor.py:83\u001B[0m, in \u001B[0;36mMonitor.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 81\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mExpected you to pass keyword argument \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mkey\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m into reset\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 82\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcurrent_reset_info[key] \u001B[38;5;241m=\u001B[39m value\n\u001B[0;32m---> 83\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/data/wrappers.py:261\u001B[0m, in \u001B[0;36mRolloutInfoWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 260\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 261\u001B[0m new_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 262\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_obs \u001B[38;5;241m=\u001B[39m [new_obs]\n\u001B[1;32m 263\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_rews \u001B[38;5;241m=\u001B[39m []\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/wrappers/time_limit.py:27\u001B[0m, in \u001B[0;36mTimeLimit.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 25\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m 26\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_elapsed_steps \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m---> 27\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:333\u001B[0m, in \u001B[0;36mRewardWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 332\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 333\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:319\u001B[0m, in \u001B[0;36mObservationWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 318\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 319\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 320\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mobservation(observation)\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:59\u001B[0m, in \u001B[0;36mFireResetEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 58\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m np\u001B[38;5;241m.\u001B[39mndarray:\n\u001B[0;32m---> 59\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 60\u001B[0m obs, _, done, _ \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39menv\u001B[38;5;241m.\u001B[39mstep(\u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 61\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m done:\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:106\u001B[0m, in \u001B[0;36mEpisodicLifeEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 97\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 98\u001B[0m \u001B[38;5;124;03mCalls the Gym environment reset, only when lives are exhausted.\u001B[39;00m\n\u001B[1;32m 99\u001B[0m \u001B[38;5;124;03mThis way all states are still reachable even though lives are episodic,\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 103\u001B[0m \u001B[38;5;124;03m:return: the first observation of the environment\u001B[39;00m\n\u001B[1;32m 104\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 105\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mwas_real_done:\n\u001B[0;32m--> 106\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 107\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 108\u001B[0m \u001B[38;5;66;03m# no-op step to advance from terminal/lost life state\u001B[39;00m\n\u001B[1;32m 109\u001B[0m obs, _, _, _ \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39menv\u001B[38;5;241m.\u001B[39mstep(\u001B[38;5;241m0\u001B[39m)\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:154\u001B[0m, in \u001B[0;36mMaxAndSkipEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 153\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m GymObs:\n\u001B[0;32m--> 154\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:36\u001B[0m, in \u001B[0;36mNoopResetEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 34\u001B[0m noops \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moverride_num_noops\n\u001B[1;32m 35\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m---> 36\u001B[0m noops \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43munwrapped\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnp_random\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrandint\u001B[49m(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnoop_max \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 37\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m noops \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[1;32m 38\u001B[0m obs \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros(\u001B[38;5;241m0\u001B[39m)\n", + "\u001B[0;31mAttributeError\u001B[0m: 'numpy.random._generator.Generator' object has no attribute 'randint'" + ] + } + ], "source": [ "import torch as th\n", "import gym\n", @@ -72,6 +115,7 @@ ").to(device)\n", "\n", "fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng)\n", + "querent = preference_comparisons.PreferenceQuerent()\n", "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", @@ -105,6 +149,7 @@ " reward_net,\n", " num_iterations=2,\n", " fragmenter=fragmenter,\n", + " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=10,\n", From 69983ef4ec3a456a6b126ba433383548a0e44db6 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Sun, 2 Jul 2023 20:30:38 +0200 Subject: [PATCH 043/143] Adds querent to pref comparisons in tutorial 5b and docs --- docs/algorithms/preference_comparisons.rst | 2 ++ ...reference_comparisons_with_synchronous_human_feedback.ipynb | 3 +++ 2 files changed, 5 insertions(+) diff --git a/docs/algorithms/preference_comparisons.rst b/docs/algorithms/preference_comparisons.rst index 5cf46b277..7638f4f08 100644 --- a/docs/algorithms/preference_comparisons.rst +++ b/docs/algorithms/preference_comparisons.rst @@ -48,6 +48,7 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) + querent = preference_comparisons.PreferenceQuerent() preference_model = preference_comparisons.PreferenceModel(reward_net) reward_trainer = preference_comparisons.BasicRewardTrainer( preference_model=preference_model, @@ -79,6 +80,7 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` reward_net, num_iterations=5, fragmenter=fragmenter, + preference_querent=querent, preference_gatherer=gatherer, reward_trainer=reward_trainer, initial_epoch_multiplier=1, diff --git a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb index 40d9c18ce..97261aeeb 100644 --- a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb +++ b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb @@ -64,6 +64,8 @@ " rng=rng,\n", ")\n", "\n", + "querent = preference_comparisons.PreferenceQuerent()\n", + "\n", "# This gatherer will show the user (you!) pairs of trajectories and ask it to choose\n", "# which one is better. It will then use the user's feedback to train the reward network.\n", "gatherer = preference_comparisons.SynchronousHumanGatherer(video_dir=video_dir)\n", @@ -104,6 +106,7 @@ " reward_net,\n", " num_iterations=5,\n", " fragmenter=fragmenter,\n", + " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=100,\n", From 54be4be1781b88f3db454bf14e3aa28971057a50 Mon Sep 17 00:00:00 2001 From: rk1a Date: Sun, 2 Jul 2023 20:41:03 +0200 Subject: [PATCH 044/143] Fix notebooks and docs --- docs/algorithms/preference_comparisons.rst | 2 + .../5_train_preference_comparisons.ipynb | 557 +++++++++++++++++- ...rain_preference_comparisons_with_cnn.ipynb | 51 +- ...sons_with_synchronous_human_feedback.ipynb | 3 + 4 files changed, 596 insertions(+), 17 deletions(-) diff --git a/docs/algorithms/preference_comparisons.rst b/docs/algorithms/preference_comparisons.rst index 5cf46b277..7638f4f08 100644 --- a/docs/algorithms/preference_comparisons.rst +++ b/docs/algorithms/preference_comparisons.rst @@ -48,6 +48,7 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) + querent = preference_comparisons.PreferenceQuerent() preference_model = preference_comparisons.PreferenceModel(reward_net) reward_trainer = preference_comparisons.BasicRewardTrainer( preference_model=preference_model, @@ -79,6 +80,7 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` reward_net, num_iterations=5, fragmenter=fragmenter, + preference_querent=querent, preference_gatherer=gatherer, reward_trainer=reward_trainer, initial_epoch_multiplier=1, diff --git a/docs/tutorials/5_train_preference_comparisons.ipynb b/docs/tutorials/5_train_preference_comparisons.ipynb index b2cf6a500..9fbfd429e 100644 --- a/docs/tutorials/5_train_preference_comparisons.ipynb +++ b/docs/tutorials/5_train_preference_comparisons.ipynb @@ -19,8 +19,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:16:31.401489014Z", + "start_time": "2023-07-02T18:16:26.574553894Z" + } + }, "outputs": [], "source": [ "import random\n", @@ -45,6 +50,7 @@ " warning_threshold=0,\n", " rng=rng,\n", ")\n", + "querent = preference_comparisons.PreferenceQuerent()\n", "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", @@ -82,6 +88,7 @@ " reward_net,\n", " num_iterations=5,\n", " fragmenter=fragmenter,\n", + " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=100,\n", @@ -101,9 +108,499 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:17:18.887827019Z", + "start_time": "2023-07-02T18:16:31.404169166Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query schedule: [20, 51, 41, 34, 29, 25]\n", + "Collecting 40 fragments (4000 transitions)\n", + "Requested 4000 transitions but only 0 in buffer. Sampling 4000 additional transitions.\n", + "Creating fragment pairs\n", + "Gathering preferences\n", + "Dataset now contains 20 comparisons\n" + ] + }, + { + "data": { + "text/plain": "Training reward model: 0%| | 0/3 [00:00" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from stable_baselines3 import PPO\n", "from stable_baselines3.ppo import MlpPolicy\n", @@ -168,9 +684,22 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:17:21.983164120Z", + "start_time": "2023-07-02T18:17:20.917257800Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-1057.3085665\n" + ] + } + ], "source": [ "from stable_baselines3.common.evaluation import evaluate_policy\n", "\n", @@ -203,4 +732,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb index 570e6154f..7d6a32207 100644 --- a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb +++ b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb @@ -22,10 +22,53 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "93187e19", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2023-07-02T18:29:25.396097810Z", + "start_time": "2023-07-02T18:29:18.742766046Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)\n", + "[Powered by Stella]\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'numpy.random._generator.Generator' object has no attribute 'randint'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[1], line 67\u001B[0m\n\u001B[1;32m 49\u001B[0m reward_trainer \u001B[38;5;241m=\u001B[39m preference_comparisons\u001B[38;5;241m.\u001B[39mBasicRewardTrainer(\n\u001B[1;32m 50\u001B[0m preference_model\u001B[38;5;241m=\u001B[39mpreference_model,\n\u001B[1;32m 51\u001B[0m loss\u001B[38;5;241m=\u001B[39mpreference_comparisons\u001B[38;5;241m.\u001B[39mCrossEntropyRewardLoss(),\n\u001B[1;32m 52\u001B[0m epochs\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m3\u001B[39m,\n\u001B[1;32m 53\u001B[0m rng\u001B[38;5;241m=\u001B[39mrng,\n\u001B[1;32m 54\u001B[0m )\n\u001B[1;32m 56\u001B[0m agent \u001B[38;5;241m=\u001B[39m PPO(\n\u001B[1;32m 57\u001B[0m policy\u001B[38;5;241m=\u001B[39mCnnPolicy,\n\u001B[1;32m 58\u001B[0m env\u001B[38;5;241m=\u001B[39mvenv,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 64\u001B[0m n_epochs\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m4\u001B[39m,\n\u001B[1;32m 65\u001B[0m )\n\u001B[0;32m---> 67\u001B[0m trajectory_generator \u001B[38;5;241m=\u001B[39m \u001B[43mpreference_comparisons\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mAgentTrainer\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 68\u001B[0m \u001B[43m \u001B[49m\u001B[43malgorithm\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43magent\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 69\u001B[0m \u001B[43m \u001B[49m\u001B[43mreward_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mreward_net\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 70\u001B[0m \u001B[43m \u001B[49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 71\u001B[0m \u001B[43m \u001B[49m\u001B[43mexploration_frac\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.0\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m 72\u001B[0m \u001B[43m \u001B[49m\u001B[43mrng\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mrng\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 73\u001B[0m \u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m pref_comparisons \u001B[38;5;241m=\u001B[39m preference_comparisons\u001B[38;5;241m.\u001B[39mPreferenceComparisons(\n\u001B[1;32m 76\u001B[0m trajectory_generator,\n\u001B[1;32m 77\u001B[0m reward_net,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 87\u001B[0m initial_epoch_multiplier\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m,\n\u001B[1;32m 88\u001B[0m )\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/algorithms/preference_comparisons.py:187\u001B[0m, in \u001B[0;36mAgentTrainer.__init__\u001B[0;34m(self, algorithm, reward_fn, venv, rng, exploration_frac, switch_prob, random_prob, custom_logger)\u001B[0m\n\u001B[1;32m 177\u001B[0m \u001B[38;5;66;03m# The BufferingWrapper records all trajectories, so we can return\u001B[39;00m\n\u001B[1;32m 178\u001B[0m \u001B[38;5;66;03m# them after training. This should come first (before the wrapper that\u001B[39;00m\n\u001B[1;32m 179\u001B[0m \u001B[38;5;66;03m# changes the reward function), so that we return the original environment\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 184\u001B[0m \u001B[38;5;66;03m# SB3 may move the image-channel dimension in the observation space, making\u001B[39;00m\n\u001B[1;32m 185\u001B[0m \u001B[38;5;66;03m# `algorithm.get_env()` not match with `reward_fn`.\u001B[39;00m\n\u001B[1;32m 186\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuffering_wrapper \u001B[38;5;241m=\u001B[39m wrappers\u001B[38;5;241m.\u001B[39mBufferingWrapper(venv)\n\u001B[0;32m--> 187\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mvenv \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreward_venv_wrapper \u001B[38;5;241m=\u001B[39m \u001B[43mreward_wrapper\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mRewardVecEnvWrapper\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 188\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbuffering_wrapper\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 189\u001B[0m \u001B[43m \u001B[49m\u001B[43mreward_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreward_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 190\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 192\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlog_callback \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreward_venv_wrapper\u001B[38;5;241m.\u001B[39mmake_log_callback()\n\u001B[1;32m 194\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39malgorithm\u001B[38;5;241m.\u001B[39mset_env(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mvenv)\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/rewards/reward_wrapper.py:73\u001B[0m, in \u001B[0;36mRewardVecEnvWrapper.__init__\u001B[0;34m(self, venv, reward_fn, ep_history)\u001B[0m\n\u001B[1;32m 71\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[1;32m 72\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_actions \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m---> 73\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/rewards/reward_wrapper.py:84\u001B[0m, in \u001B[0;36mRewardVecEnvWrapper.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 83\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m):\n\u001B[0;32m---> 84\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 85\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/data/wrappers.py:126\u001B[0m, in \u001B[0;36mBufferingWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 124\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_init_reset \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[1;32m 125\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_transitions \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m--> 126\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 127\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_traj_accum \u001B[38;5;241m=\u001B[39m rollout\u001B[38;5;241m.\u001B[39mTrajectoryAccumulator()\n\u001B[1;32m 128\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, ob \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(obs):\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py:58\u001B[0m, in \u001B[0;36mVecFrameStack.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 54\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Union[np\u001B[38;5;241m.\u001B[39mndarray, Dict[\u001B[38;5;28mstr\u001B[39m, np\u001B[38;5;241m.\u001B[39mndarray]]:\n\u001B[1;32m 55\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 56\u001B[0m \u001B[38;5;124;03m Reset all environments\u001B[39;00m\n\u001B[1;32m 57\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[0;32m---> 58\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m \u001B[38;5;66;03m# pytype:disable=annotation-type-mismatch\u001B[39;00m\n\u001B[1;32m 60\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstackedobs\u001B[38;5;241m.\u001B[39mreset(observation)\n\u001B[1;32m 61\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m observation\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:74\u001B[0m, in \u001B[0;36mDummyVecEnv.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 72\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m VecEnvObs:\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m env_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnum_envs):\n\u001B[0;32m---> 74\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menvs\u001B[49m\u001B[43m[\u001B[49m\u001B[43menv_idx\u001B[49m\u001B[43m]\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_save_obs(env_idx, obs)\n\u001B[1;32m 76\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_obs_from_buf()\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/monitor.py:83\u001B[0m, in \u001B[0;36mMonitor.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 81\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mExpected you to pass keyword argument \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mkey\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m into reset\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 82\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcurrent_reset_info[key] \u001B[38;5;241m=\u001B[39m value\n\u001B[0;32m---> 83\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/data/wrappers.py:261\u001B[0m, in \u001B[0;36mRolloutInfoWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 260\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 261\u001B[0m new_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 262\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_obs \u001B[38;5;241m=\u001B[39m [new_obs]\n\u001B[1;32m 263\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_rews \u001B[38;5;241m=\u001B[39m []\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/wrappers/time_limit.py:27\u001B[0m, in \u001B[0;36mTimeLimit.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 25\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m 26\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_elapsed_steps \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m---> 27\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:333\u001B[0m, in \u001B[0;36mRewardWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 332\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 333\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:319\u001B[0m, in \u001B[0;36mObservationWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 318\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 319\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 320\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mobservation(observation)\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:59\u001B[0m, in \u001B[0;36mFireResetEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 58\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m np\u001B[38;5;241m.\u001B[39mndarray:\n\u001B[0;32m---> 59\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 60\u001B[0m obs, _, done, _ \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39menv\u001B[38;5;241m.\u001B[39mstep(\u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 61\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m done:\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:106\u001B[0m, in \u001B[0;36mEpisodicLifeEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 97\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 98\u001B[0m \u001B[38;5;124;03mCalls the Gym environment reset, only when lives are exhausted.\u001B[39;00m\n\u001B[1;32m 99\u001B[0m \u001B[38;5;124;03mThis way all states are still reachable even though lives are episodic,\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 103\u001B[0m \u001B[38;5;124;03m:return: the first observation of the environment\u001B[39;00m\n\u001B[1;32m 104\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 105\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mwas_real_done:\n\u001B[0;32m--> 106\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 107\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 108\u001B[0m \u001B[38;5;66;03m# no-op step to advance from terminal/lost life state\u001B[39;00m\n\u001B[1;32m 109\u001B[0m obs, _, _, _ \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39menv\u001B[38;5;241m.\u001B[39mstep(\u001B[38;5;241m0\u001B[39m)\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:154\u001B[0m, in \u001B[0;36mMaxAndSkipEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 153\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m GymObs:\n\u001B[0;32m--> 154\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:36\u001B[0m, in \u001B[0;36mNoopResetEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 34\u001B[0m noops \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moverride_num_noops\n\u001B[1;32m 35\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m---> 36\u001B[0m noops \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43munwrapped\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnp_random\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrandint\u001B[49m(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnoop_max \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 37\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m noops \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[1;32m 38\u001B[0m obs \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros(\u001B[38;5;241m0\u001B[39m)\n", + "\u001B[0;31mAttributeError\u001B[0m: 'numpy.random._generator.Generator' object has no attribute 'randint'" + ] + } + ], "source": [ "import torch as th\n", "import gym\n", @@ -72,6 +115,7 @@ ").to(device)\n", "\n", "fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng)\n", + "querent = preference_comparisons.PreferenceQuerent()\n", "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", @@ -105,6 +149,7 @@ " reward_net,\n", " num_iterations=2,\n", " fragmenter=fragmenter,\n", + " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=10,\n", diff --git a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb index 40d9c18ce..97261aeeb 100644 --- a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb +++ b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb @@ -64,6 +64,8 @@ " rng=rng,\n", ")\n", "\n", + "querent = preference_comparisons.PreferenceQuerent()\n", + "\n", "# This gatherer will show the user (you!) pairs of trajectories and ask it to choose\n", "# which one is better. It will then use the user's feedback to train the reward network.\n", "gatherer = preference_comparisons.SynchronousHumanGatherer(video_dir=video_dir)\n", @@ -104,6 +106,7 @@ " reward_net,\n", " num_iterations=5,\n", " fragmenter=fragmenter,\n", + " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=100,\n", From 657a8b25df95d69c78f3986a1da23297549764dd Mon Sep 17 00:00:00 2001 From: rk1a Date: Sun, 2 Jul 2023 20:41:24 +0200 Subject: [PATCH 045/143] Fix bug --- src/imitation/algorithms/preference_comparisons.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index a8dab2c16..d658012d9 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -2033,6 +2033,7 @@ def __init__( # are any of the optional args that require a rng None? has_any_rng_args_none = None in ( + preference_querent, preference_gatherer, fragmenter, reward_trainer, From a0564763d1c598ae025af000b94f33e1f294256f Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Tue, 27 Jun 2023 15:58:12 +0200 Subject: [PATCH 046/143] Removes active check --- src/imitation/data/wrappers.py | 41 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 58f4846b1..0090af87f 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -47,28 +47,27 @@ def __init__( def step(self, action): obs, rew, done, info = self.env.step(action) - if self._active: - rendered_image = self.render(mode="rgb_array") - # Scale the render image - scaled_size = ( - int(self.scale_factor * rendered_image.shape[0]), - int(self.scale_factor * rendered_image.shape[1]), - ) - scaled_rendered_image = cv2.resize( - rendered_image, - scaled_size, - interpolation=cv2.INTER_AREA, + rendered_image = self.render(mode="rgb_array") + # Scale the render image + scaled_size = ( + int(self.scale_factor * rendered_image.shape[0]), + int(self.scale_factor * rendered_image.shape[1]), + ) + scaled_rendered_image = cv2.resize( + rendered_image, + scaled_size, + interpolation=cv2.INTER_AREA, + ) + # Store the render image + if not self.use_file_cache: + info["rendered_img"] = scaled_rendered_image + else: + unique_file_path = os.path.join( + self.file_cache, + str(uuid.uuid4()) + ".npy", ) - # Store the render image - if not self.use_file_cache: - info["rendered_img"] = scaled_rendered_image - else: - unique_file_path = os.path.join( - self.file_cache, - str(uuid.uuid4()) + ".npy", - ) - np.save(unique_file_path, scaled_rendered_image) - info["rendered_img"] = unique_file_path + np.save(unique_file_path, scaled_rendered_image) + info["rendered_img"] = unique_file_path # Do not show window of classic control envs if self.env.viewer is not None and self.env.viewer.window.visible: From a2ec16de18e18edb01f001045c33e66be7b3c09e Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Tue, 27 Jun 2023 17:02:08 +0200 Subject: [PATCH 047/143] Adds missing hyphen --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index d658012d9..e46786b14 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -854,7 +854,7 @@ def __call__( for query_id, query in identified_queries.items(): output_file_name = os.path.join( self.video_output_dir, - f"{query_id}" + "{}.webm", + f"{query_id}" + "-{}.webm", ) write_fragment_video( query[0], From 97d4565351c197cb71a116f763bc7fcb3a4497fc Mon Sep 17 00:00:00 2001 From: rk1a Date: Wed, 16 Aug 2023 15:31:16 +0200 Subject: [PATCH 048/143] Adds rng argument to querent --- src/imitation/algorithms/preference_comparisons.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e46786b14..aab7a5cbd 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -788,14 +788,17 @@ class PreferenceQuerent: def __init__( self, + rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ) -> None: """Initializes the preference querent. Args: + rng: random number generator custom_logger: Where to log to; if None (default), creates a new logger. """ self.logger = custom_logger or imit_logger.configure() + self.rng = rng def __call__( self, From a861ce8c0b428c30af7472856f0559ec1d5ed2ac Mon Sep 17 00:00:00 2001 From: rk1a Date: Wed, 16 Aug 2023 15:46:01 +0200 Subject: [PATCH 049/143] Add querent default config --- src/imitation/scripts/config/train_preference_comparisons.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 6ba4b6bfd..d4e0579b9 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -50,6 +50,8 @@ def train_defaults(): gatherer_cls = preference_comparisons.SyntheticGatherer # arguments passed on to the PreferenceGatherer specified by gatherer_cls gatherer_kwargs = {} + querent_cls = preference_comparisons.PreferenceQuerent + querent_kwargs = dict() active_selection = False active_selection_oversampling = 2 uncertainty_on = "logit" From fa70a3b6eeb66495eb857bb088f19387fa8e3e9f Mon Sep 17 00:00:00 2001 From: rk1a Date: Wed, 16 Aug 2023 16:19:50 +0200 Subject: [PATCH 050/143] Fix tests --- setup.py | 1 + .../algorithms/preference_comparisons.py | 4 +++- .../algorithms/test_preference_comparisons.py | 22 ++++++++++--------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index f70f1a86c..6df8182ff 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ "setuptools_scm~=7.0.5", "pre-commit>=2.20.0", "types-requests~=2.31.0.1", + "requests-mock~=1.11.0", ] + PARALLEL_REQUIRE + ATARI_REQUIRE diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index aab7a5cbd..149d9f4b0 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -880,7 +880,9 @@ def _query(self, query_id): ) -def write_fragment_video(fragment: TrajectoryWithRew, frames_per_second: int, output_path: AnyPath) -> None: +def write_fragment_video( + fragment: TrajectoryWithRew, frames_per_second: int, output_path: AnyPath +) -> None: """Write fragment video clip.""" frame_shape = get_frame_shape(fragment) video_writer = cv2.VideoWriter( diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 48688e396..10cf4f824 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -342,6 +342,7 @@ def build_preference_comparisons( @patch("IPython.display.display") def test_synchronous_human_gatherer(mock_display, mock_input): del mock_display # unused + querent = PreferenceQuerent() gatherer = preference_comparisons.SynchronousHumanGatherer( video_dir=pathlib.Path("."), ) @@ -350,7 +351,7 @@ def test_synchronous_human_gatherer(mock_display, mock_input): trajectory_pairs = [ ( types.TrajectoryWithRew( - np.array([1, 2]), + np.zeros((2, 200, 200, 3,), np.uint8), np.array([1]), np.array( [ @@ -365,9 +366,9 @@ def test_synchronous_human_gatherer(mock_display, mock_input): np.array([1.0]), ), types.TrajectoryWithRew( - np.array([1, 2]), - np.array([1]), - np.array( + np.zeros((2, 200, 200, 3,), np.uint8), + np.array([1]), # act + np.array( # info [ { "video_path": pathlib.Path( @@ -376,17 +377,19 @@ def test_synchronous_human_gatherer(mock_display, mock_input): }, ], ), - True, - np.array([1.0]), + True, # done + np.array([1.0]), # reward ), ), ] - + identified_queries = querent(trajectory_pairs) + gatherer.add(identified_queries) # this is the actual test mock_input.return_value = "1" - assert gatherer(trajectory_pairs) == np.array([1.0]) + assert gatherer()[1] == np.array([1.0]) + gatherer.add(identified_queries) mock_input.return_value = "2" - assert gatherer(trajectory_pairs) == np.array([0.0]) + assert gatherer()[1] == np.array([0.0]) @pytest.mark.parametrize( @@ -1239,7 +1242,6 @@ def test_sends_put_request_for_each_query(requests_mock): class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" - @abc.abstractmethod def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: pass From e0b44e5c3ef25a7214c4b789180231e48f50b6d3 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 25 Aug 2023 15:48:58 +0200 Subject: [PATCH 051/143] Fix test --- src/imitation/scripts/train_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 96d35122c..c50393f37 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -98,7 +98,7 @@ def train_rl( rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) - post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] + post_wrappers = {"RolloutInfoWrapper": lambda env, idx: wrappers.RolloutInfoWrapper(env)} with environment.make_venv(post_wrappers=post_wrappers) as venv: callback_objs = [] if reward_type is not None: From 30578ca4485028e39d877e19e58f9a5bfe7a37f0 Mon Sep 17 00:00:00 2001 From: rk1a Date: Tue, 5 Sep 2023 16:07:30 +0200 Subject: [PATCH 052/143] Fix errors related to post_wrappers --- docs/algorithms/airl.rst | 2 +- docs/algorithms/gail.rst | 2 +- docs/tutorials/3_train_gail.ipynb | 4 ++-- docs/tutorials/4_train_airl.ipynb | 4 ++-- ...ence_comparisons_with_synchronous_human_feedback.ipynb | 8 ++++---- docs/tutorials/8_train_custom_env.ipynb | 4 ++-- src/imitation/scripts/eval_policy.py | 4 ++-- src/imitation/scripts/ingredients/environment.py | 2 +- src/imitation/testing/expert_trajectories.py | 2 +- 9 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/algorithms/airl.rst b/docs/algorithms/airl.rst index 18f9a6fb1..1fca64089 100644 --- a/docs/algorithms/airl.rst +++ b/docs/algorithms/airl.rst @@ -48,7 +48,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl` "seals/CartPole-v0", rng=rng, n_envs=5, - post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], + post_wrappers={"RolloutInfoWrapper": lambda env, _: RolloutInfoWrapper(env)}, ), rollout.make_sample_until(min_timesteps=None, min_episodes=60), rng=rng, diff --git a/docs/algorithms/gail.rst b/docs/algorithms/gail.rst index d21829ec8..48284c0d0 100644 --- a/docs/algorithms/gail.rst +++ b/docs/algorithms/gail.rst @@ -44,7 +44,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail` make_vec_env( "seals/CartPole-v0", n_envs=5, - post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], + post_wrappers={"RolloutInfoWrapper": lambda env, _: RolloutInfoWrapper(env)}, rng=rng, ), rollout.make_sample_until(min_timesteps=None, min_episodes=60), diff --git a/docs/tutorials/3_train_gail.ipynb b/docs/tutorials/3_train_gail.ipynb index 5cdeca671..12005f895 100644 --- a/docs/tutorials/3_train_gail.ipynb +++ b/docs/tutorials/3_train_gail.ipynb @@ -69,7 +69,7 @@ " make_vec_env(\n", " \"seals/CartPole-v0\",\n", " n_envs=5,\n", - " post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", + " post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)},\n", " rng=rng,\n", " ),\n", " rollout.make_sample_until(min_timesteps=None, min_episodes=60),\n", @@ -187,4 +187,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/docs/tutorials/4_train_airl.ipynb b/docs/tutorials/4_train_airl.ipynb index 1067ae2df..762bf4ce4 100644 --- a/docs/tutorials/4_train_airl.ipynb +++ b/docs/tutorials/4_train_airl.ipynb @@ -66,7 +66,7 @@ " make_vec_env(\n", " \"seals/CartPole-v0\",\n", " n_envs=5,\n", - " post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", + " post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)},\n", " rng=rng,\n", " ),\n", " rollout.make_sample_until(min_timesteps=None, min_episodes=60),\n", @@ -181,4 +181,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb index 97261aeeb..e10f824f7 100644 --- a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb +++ b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb @@ -50,9 +50,9 @@ "venv = make_vec_env(\n", " \"Pendulum-v1\",\n", " rng=rng,\n", - " post_wrappers=[\n", - " video_wrapper.video_wrapper_factory(pathlib.Path(video_dir), single_video=False)\n", - " ],\n", + " post_wrappers={\n", + " \"VideoWrapper\": video_wrapper.video_wrapper_factory(pathlib.Path(video_dir), single_video=False)\n", + " },\n", ")\n", "\n", "reward_net = BasicRewardNet(\n", @@ -241,7 +241,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/docs/tutorials/8_train_custom_env.ipynb b/docs/tutorials/8_train_custom_env.ipynb index b81d603c1..8e49bfcc5 100644 --- a/docs/tutorials/8_train_custom_env.ipynb +++ b/docs/tutorials/8_train_custom_env.ipynb @@ -138,12 +138,12 @@ "\n", "# Create a vectorized environment for training with `imitation`\n", "\n", - "# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers=[lambda env, _: RolloutInfoWrapper(env)]`\n", + "# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)}`\n", "venv = make_vec_env(\n", " \"custom/ObservationMatching-v0\",\n", " rng=np.random.default_rng(),\n", " n_envs=4,\n", - " post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", + " post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)},\n", ")\n", "\n", "\n", diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 1a2659c61..80de048b5 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -83,9 +83,9 @@ def eval_policy( log_dir = logging_ingredient.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) post_wrappers = ( - [video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)] + {"VideoWrapper": video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)} if videos - else None + else {} ) with environment.make_venv(post_wrappers=post_wrappers) as venv: if render: diff --git a/src/imitation/scripts/ingredients/environment.py b/src/imitation/scripts/ingredients/environment.py index b197b6a86..616dd8d8c 100644 --- a/src/imitation/scripts/ingredients/environment.py +++ b/src/imitation/scripts/ingredients/environment.py @@ -123,7 +123,7 @@ def make_rollout_venv( max_episode_steps=max_episode_steps, log_dir=None, env_make_kwargs=env_make_kwargs, - post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], + post_wrappers={"RolloutInfoWrapper": lambda env, i: wrappers.RolloutInfoWrapper(env)}, ) try: yield venv diff --git a/src/imitation/testing/expert_trajectories.py b/src/imitation/testing/expert_trajectories.py index dc640b8c1..ff87b1c6d 100644 --- a/src/imitation/testing/expert_trajectories.py +++ b/src/imitation/testing/expert_trajectories.py @@ -38,7 +38,7 @@ def generate_expert_trajectories( """ env = util.make_vec_env( env_id, - post_wrappers=[lambda e, _: wrappers.RolloutInfoWrapper(e)], + post_wrappers={"RolloutInfoWrapper": lambda e, _: wrappers.RolloutInfoWrapper(e)}, rng=rng, ) try: From d15cab1aab0fe98c56f0065eea4a7561db74dc57 Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 14 Sep 2023 17:31:02 +0200 Subject: [PATCH 053/143] Fix tests --- src/imitation/scripts/ingredients/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/ingredients/environment.py b/src/imitation/scripts/ingredients/environment.py index 616dd8d8c..b197b6a86 100644 --- a/src/imitation/scripts/ingredients/environment.py +++ b/src/imitation/scripts/ingredients/environment.py @@ -123,7 +123,7 @@ def make_rollout_venv( max_episode_steps=max_episode_steps, log_dir=None, env_make_kwargs=env_make_kwargs, - post_wrappers={"RolloutInfoWrapper": lambda env, i: wrappers.RolloutInfoWrapper(env)}, + post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], ) try: yield venv From f0728cddb7acaa7aa5674e24f31f753460bfdeb4 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Tue, 19 Sep 2023 16:53:21 +0200 Subject: [PATCH 054/143] Fix test --- src/imitation/scripts/ingredients/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/ingredients/environment.py b/src/imitation/scripts/ingredients/environment.py index 616dd8d8c..b197b6a86 100644 --- a/src/imitation/scripts/ingredients/environment.py +++ b/src/imitation/scripts/ingredients/environment.py @@ -123,7 +123,7 @@ def make_rollout_venv( max_episode_steps=max_episode_steps, log_dir=None, env_make_kwargs=env_make_kwargs, - post_wrappers={"RolloutInfoWrapper": lambda env, i: wrappers.RolloutInfoWrapper(env)}, + post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], ) try: yield venv From 4dbf20c6d42286377901dd538cd6c5001225fa40 Mon Sep 17 00:00:00 2001 From: rk1a Date: Fri, 6 Oct 2023 14:38:03 +0200 Subject: [PATCH 055/143] Fix test --- src/imitation/util/video_wrapper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 1babbd227..e250da5e3 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -52,7 +52,6 @@ def __init__( self.video_recorder = None self.single_video = single_video self.delete_on_close = delete_on_close - self.current_video_path: Optional[pathlib.Path] = None self.directory = directory self.directory.mkdir(parents=True, exist_ok=True) @@ -77,7 +76,6 @@ def _reset_video_recorder(self) -> None: base_path=str(self.directory / f"video.{self.episode_id:06}"), metadata={"episode_id": self.episode_id}, ) - self.current_video_path = pathlib.Path(self.video_recorder.path) def reset( self, From ef1134a677205937ed24eccf9c676aa434a26cc4 Mon Sep 17 00:00:00 2001 From: rk1a Date: Fri, 6 Oct 2023 14:40:33 +0200 Subject: [PATCH 056/143] Refactors gatherer and adapts some tests --- .../algorithms/preference_comparisons.py | 35 ++++++++----------- .../algorithms/test_preference_comparisons.py | 16 ++++----- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 149d9f4b0..9dee8406f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -961,11 +961,12 @@ def __init__( # pass in a seed in training scripts (without worrying about whether # the PreferenceGatherer we use needs one). del rng + self.querent = PreferenceQuerent() self.logger = custom_logger or imit_logger.configure() self.pending_queries: Dict = {} @abc.abstractmethod - def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Gathers the probabilities that fragment 1 is preferred in `queries`. Returns: @@ -980,7 +981,11 @@ def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: probabilities. """ # noqa: DAR202 - def add(self, new_queries: Dict[str, TrajectoryWithRewPair]) -> None: + def query(self, queries: Sequence[TrajectoryWithRewPair]) -> None: + identified_queries = self.querent(queries) + self._add(identified_queries) + + def _add(self, new_queries: Dict[str, TrajectoryWithRewPair]) -> None: """Adds queries to pending queries. Args: @@ -1036,7 +1041,7 @@ def __init__( if self.sample and self.rng is None: raise ValueError("If `sample` is True, then `rng` must be provided.") - def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Computes probability fragment 1 is preferred over fragment 2.""" returns1, returns2 = self._reward_sums(self.pending_queries.values()) @@ -1111,7 +1116,7 @@ def __init__( self.video_height = video_height self.frames_per_second = frames_per_second - def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Displays each pair of fragments and asks for a preference. It iteratively requests user feedback for each pair of fragments. If in the @@ -1314,11 +1319,12 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. """ super().__init__(rng, custom_logger) + self.querent = PrefCollectQuerent(pref_collect_address, "videos") self.query_endpoint = pref_collect_address + "/preferences/query/" self.pending_queries = {} self.wait_for_user = wait_for_user - def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: # TODO: create user-independent (automated) waiting policy if self.wait_for_user: @@ -1940,7 +1946,6 @@ def __init__( reward_model: reward_nets.RewardNet, num_iterations: int, fragmenter: Optional[Fragmenter] = None, - preference_querent: Optional[PreferenceQuerent] = None, preference_gatherer: Optional[PreferenceGatherer] = None, reward_trainer: Optional[RewardTrainer] = None, comparison_queue_size: Optional[int] = None, @@ -1969,7 +1974,6 @@ def __init__( for which preferences will be gathered. These fragments could be random, or they could be selected more deliberately (active learning). Default is a random fragmenter. - preference_querent: queries preferences between trajectory fragments. preference_gatherer: gathers preferences between trajectory fragments. Default (and currently the only option) is to use synthetic preferences based on ground-truth rewards. Human preferences could be implemented @@ -2038,7 +2042,6 @@ def __init__( # are any of the optional args that require a rng None? has_any_rng_args_none = None in ( - preference_querent, preference_gatherer, fragmenter, reward_trainer, @@ -2048,14 +2051,14 @@ def __init__( if self.rng is None and has_any_rng_args_none: raise ValueError( "If you don't provide a random state, you must provide your own " - "seeded fragmenter, preference gatherer, preference querent, " + "seeded fragmenter, preference gatherer, " "and reward_trainer. " "You can initialize a random state with `np.random.default_rng(seed)`.", ) elif self.rng is not None and not has_any_rng_args_none: raise ValueError( "If you provide your own fragmenter, preference gatherer, " - "preference querent, and reward trainer, " + "and reward trainer, " "you don't need to provide a random state.", ) @@ -2086,13 +2089,6 @@ def __init__( rng=self.rng, ) self.fragmenter.logger = self.logger - if preference_querent: - self.preference_querent = preference_querent - else: - assert self.rng is not None - self.preference_querent = PreferenceQuerent( - custom_logger=self.logger, - ) if preference_gatherer: self.preference_gatherer = preference_gatherer else: @@ -2172,13 +2168,12 @@ def train( queries = self.fragmenter(trajectories, self.fragment_length, num_queries) - identified_queries = self.preference_querent(queries) - self.preference_gatherer.add(identified_queries) + self.preference_gatherer.query(queries) with self.logger.accumulate_means("preferences"): self.logger.log("Gathering preferences") # Gather fragment pairs for which preferences have been provided - queries, preferences = self.preference_gatherer() + queries, preferences = self.preference_gatherer.gather() # Free up RAM or disk space from keeping rendered images remove_rendered_images(trajectories) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 9f62a2d9a..ceea95c0e 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1242,26 +1242,26 @@ def test_sends_put_request_for_each_query(requests_mock): class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" - def __call__(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: pass def test_adds_queries_to_pending_queries(): gatherer = ConcretePreferenceGatherer() - query_id = "id" - queries = {query_id: Mock()} + query = Mock() + queries = [query] - gatherer.add(new_queries=queries) - assert query_id in list(gatherer.pending_queries.keys()) + gatherer.query(queries) + assert query in list(gatherer.pending_queries.values()) def test_clears_pending_queries(trajectory_with_rew): gatherer = SyntheticGatherer(sample=False) - queries = {"id": (trajectory_with_rew, trajectory_with_rew)} - gatherer.add(new_queries=queries) + queries = [(trajectory_with_rew, trajectory_with_rew)] + gatherer.query(queries) - gatherer() + gatherer.gather() assert len(gatherer.pending_queries) == 0 From 046183f90caf1dd71c0fc0f51ee6fd9f115a58cc Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 2 Nov 2023 16:55:50 +0100 Subject: [PATCH 057/143] Fixes gym import bug --- src/imitation/scripts/ingredients/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/scripts/ingredients/environment.py b/src/imitation/scripts/ingredients/environment.py index b197b6a86..0c1f52a30 100644 --- a/src/imitation/scripts/ingredients/environment.py +++ b/src/imitation/scripts/ingredients/environment.py @@ -3,7 +3,7 @@ import functools from typing import Any, Callable, Generator, Mapping -import gym +import gymnasium as gym import numpy as np import sacred from stable_baselines3.common import vec_env From d8c5d98547769404b5151c7a748358b035eda10e Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 2 Nov 2023 17:01:38 +0100 Subject: [PATCH 058/143] Fixes bug --- src/imitation/algorithms/preference_comparisons.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 9dee8406f..b2adfb0f6 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1043,10 +1043,13 @@ def __init__( def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Computes probability fragment 1 is preferred over fragment 2.""" - returns1, returns2 = self._reward_sums(self.pending_queries.values()) + queries = list(self.pending_queries.values()) + self.pending_queries.clear() # Clear pending queries because the oracle will have answered all + + returns1, returns2 = self._reward_sums(queries) if self.temperature == 0: - return (np.sign(returns1 - returns2) + 1) / 2 + return queries, (np.sign(returns1 - returns2) + 1) / 2 returns1 /= self.temperature returns2 /= self.temperature @@ -1065,13 +1068,10 @@ def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: ).mean() self.logger.record("entropy", entropy) - # Clear pending queries because the oracle has answered all - queries = list(self.pending_queries.values()) - self.pending_queries.clear() - if self.sample: assert self.rng is not None return queries, self.rng.binomial(n=1, p=choice_probs).astype(np.float32) + return queries, choice_probs def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]: From d341090b88d066d5332a0579f9cd8335d3d1bf76 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 2 Nov 2023 17:39:52 +0100 Subject: [PATCH 059/143] Fix tests --- .../algorithms/test_preference_comparisons.py | 63 ++++++------------- 1 file changed, 20 insertions(+), 43 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index ceea95c0e..bfdbef848 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -266,27 +266,19 @@ def test_preference_comparisons_raises( rng=rng, ) - querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) no_rng_msg = ( ".*don't provide.*random state.*provide.*fragmenter" - ".*preference gatherer.*preference querent.*reward_trainer.*" + ".*preference gatherer.*reward_trainer.*" ) - def build_preference_comparisons( - gatherer, - querent, - reward_trainer, - fragmenter, - rng, - ): + def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): preference_comparisons.PreferenceComparisons( agent_trainer, reward_net, num_iterations=2, transition_oversampling=2, reward_trainer=reward_trainer, - preference_querent=querent, preference_gatherer=gatherer, fragmenter=fragmenter, custom_logger=custom_logger, @@ -295,47 +287,31 @@ def build_preference_comparisons( ) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparisons(gatherer, None, None, None, rng=None) + build_preference_comparisons(gatherer, None, None, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparisons(None, None, reward_trainer, None, rng=None) + build_preference_comparisons(None, reward_trainer, None, rng=None) with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparisons(None, None, None, random_fragmenter, rng=None) - - with pytest.raises(ValueError, match=no_rng_msg): - build_preference_comparisons(None, querent, None, None, rng=None) + build_preference_comparisons(None, None, random_fragmenter, rng=None) # This should not raise - build_preference_comparisons( - gatherer, - querent, - reward_trainer, - random_fragmenter, - rng=None, - ) + build_preference_comparisons(gatherer, reward_trainer, random_fragmenter, rng=None) # if providing fragmenter, preference gatherer, reward trainer, does not need rng. with_rng_msg = ( - "provide.*fragmenter.*preference gatherer.*preference querent.*reward trainer" + "provide.*fragmenter.*preference gatherer.*reward trainer" ".*don't need.*random state.*" ) with pytest.raises(ValueError, match=with_rng_msg): - build_preference_comparisons( - gatherer, - querent, - reward_trainer, - random_fragmenter, - rng=rng, - ) + build_preference_comparisons(gatherer, reward_trainer, random_fragmenter, rng=rng) # This should not raise - build_preference_comparisons(None, None, None, None, rng=rng) - build_preference_comparisons(gatherer, None, None, None, rng=rng) - build_preference_comparisons(None, querent, None, None, rng=rng) - build_preference_comparisons(None, None, reward_trainer, None, rng=rng) - build_preference_comparisons(None, None, None, random_fragmenter, rng=rng) + build_preference_comparisons(None, None, None, rng=rng) + build_preference_comparisons(gatherer, None, None, rng=rng) + build_preference_comparisons(None, reward_trainer, None, rng=rng) + build_preference_comparisons(None, None, random_fragmenter, rng=rng) @patch("builtins.input") @@ -382,14 +358,15 @@ def test_synchronous_human_gatherer(mock_display, mock_input): ), ), ] - identified_queries = querent(trajectory_pairs) - gatherer.add(identified_queries) + gatherer.query(trajectory_pairs) + # this is the actual test mock_input.return_value = "1" - assert gatherer()[1] == np.array([1.0]) - gatherer.add(identified_queries) + assert gatherer.gather()[1] == np.array([1.0]) + + gatherer.query(trajectory_pairs) mock_input.return_value = "2" - assert gatherer()[1] == np.array([0.0]) + assert gatherer.gather()[1] == np.array([0.0]) @pytest.mark.parametrize( @@ -1338,7 +1315,7 @@ def test_gathers_valid_preference(): query = Mock() gatherer.pending_queries = {"1234": query} - gathered_queries, gathered_preferences = gatherer() + gathered_queries, gathered_preferences = gatherer.gather() assert gathered_preferences[0] == preference assert gathered_queries[0] == query @@ -1353,7 +1330,7 @@ def test_ignores_incomparable_answer(): gatherer._gather_preference = MagicMock(return_value=-1.0) gatherer.pending_queries = {"1234": Mock()} - gathered_queries, gathered_preferences = gatherer() + gathered_queries, gathered_preferences = gatherer.gather() assert len(gathered_preferences) == 0 assert len(gathered_queries) == 0 From f8c42bc36bc2468dcc8fc87d8461394c91139751 Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 2 Nov 2023 17:39:56 +0100 Subject: [PATCH 060/143] Fix tests --- .../algorithms/test_preference_comparisons.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index ceea95c0e..9974d3e75 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -577,7 +577,6 @@ def test_gradient_accumulation( minibatch_size = 3 num_trajectories = 5 - preference_querent = preference_comparisons.PreferenceQuerent(rng=rng) preference_gatherer = preference_comparisons.SyntheticGatherer( custom_logger=custom_logger, rng=rng, @@ -585,9 +584,8 @@ def test_gradient_accumulation( dataset = preference_comparisons.PreferenceDataset() trajectory = agent_trainer.sample(num_trajectories) fragments = random_fragmenter(trajectory, 1, num_trajectories) - identified_queries = preference_querent(fragments) - preference_gatherer.add(identified_queries) - fragments, preferences = preference_gatherer() + preference_gatherer.query(fragments) + fragments, preferences = preference_gatherer.gather() dataset.push(fragments, preferences) seed = rng.integers(2**32) @@ -625,18 +623,16 @@ def test_synthetic_gatherer_deterministic( random_fragmenter, rng, ): - preference_querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer( temperature=0, rng=rng, ) trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) - identified_queries = preference_querent(fragments) - gatherer.add(identified_queries) - _, preferences1 = gatherer() - gatherer.add(identified_queries) - _, preferences2 = gatherer() + gatherer.query(fragments) + _, preferences1 = gatherer.gather() + gatherer.query(fragments) + _, preferences2 = gatherer.gather() assert np.all(preferences1 == preferences2) @@ -711,13 +707,11 @@ def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): dataset = preference_comparisons.PreferenceDataset(max_size=5) trajectories = agent_trainer.sample(10) - querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) for i in range(6): fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=1) - identified_queries = querent(fragments) - gatherer.add(identified_queries) - fragments, preferences = gatherer() + gatherer.query(fragments) + fragments, preferences = gatherer.gather() assert len(dataset) == min(i, 5) dataset.push(fragments, preferences) assert len(dataset) == min(i + 1, 5) @@ -735,11 +729,9 @@ def test_store_and_load_preference_dataset( dataset = preference_comparisons.PreferenceDataset() trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) - querent = preference_comparisons.PreferenceQuerent(rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) - identified_queries = querent(fragments) - gatherer.add(identified_queries) - fragments, preferences = gatherer() + gatherer.query(fragments) + fragments, preferences = gatherer.gather() dataset.push(fragments, preferences) path = tmp_path / "preferences.pkl" @@ -1309,7 +1301,7 @@ def test_keeps_pending_query_for_unanswered_query(): gatherer.pending_queries = {"1234": Mock()} pending_queries_pre = gatherer.pending_queries.copy() - gatherer() + gatherer.gather() assert pending_queries_pre == gatherer.pending_queries @@ -1323,7 +1315,7 @@ def test_deletes_pending_query_for_answered_query(): gatherer._gather_preference = MagicMock(return_value=preference) gatherer.pending_queries = {"1234": Mock()} - gatherer() + gatherer.gather() assert len(gatherer.pending_queries) == 0 From b74bb5a3135e95c09bad44ba9e6c227cfb638725 Mon Sep 17 00:00:00 2001 From: rk1a Date: Mon, 6 Nov 2023 16:25:39 +0100 Subject: [PATCH 061/143] Fix tests --- src/imitation/scripts/eval_policy.py | 2 +- src/imitation/scripts/train_preference_comparisons.py | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 8aa765d47..ee2da8f9e 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -94,7 +94,7 @@ def eval_policy( """ log_dir = logging_ingredient.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) - post_wrappers = {"VideoWrapper": video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)} if videos else None + post_wrappers = {"VideoWrapper": video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)} if videos else {} render_mode = "rgb_array" if videos else None with environment.make_venv( # type: ignore[wrong-keyword-args] post_wrappers=post_wrappers, diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 1e49c0c0b..e06df02de 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -77,8 +77,6 @@ def train_preference_comparisons( reward_trainer_kwargs: Mapping[str, Any], gatherer_cls: Type[preference_comparisons.PreferenceGatherer], gatherer_kwargs: Mapping[str, Any], - querent_cls: Type[preference_comparisons.PreferenceQuerent], - querent_kwargs: Mapping[str, Any], active_selection: bool, active_selection_oversampling: int, uncertainty_on: str, @@ -124,8 +122,6 @@ def train_preference_comparisons( reward_trainer_kwargs: passed to BasicRewardTrainer or EnsembleRewardTrainer gatherer_cls: type of PreferenceGatherer to use (defaults to SyntheticGatherer) gatherer_kwargs: passed to the PreferenceGatherer specified by gatherer_cls - querent_cls: type of PreferenceQuerent to use (defaults to PreferenceQuerent) - querent_kwargs: passed to the PreferenceQuerent specified by querent_cls active_selection: use active selection fragmenter instead of random fragmenter active_selection_oversampling: factor by which to oversample random fragments from the base fragmenter of active selection. @@ -243,10 +239,6 @@ def train_preference_comparisons( rng=_rnd, custom_logger=custom_logger, ) - querent = querent_cls( - **querent_kwargs, - custom_logger=custom_logger, - ) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -263,7 +255,6 @@ def train_preference_comparisons( num_iterations=num_iterations, fragmenter=fragmenter, preference_gatherer=gatherer, - preference_querent=querent, reward_trainer=reward_trainer, comparison_queue_size=comparison_queue_size, fragment_length=fragment_length, From ba20500dd98322ab2739b8a6888829caa43a08e3 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 10 Nov 2023 12:10:50 +0100 Subject: [PATCH 062/143] Adapt render image infor wrapper to gymnasium --- src/imitation/data/wrappers.py | 13 ++++++------ tests/data/test_wrappers.py | 36 ++++++++++++++++++++++++++-------- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index aaa27de58..825a72f8b 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -38,6 +38,9 @@ def __init__( scale_factor: scales rendered images to be stored. use_file_cache: whether to save rendered images to disk. """ + assert env.render_mode == "rgb_array", \ + f'The environment must be in render mode "rgb_array" in order to use this wrapper but render_mode is ' \ + f'"{env.render_mode}".' super().__init__(env) self.scale_factor = scale_factor self.use_file_cache = use_file_cache @@ -45,9 +48,9 @@ def __init__( self.file_cache = tempfile.mkdtemp("imitation_RenderImageInfoWrapper") def step(self, action): - obs, rew, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) - rendered_image = self.render(mode="rgb_array") + rendered_image = self.render() # Scale the render image scaled_size = ( int(self.scale_factor * rendered_image.shape[0]), @@ -69,11 +72,7 @@ def step(self, action): np.save(unique_file_path, scaled_rendered_image) info["rendered_img"] = unique_file_path - # Do not show window of classic control envs - if self.env.viewer is not None and self.env.viewer.window.visible: - self.env.viewer.window.set_visible(False) - - return obs, rew, done, info + return observation, reward, terminated, truncated, info def close(self) -> None: if self.use_file_cache: diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 33677c68f..fcc85c760 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env import DummyVecEnv from imitation.data import types -from imitation.data.wrappers import BufferingWrapper +from imitation.data.wrappers import BufferingWrapper, RenderImageInfoWrapper class _CountingEnv(gym.Env): # pragma: no cover @@ -78,8 +78,8 @@ def step(self, action): def _make_buffering_venv( - Env: Type[gym.Env], - error_on_premature_reset: bool, + Env: Type[gym.Env], + error_on_premature_reset: bool, ) -> BufferingWrapper: venv = DummyVecEnv([Env] * 2) wrapped_venv = BufferingWrapper(venv, error_on_premature_reset) @@ -95,7 +95,7 @@ def _assert_equal_scrambled_vectors(a: np.ndarray, b: np.ndarray) -> None: def _join_transitions( - trans_list: Sequence[types.TransitionsWithRew], + trans_list: Sequence[types.TransitionsWithRew], ) -> types.TransitionsWithRew: def concat(x): return np.concatenate(list(x)) @@ -121,10 +121,10 @@ def concat(x): @pytest.mark.parametrize("n_steps", [1, 2, 20, 21]) @pytest.mark.parametrize("extra_pop_timesteps", [(), (1,), (4, 8)]) def test_pop( - Env: Type[gym.Env], - episode_lengths: Sequence[int], - n_steps: int, - extra_pop_timesteps: Sequence[int], + Env: Type[gym.Env], + episode_lengths: Sequence[int], + n_steps: int, + extra_pop_timesteps: Sequence[int], ) -> None: """Check pop_transitions() results for BufferWrapper. @@ -278,3 +278,23 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): assert venv.n_transitions == 0 with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() + + +def test_writes_to_info(): + env = gym.make("CartPole-v0", render_mode="rgb_array") + wrapped_env = RenderImageInfoWrapper(env) + wrapped_env.reset() + _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample()) + assert "rendered_img" in info + + +def test_raises_assertion_error_if_env_not_in_correct_render_mode(): + wrong_mode = "human" + env = gym.make("CartPole-v0", render_mode=wrong_mode) + + with pytest.raises( + AssertionError, + match='The environment must be in render mode "rgb_array" in order to use this wrapper ' + f'but render_mode is "{wrong_mode}"' + ): + RenderImageInfoWrapper(env) From 92bbf95306f86d04944282e04a8e46cc1e7e3863 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Tue, 14 Nov 2023 17:41:28 +0100 Subject: [PATCH 063/143] Add querent kwargs and change config accordingly --- .../algorithms/preference_comparisons.py | 8 ++++++-- .../config/train_preference_comparisons.py | 16 ++++------------ .../scripts/train_preference_comparisons.py | 2 -- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 6108f444e..bf3d760ca 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -949,6 +949,7 @@ def __init__( self, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None ) -> None: """Initializes the preference gatherer. @@ -961,7 +962,8 @@ def __init__( # pass in a seed in training scripts (without worrying about whether # the PreferenceGatherer we use needs one). del rng - self.querent = PreferenceQuerent() + querent_kwargs = querent_kwargs or {} + self.querent = PreferenceQuerent(**querent_kwargs) self.logger = custom_logger or imit_logger.configure() self.pending_queries: Dict = {} @@ -1309,6 +1311,7 @@ def __init__( wait_for_user: bool = True, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None, ) -> None: """Initializes the preference gatherer. @@ -1319,7 +1322,8 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. """ super().__init__(rng, custom_logger) - self.querent = PrefCollectQuerent(pref_collect_address, "videos") + querent_kwargs = querent_kwargs if querent_kwargs else {} + self.querent = PrefCollectQuerent(pref_collect_address=pref_collect_address, **querent_kwargs) self.query_endpoint = pref_collect_address + "/preferences/query/" self.pending_queries = {} self.wait_for_user = wait_for_user diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 825b08e0a..26530cc8a 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -55,8 +55,6 @@ def train_defaults(): gatherer_cls = preference_comparisons.SyntheticGatherer # arguments passed on to the PreferenceGatherer specified by gatherer_cls gatherer_kwargs = {} - querent_cls = preference_comparisons.PreferenceQuerent - querent_kwargs = dict() active_selection = False active_selection_oversampling = 2 uncertainty_on = "logit" @@ -71,10 +69,6 @@ def train_defaults(): checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only) query_schedule = "hyperbolic" - # If set, save trajectory videos to this directory. Must be present if gather_cls is - # SynchronousCLIGatherer - video_log_dir = None - @train_preference_comparisons_ex.named_config def synch_human_preferences(): @@ -102,12 +96,10 @@ def human_preferences(): gatherer_kwargs = dict( pref_collect_address="http://127.0.0.1:8000", wait_for_user=True, - ) - querent_cls = preference_comparisons.PrefCollectQuerent - querent_kwargs = dict( - pref_collect_address="http://127.0.0.1:8000", - video_output_dir="../pref-collect/videofiles", - video_fps=20, + querent_kwargs=dict( + video_output_dir="../pref-collect/videofiles", + video_fps=20, + ), ) environment = dict( post_wrappers=dict( diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index e06df02de..71363daee 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -84,7 +84,6 @@ def train_preference_comparisons( allow_variable_horizon: bool, checkpoint_interval: int, query_schedule: Union[str, type_aliases.Schedule], - video_log_dir: Optional[str], _rnd: np.random.Generator, ) -> Mapping[str, Any]: """Train a reward model using preference comparisons. @@ -145,7 +144,6 @@ def train_preference_comparisons( be allocated to each iteration. "hyperbolic" and "inverse_quadratic" apportion fewer queries to later iterations when the policy is assumed to be better and more stable. - video_log_dir: If set, save videos to this directory. _rnd: Random number generator provided by Sacred. Returns: From 6b0e92eb742152eb7a1744a5d39dc737bb30bd1c Mon Sep 17 00:00:00 2001 From: rk1a Date: Tue, 21 Nov 2023 16:36:35 +0100 Subject: [PATCH 064/143] Fix default video dir --- src/imitation/scripts/config/train_preference_comparisons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 26530cc8a..bd96c329b 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -97,7 +97,7 @@ def human_preferences(): pref_collect_address="http://127.0.0.1:8000", wait_for_user=True, querent_kwargs=dict( - video_output_dir="../pref-collect/videofiles", + video_output_dir="../pref_collect/videofiles", video_fps=20, ), ) @@ -111,6 +111,7 @@ def human_preferences(): post_wrappers_kwargs=dict( RenderImageInfoWrapper=dict(scale_factor=0.5, use_file_cache=True), ), + env_make_kwargs=dict(render_mode="rgb_array"), ) From 85f01d4525c88260b2a90631b802af71ce229474 Mon Sep 17 00:00:00 2001 From: rk1a Date: Sun, 10 Dec 2023 14:33:12 +0100 Subject: [PATCH 065/143] Add test for video writing, fix some precommit errors --- docs/main-concepts/benchmark_summary.md | 69 +++++++ docs/tutorials/3_train_gail.ipynb | 4 +- docs/tutorials/4_train_airl.ipynb | 4 +- ...rain_preference_comparisons_with_cnn.ipynb | 42 +--- ...sons_with_synchronous_human_feedback.ipynb | 4 +- .../algorithms/preference_comparisons.py | 100 ++++----- src/imitation/data/wrappers.py | 15 +- src/imitation/scripts/eval_policy.py | 6 +- src/imitation/scripts/train_rl.py | 4 +- src/imitation/testing/expert_trajectories.py | 2 +- .../algorithms/test_preference_comparisons.py | 191 ++++++++++++------ tests/data/test_wrappers.py | 21 +- 12 files changed, 286 insertions(+), 176 deletions(-) create mode 100644 docs/main-concepts/benchmark_summary.md diff --git a/docs/main-concepts/benchmark_summary.md b/docs/main-concepts/benchmark_summary.md new file mode 100644 index 000000000..a2480f292 --- /dev/null +++ b/docs/main-concepts/benchmark_summary.md @@ -0,0 +1,69 @@ +# Benchmark Summary + +This is a summary of the sacred runs in `benchmark_runs` generated by `sacred_output_to_markdown_summary.py`. +## Scores + +The scores are normalized based on the performance of a random agent as the baseline and the expert as the maximum possible score as explained [in this blog post](https://araffin.github.io/post/rliable/): +> `(score - random_score) / (expert_score - random_score)` + +Aggregate scores and confidence intervals are computed using the [rliable library](https://agarwl.github.io/rliable/). +### AIRL +Environment | Score (mean/std)| Normalized Score (mean/std) | N + --- | --- | --- | --- +seals/Ant-v1 | 2485.889 / 533.471 | 0.981 / 0.184 | 10 +seals/HalfCheetah-v1 | 938.450 / 804.871 | 0.627 / 0.412 | 10 +seals/Hopper-v1 | 183.780 / 93.295 | 0.921 / 0.373 | 10 +seals/Swimmer-v1 | 286.699 / 7.763 | 0.970 / 0.027 | 10 +seals/Walker2d-v1 | 1154.921 / 659.564 | 0.461 / 0.264 | 10 + +#### Aggregate Normalized scores +Metric | Value | 95% CI + --- | --- | --- +Mean | 0.792 | [0.709, 0.792] +IQM | 0.918 | [0.871, 0.974] + +### BC +Environment | Score (mean/std)| Normalized Score (mean/std) | N + --- | --- | --- | --- +seals/Ant-v1 | 2090.551 / 180.340 | 0.844 / 0.062 | 10 +seals/HalfCheetah-v1 | 1516.476 / 37.487 | 0.923 / 0.019 | 10 +seals/Hopper-v1 | 204.271 / 0.609 | 1.003 / 0.002 | 10 +seals/Swimmer-v1 | 276.242 / 9.328 | 0.935 / 0.032 | 10 +seals/Walker2d-v1 | 2393.254 / 37.641 | 0.956 / 0.015 | 10 + +#### Aggregate Normalized scores +Metric | Value | 95% CI + --- | --- | --- +Mean | 0.932 | [0.922, 0.932] +IQM | 0.941 | [0.941, 0.949] + +### DAGGER +Environment | Score (mean/std)| Normalized Score (mean/std) | N + --- | --- | --- | --- +seals/Ant-v1 | 2302.527 / 108.315 | 0.957 / 0.052 | 10 +seals/HalfCheetah-v1 | 1615.004 / 8.262 | 1.017 / 0.008 | 10 +seals/Hopper-v1 | 204.789 / 1.599 | 1.011 / 0.012 | 10 +seals/Swimmer-v1 | 283.776 / 6.524 | 0.988 / 0.024 | 10 +seals/Walker2d-v1 | 2419.748 / 52.215 | 1.002 / 0.026 | 10 + +#### Aggregate Normalized scores +Metric | Value | 95% CI + --- | --- | --- +Mean | 0.995 | [0.987, 0.998] +IQM | 1.004 | [1.003, 1.008] + +### GAIL +Environment | Score (mean/std)| Normalized Score (mean/std) | N + --- | --- | --- | --- +seals/Ant-v1 | 2527.566 / 148.034 | 0.995 / 0.051 | 10 +seals/HalfCheetah-v1 | 1595.129 / 37.374 | 0.963 / 0.019 | 10 +seals/Hopper-v1 | 187.105 / 14.298 | 0.935 / 0.057 | 10 +seals/Swimmer-v1 | 249.949 / 74.295 | 0.845 / 0.254 | 10 +seals/Walker2d-v1 | 2399.196 / 89.949 | 0.959 / 0.036 | 10 + +#### Aggregate Normalized scores +Metric | Value | 95% CI + --- | --- | --- +Mean | 0.939 | [0.900, 0.944] +IQM | 0.957 | [0.965, 0.970] + diff --git a/docs/tutorials/3_train_gail.ipynb b/docs/tutorials/3_train_gail.ipynb index 63c2649b8..a96833c10 100644 --- a/docs/tutorials/3_train_gail.ipynb +++ b/docs/tutorials/3_train_gail.ipynb @@ -37,7 +37,9 @@ " \"seals:seals/CartPole-v0\",\n", " rng=np.random.default_rng(SEED),\n", " n_envs=8,\n", - " post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)}, # needed for computing rollouts later\n", + " post_wrappers={\n", + " \"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)\n", + " }, # needed for computing rollouts later\n", ")\n", "expert = load_policy(\n", " \"ppo-huggingface\",\n", diff --git a/docs/tutorials/4_train_airl.ipynb b/docs/tutorials/4_train_airl.ipynb index b610705fe..3af6107f6 100644 --- a/docs/tutorials/4_train_airl.ipynb +++ b/docs/tutorials/4_train_airl.ipynb @@ -41,7 +41,9 @@ " \"seals:seals/CartPole-v0\",\n", " rng=np.random.default_rng(SEED),\n", " n_envs=8,\n", - " post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)}, # needed for computing rollouts later\n", + " post_wrappers={\n", + " \"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)\n", + " }, # needed for computing rollouts later\n", ")\n", "expert = load_policy(\n", " \"ppo-huggingface\",\n", diff --git a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb index 5a8f2ee7c..f72b9ca50 100644 --- a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb +++ b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "93187e19", "metadata": { "ExecuteTime": { @@ -30,45 +30,7 @@ "start_time": "2023-07-02T18:29:18.742766046Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)\n", - "[Powered by Stella]\n" - ] - }, - { - "ename": "AttributeError", - "evalue": "'numpy.random._generator.Generator' object has no attribute 'randint'", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[1], line 67\u001B[0m\n\u001B[1;32m 49\u001B[0m reward_trainer \u001B[38;5;241m=\u001B[39m preference_comparisons\u001B[38;5;241m.\u001B[39mBasicRewardTrainer(\n\u001B[1;32m 50\u001B[0m preference_model\u001B[38;5;241m=\u001B[39mpreference_model,\n\u001B[1;32m 51\u001B[0m loss\u001B[38;5;241m=\u001B[39mpreference_comparisons\u001B[38;5;241m.\u001B[39mCrossEntropyRewardLoss(),\n\u001B[1;32m 52\u001B[0m epochs\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m3\u001B[39m,\n\u001B[1;32m 53\u001B[0m rng\u001B[38;5;241m=\u001B[39mrng,\n\u001B[1;32m 54\u001B[0m )\n\u001B[1;32m 56\u001B[0m agent \u001B[38;5;241m=\u001B[39m PPO(\n\u001B[1;32m 57\u001B[0m policy\u001B[38;5;241m=\u001B[39mCnnPolicy,\n\u001B[1;32m 58\u001B[0m env\u001B[38;5;241m=\u001B[39mvenv,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 64\u001B[0m n_epochs\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m4\u001B[39m,\n\u001B[1;32m 65\u001B[0m )\n\u001B[0;32m---> 67\u001B[0m trajectory_generator \u001B[38;5;241m=\u001B[39m \u001B[43mpreference_comparisons\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mAgentTrainer\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 68\u001B[0m \u001B[43m \u001B[49m\u001B[43malgorithm\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43magent\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 69\u001B[0m \u001B[43m \u001B[49m\u001B[43mreward_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mreward_net\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 70\u001B[0m \u001B[43m \u001B[49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 71\u001B[0m \u001B[43m \u001B[49m\u001B[43mexploration_frac\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.0\u001B[39;49m\u001B[43m,\u001B[49m\n\u001B[1;32m 72\u001B[0m \u001B[43m \u001B[49m\u001B[43mrng\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mrng\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 73\u001B[0m \u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m pref_comparisons \u001B[38;5;241m=\u001B[39m preference_comparisons\u001B[38;5;241m.\u001B[39mPreferenceComparisons(\n\u001B[1;32m 76\u001B[0m trajectory_generator,\n\u001B[1;32m 77\u001B[0m reward_net,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 87\u001B[0m initial_epoch_multiplier\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m,\n\u001B[1;32m 88\u001B[0m )\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/algorithms/preference_comparisons.py:187\u001B[0m, in \u001B[0;36mAgentTrainer.__init__\u001B[0;34m(self, algorithm, reward_fn, venv, rng, exploration_frac, switch_prob, random_prob, custom_logger)\u001B[0m\n\u001B[1;32m 177\u001B[0m \u001B[38;5;66;03m# The BufferingWrapper records all trajectories, so we can return\u001B[39;00m\n\u001B[1;32m 178\u001B[0m \u001B[38;5;66;03m# them after training. This should come first (before the wrapper that\u001B[39;00m\n\u001B[1;32m 179\u001B[0m \u001B[38;5;66;03m# changes the reward function), so that we return the original environment\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 184\u001B[0m \u001B[38;5;66;03m# SB3 may move the image-channel dimension in the observation space, making\u001B[39;00m\n\u001B[1;32m 185\u001B[0m \u001B[38;5;66;03m# `algorithm.get_env()` not match with `reward_fn`.\u001B[39;00m\n\u001B[1;32m 186\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuffering_wrapper \u001B[38;5;241m=\u001B[39m wrappers\u001B[38;5;241m.\u001B[39mBufferingWrapper(venv)\n\u001B[0;32m--> 187\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mvenv \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreward_venv_wrapper \u001B[38;5;241m=\u001B[39m \u001B[43mreward_wrapper\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mRewardVecEnvWrapper\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 188\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbuffering_wrapper\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 189\u001B[0m \u001B[43m \u001B[49m\u001B[43mreward_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreward_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 190\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 192\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlog_callback \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreward_venv_wrapper\u001B[38;5;241m.\u001B[39mmake_log_callback()\n\u001B[1;32m 194\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39malgorithm\u001B[38;5;241m.\u001B[39mset_env(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mvenv)\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/rewards/reward_wrapper.py:73\u001B[0m, in \u001B[0;36mRewardVecEnvWrapper.__init__\u001B[0;34m(self, venv, reward_fn, ep_history)\u001B[0m\n\u001B[1;32m 71\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[1;32m 72\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_actions \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m---> 73\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/rewards/reward_wrapper.py:84\u001B[0m, in \u001B[0;36mRewardVecEnvWrapper.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 83\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m):\n\u001B[0;32m---> 84\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 85\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_old_obs\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/data/wrappers.py:126\u001B[0m, in \u001B[0;36mBufferingWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 124\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_init_reset \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[1;32m 125\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_transitions \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m--> 126\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 127\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_traj_accum \u001B[38;5;241m=\u001B[39m rollout\u001B[38;5;241m.\u001B[39mTrajectoryAccumulator()\n\u001B[1;32m 128\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, ob \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(obs):\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py:58\u001B[0m, in \u001B[0;36mVecFrameStack.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 54\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Union[np\u001B[38;5;241m.\u001B[39mndarray, Dict[\u001B[38;5;28mstr\u001B[39m, np\u001B[38;5;241m.\u001B[39mndarray]]:\n\u001B[1;32m 55\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 56\u001B[0m \u001B[38;5;124;03m Reset all environments\u001B[39;00m\n\u001B[1;32m 57\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[0;32m---> 58\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvenv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m \u001B[38;5;66;03m# pytype:disable=annotation-type-mismatch\u001B[39;00m\n\u001B[1;32m 60\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstackedobs\u001B[38;5;241m.\u001B[39mreset(observation)\n\u001B[1;32m 61\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m observation\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:74\u001B[0m, in \u001B[0;36mDummyVecEnv.reset\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 72\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m VecEnvObs:\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m env_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnum_envs):\n\u001B[0;32m---> 74\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menvs\u001B[49m\u001B[43m[\u001B[49m\u001B[43menv_idx\u001B[49m\u001B[43m]\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_save_obs(env_idx, obs)\n\u001B[1;32m 76\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_obs_from_buf()\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/monitor.py:83\u001B[0m, in \u001B[0;36mMonitor.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 81\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mExpected you to pass keyword argument \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mkey\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m into reset\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 82\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcurrent_reset_info[key] \u001B[38;5;241m=\u001B[39m value\n\u001B[0;32m---> 83\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/src/imitation/data/wrappers.py:261\u001B[0m, in \u001B[0;36mRolloutInfoWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 260\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 261\u001B[0m new_obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 262\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_obs \u001B[38;5;241m=\u001B[39m [new_obs]\n\u001B[1;32m 263\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_rews \u001B[38;5;241m=\u001B[39m []\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/wrappers/time_limit.py:27\u001B[0m, in \u001B[0;36mTimeLimit.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 25\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m 26\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_elapsed_steps \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m---> 27\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:292\u001B[0m, in \u001B[0;36mWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 291\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 292\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:333\u001B[0m, in \u001B[0;36mRewardWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 332\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 333\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/gym/core.py:319\u001B[0m, in \u001B[0;36mObservationWrapper.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 318\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 319\u001B[0m observation \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 320\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mobservation(observation)\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:59\u001B[0m, in \u001B[0;36mFireResetEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 58\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m np\u001B[38;5;241m.\u001B[39mndarray:\n\u001B[0;32m---> 59\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 60\u001B[0m obs, _, done, _ \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39menv\u001B[38;5;241m.\u001B[39mstep(\u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 61\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m done:\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:106\u001B[0m, in \u001B[0;36mEpisodicLifeEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 97\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 98\u001B[0m \u001B[38;5;124;03mCalls the Gym environment reset, only when lives are exhausted.\u001B[39;00m\n\u001B[1;32m 99\u001B[0m \u001B[38;5;124;03mThis way all states are still reachable even though lives are episodic,\u001B[39;00m\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 103\u001B[0m \u001B[38;5;124;03m:return: the first observation of the environment\u001B[39;00m\n\u001B[1;32m 104\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 105\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mwas_real_done:\n\u001B[0;32m--> 106\u001B[0m obs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 107\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 108\u001B[0m \u001B[38;5;66;03m# no-op step to advance from terminal/lost life state\u001B[39;00m\n\u001B[1;32m 109\u001B[0m obs, _, _, _ \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39menv\u001B[38;5;241m.\u001B[39mstep(\u001B[38;5;241m0\u001B[39m)\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:154\u001B[0m, in \u001B[0;36mMaxAndSkipEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 153\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mreset\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m GymObs:\n\u001B[0;32m--> 154\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43menv\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreset\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/PycharmProjects/rk1a-imitation/venv/lib/python3.9/site-packages/stable_baselines3/common/atari_wrappers.py:36\u001B[0m, in \u001B[0;36mNoopResetEnv.reset\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m 34\u001B[0m noops \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moverride_num_noops\n\u001B[1;32m 35\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m---> 36\u001B[0m noops \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43munwrapped\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnp_random\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrandint\u001B[49m(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnoop_max \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 37\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m noops \u001B[38;5;241m>\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[1;32m 38\u001B[0m obs \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros(\u001B[38;5;241m0\u001B[39m)\n", - "\u001B[0;31mAttributeError\u001B[0m: 'numpy.random._generator.Generator' object has no attribute 'randint'" - ] - } - ], + "outputs": [], "source": [ "import torch as th\n", "import gymnasium as gym\n", diff --git a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb index e10f824f7..5f768784d 100644 --- a/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb +++ b/docs/tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback.ipynb @@ -51,7 +51,9 @@ " \"Pendulum-v1\",\n", " rng=rng,\n", " post_wrappers={\n", - " \"VideoWrapper\": video_wrapper.video_wrapper_factory(pathlib.Path(video_dir), single_video=False)\n", + " \"VideoWrapper\": video_wrapper.video_wrapper_factory(\n", + " pathlib.Path(video_dir), single_video=False\n", + " )\n", " },\n", ")\n", "\n", diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index bf3d760ca..94977bbc1 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -3,6 +3,7 @@ Trains a reward model and optionally a policy based on preferences between trajectory fragments. """ + import abc import math import os @@ -31,6 +32,7 @@ import numpy as np import requests import torch as th +from moviepy.editor import ImageSequenceClip from scipy import special from stable_baselines3.common import base_class, type_aliases, utils, vec_env from torch import nn @@ -838,7 +840,7 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(custom_logger) + super().__init__(custom_logger=custom_logger) self.rng = rng self.query_endpoint = pref_collect_address + "/preferences/query/" self.video_output_dir = video_output_dir @@ -880,66 +882,48 @@ def _query(self, query_id): ) +def add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: + if frames.shape[-1] < 3: + missing_channels = 3 - frames.shape[-1] + frames = np.concatenate( + [frames] + missing_channels * [frames[..., -1][..., None]], + axis=-1, + ) + return frames + + def write_fragment_video( - fragment: TrajectoryWithRew, frames_per_second: int, output_path: AnyPath + fragment: TrajectoryWithRew, + frames_per_second: int, + output_path: AnyPath, + progress_logger: bool = True, ) -> None: """Write fragment video clip.""" - frame_shape = get_frame_shape(fragment) - video_writer = cv2.VideoWriter( - output_path, - cv2.VideoWriter_fourcc(*"VP90"), - frames_per_second, - frame_shape, - ) - - # Make videos from rendered observations if available - frames: np.ndarray + frames_list: List[Union[os.PathLike, np.ndarray]] = [] + # Create fragment videos from environment's render images if available if fragment.infos is not None and "rendered_img" in fragment.infos[0]: - frames_list = [] for i in range(len(fragment.infos)): - frame_info = fragment.infos[i]["rendered_img"] - # If path is provided load cached image - if isinstance(frame_info, (str, bytes, os.PathLike)): - frame = np.load(frame_info) - elif isinstance(frame_info, np.ndarray): - frame = frame_info + frame: Union[os.PathLike, np.ndarray] = fragment.infos[i][ + "rendered_img" + ] + if isinstance(frame, np.ndarray): + frame = add_missing_rgb_channels(frame) frames_list.append(frame) - frames = np.array(frames_list) + # Create fragment video from observations if possible else: - frames = fragment.obs - - for frame in frames: - # Transform to RGB frame if necessary - if frame.shape[-1] < 3: - missing_channels = 3 - frame.shape[-1] - frame = np.concatenate( - [frame] + missing_channels * [frame[..., -1][..., None]], - axis=-1, - ) - video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - - video_writer.release() - - -def get_frame_shape(fragment: TrajectoryWithRew) -> Tuple[int, int]: - """Calculate frame shape.""" - if fragment.infos is not None and "rendered_img" in fragment.infos[0]: - rendered_img_info = fragment.infos[0]["rendered_img"] - # If path is provided load cached image - if isinstance(rendered_img_info, (str, bytes, os.PathLike)): - single_frame = np.load(rendered_img_info) + if isinstance(fragment.obs, np.ndarray): + frames_list = [frame for frame in add_missing_rgb_channels(fragment.obs[1:])] else: - single_frame = rendered_img_info - else: - single_frame = np.array(fragment.obs[0]) - # Check whether observations are image-like - if len(single_frame.shape) < 2: + # TODO add support for DictObs raise ValueError( - "Observation must be an image, " - f"but shape {single_frame.shape} has too few dimensions!", + "Unsupported observation type " + f"for writing fragment video: {type(fragment.obs)}", ) - # Swap dimensions, because matrix and image dims are swapped - return single_frame.shape[1], single_frame.shape[0] + # Note: `ImageSeqeuenceClip` handily accepts both + # lists of image paths or numpy arrays + clip = ImageSequenceClip(frames_list, fps=frames_per_second) + moviepy_logger = None if not progress_logger else "bar" + clip.write_videofile(output_path, logger=moviepy_logger) class PreferenceGatherer(abc.ABC): @@ -949,13 +933,14 @@ def __init__( self, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None + querent_kwargs: Optional[Mapping] = None, ) -> None: """Initializes the preference gatherer. Args: rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. + querent_kwargs: Keyword arguments passed to the querent. """ # The random seed isn't used here, but it's useful to have this # as an argument nevertheless because that means we can always @@ -1046,7 +1031,8 @@ def __init__( def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Computes probability fragment 1 is preferred over fragment 2.""" queries = list(self.pending_queries.values()) - self.pending_queries.clear() # Clear pending queries because the oracle will have answered all + # Clear pending queries because the oracle will have answered all + self.pending_queries.clear() returns1, returns2 = self._reward_sums(queries) @@ -1132,7 +1118,6 @@ def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """ preferences = np.zeros(len(self.pending_queries), dtype=np.float32) for i, (query_id, query) in enumerate(self.pending_queries.items()): - write_fragment_video( query[0], frames_per_second=self.frames_per_second, @@ -1320,16 +1305,19 @@ def __init__( wait_for_user: Waits for user to input their preferences. rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. + querent_kwargs: Keyword arguments passed to the querent. """ super().__init__(rng, custom_logger) querent_kwargs = querent_kwargs if querent_kwargs else {} - self.querent = PrefCollectQuerent(pref_collect_address=pref_collect_address, **querent_kwargs) + self.querent = PrefCollectQuerent( + pref_collect_address=pref_collect_address, + **querent_kwargs, + ) self.query_endpoint = pref_collect_address + "/preferences/query/" self.pending_queries = {} self.wait_for_user = wait_for_user def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: - # TODO: create user-independent (automated) waiting policy if self.wait_for_user: print("Waiting for user to provide preferences. Press enter to continue.") diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 825a72f8b..122252f50 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -6,8 +6,9 @@ import uuid from typing import List, Optional, Sequence, Tuple -import gymnasium as gym import cv2 +import gymnasium as gym +import imageio import numpy as np import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper @@ -38,9 +39,11 @@ def __init__( scale_factor: scales rendered images to be stored. use_file_cache: whether to save rendered images to disk. """ - assert env.render_mode == "rgb_array", \ - f'The environment must be in render mode "rgb_array" in order to use this wrapper but render_mode is ' \ - f'"{env.render_mode}".' + assert env.render_mode == "rgb_array", ( + "The environment must be in render mode 'rgb_array' in order" + " to use this wrapper but render_mode is " + f"'{env.render_mode}'." + ) super().__init__(env) self.scale_factor = scale_factor self.use_file_cache = use_file_cache @@ -67,9 +70,9 @@ def step(self, action): else: unique_file_path = os.path.join( self.file_cache, - str(uuid.uuid4()) + ".npy", + str(uuid.uuid4()) + ".png", ) - np.save(unique_file_path, scaled_rendered_image) + imageio.imwrite(unique_file_path, scaled_rendered_image) info["rendered_img"] = unique_file_path return observation, reward, terminated, truncated, info diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index ee2da8f9e..5bb68a9e3 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -94,7 +94,11 @@ def eval_policy( """ log_dir = logging_ingredient.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) - post_wrappers = {"VideoWrapper": video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)} if videos else {} + post_wrappers = ( + {"VideoWrapper": video_wrapper.video_wrapper_factory(log_dir, **video_kwargs)} + if videos + else {} + ) render_mode = "rgb_array" if videos else None with environment.make_venv( # type: ignore[wrong-keyword-args] post_wrappers=post_wrappers, diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 56d7bf199..c1764f6f8 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -98,7 +98,9 @@ def train_rl( rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) - post_wrappers = {"RolloutInfoWrapper": lambda env, idx: wrappers.RolloutInfoWrapper(env)} + post_wrappers = { + "RolloutInfoWrapper": lambda env, idx: wrappers.RolloutInfoWrapper(env), + } with environment.make_venv( # type: ignore[wrong-keyword-args] post_wrappers=post_wrappers, ) as venv: diff --git a/src/imitation/testing/expert_trajectories.py b/src/imitation/testing/expert_trajectories.py index 7477cc268..220f830e4 100644 --- a/src/imitation/testing/expert_trajectories.py +++ b/src/imitation/testing/expert_trajectories.py @@ -38,7 +38,7 @@ def generate_expert_trajectories( """ env = util.make_vec_env( env_id, - post_wrappers={"RolloutInfoWrapper": lambda e, _: wrappers.RolloutInfoWrapper(e)}, + post_wrappers=[lambda e, _: wrappers.RolloutInfoWrapper(e)], rng=rng, ) try: diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 504d2d928..5b73e8569 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1,14 +1,17 @@ """Tests for the preference comparisons reward learning implementation.""" -import abc import math +import os import pathlib import re +import shutil +import tempfile import uuid from typing import Any, Sequence, Tuple from unittest.mock import MagicMock, Mock, patch import gymnasium as gym +import imageio import numpy as np import pytest import seals # noqa: F401 @@ -27,6 +30,7 @@ PreferenceGatherer, PreferenceQuerent, SyntheticGatherer, + write_fragment_video, ) from imitation.data import types from imitation.data.types import TrajectoryWithRew, TrajectoryWithRewPair @@ -308,7 +312,9 @@ def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): ) with pytest.raises(ValueError, match=with_rng_msg): - build_preference_comparisons(gatherer, reward_trainer, random_fragmenter, rng=rng) + build_preference_comparisons( + gatherer, reward_trainer, random_fragmenter, rng=rng, + ) # This should not raise build_preference_comparisons(None, None, None, rng=rng) @@ -317,61 +323,6 @@ def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): build_preference_comparisons(None, None, random_fragmenter, rng=rng) -@patch("builtins.input") -@patch("IPython.display.display") -def test_synchronous_human_gatherer(mock_display, mock_input): - del mock_display # unused - querent = PreferenceQuerent() - gatherer = preference_comparisons.SynchronousHumanGatherer( - video_dir=pathlib.Path("."), - ) - - # these inputs are designed solely to pass the test. they aren't tested for anything - trajectory_pairs = [ - ( - types.TrajectoryWithRew( - np.zeros((2, 200, 200, 3,), np.uint8), - np.array([1]), - np.array( - [ - { - "video_path": pathlib.Path( - "tests/algorithms/test_preference_comparisons.py", - ), - }, - ], - ), - True, - np.array([1.0]), - ), - types.TrajectoryWithRew( - np.zeros((2, 200, 200, 3,), np.uint8), - np.array([1]), # act - np.array( # info - [ - { - "video_path": pathlib.Path( - "tests/algorithms/test_preference_comparisons.py", - ), - }, - ], - ), - True, # done - np.array([1.0]), # reward - ), - ), - ] - gatherer.query(trajectory_pairs) - - # this is the actual test - mock_input.return_value = "1" - assert gatherer.gather()[1] == np.array([1.0]) - - gatherer.query(trajectory_pairs) - mock_input.return_value = "2" - assert gatherer.gather()[1] == np.array([0.0]) - - @pytest.mark.parametrize( "schedule", ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], @@ -1181,6 +1132,7 @@ def test_that_trainer_improves( ) +# PreferenceQuerent def test_returns_query_dict_from_query_sequence_with_correct_length(): querent = PreferenceQuerent() query_sequence = [Mock()] @@ -1199,6 +1151,7 @@ def test_returned_queries_have_uuid(): pytest.fail() +# PrefCollectQuerent def test_sends_put_request_for_each_query(requests_mock): address = "https://test.de" querent = PrefCollectQuerent(pref_collect_address=address, video_output_dir="video") @@ -1211,11 +1164,60 @@ def test_sends_put_request_for_each_query(requests_mock): assert requests_mock.last_request.text == f'{{"uuid": "{query_id}"}}' +@pytest.fixture( + params=["obs_only", "dictobs", "with_render_images", "with_render_image_paths"], +) +def fragment(request): + num_frames = 10 + frame_shape = (200, 200) + obs = np.zeros((num_frames, *frame_shape, 3), np.uint8) + acts = np.zeros((num_frames - 1,), np.uint8) + rews = np.zeros((num_frames - 1,)) + infos = None + if request.param == "dictobs": + obs = types.DictObs({"obs": obs}) + elif request.param == "with_render_images": + infos = np.array([{"rendered_img": frame} for frame in obs[1:]]) + elif request.param == "with_render_image_paths": + tmp_dir = tempfile.mkdtemp() + infos = [] + for frame in obs[1:]: + unique_file_path = os.path.join( + tmp_dir, + str(uuid.uuid4()) + ".png", + ) + imageio.imwrite(unique_file_path, frame) + infos.append({"rendered_img": unique_file_path}) + infos = np.array(infos) + yield types.TrajectoryWithRew( + obs=obs, + acts=acts, + infos=infos, + terminal=True, + rews=rews, + ) + if request.param == "with_render_image_paths": + shutil.rmtree(tmp_dir) + + +@pytest.mark.parametrize("codec", ["webm", "mp4"]) +def test_write_fragment_video(fragment, codec): + video_path = f"video.{codec}" + if isinstance(fragment.obs, types.DictObs): + with pytest.raises(ValueError): + write_fragment_video(fragment, frames_per_second=5, output_path=video_path) + else: + write_fragment_video(fragment, frames_per_second=5, output_path=video_path) + assert os.path.isfile(video_path) + os.remove(video_path) + + +# PreferenceGatherer class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: - pass + return np.zeros(shape=(1,)) def test_adds_queries_to_pending_queries(): @@ -1238,6 +1240,7 @@ def test_clears_pending_queries(trajectory_with_rew): assert len(gatherer.pending_queries) == 0 +# PrefCollectGatherer def test_returns_none_for_unanswered_query(requests_mock): address = "https://test.de" query_id = "1234" @@ -1329,3 +1332,75 @@ def test_ignores_incomparable_answer(): assert len(gathered_preferences) == 0 assert len(gathered_queries) == 0 + + +# SynchronousHumanGatherer +@patch("builtins.input") +@patch("IPython.display.display") +def test_synchronous_human_gatherer(mock_display, mock_input): + del mock_display # unused + querent = PreferenceQuerent() + gatherer = preference_comparisons.SynchronousHumanGatherer( + video_dir=pathlib.Path("."), + ) + + # these inputs are designed solely to pass the test. they aren't tested for anything + trajectory_pairs = [ + ( + types.TrajectoryWithRew( + np.zeros( + ( + 2, + 200, + 200, + 3, + ), + np.uint8, + ), + np.array([1]), + np.array( + [ + { + "video_path": pathlib.Path( + "tests/algorithms/test_preference_comparisons.py", + ), + }, + ], + ), + True, + np.array([1.0]), + ), + types.TrajectoryWithRew( + np.zeros( + ( + 2, + 200, + 200, + 3, + ), + np.uint8, + ), + np.array([1]), # act + np.array( # info + [ + { + "video_path": pathlib.Path( + "tests/algorithms/test_preference_comparisons.py", + ), + }, + ], + ), + True, # done + np.array([1.0]), # reward + ), + ), + ] + gatherer.query(trajectory_pairs) + + # this is the actual test + mock_input.return_value = "1" + assert gatherer.gather()[1] == np.array([1.0]) + + gatherer.query(trajectory_pairs) + mock_input.return_value = "2" + assert gatherer.gather()[1] == np.array([0.0]) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index fcc85c760..1c78e73e2 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -78,8 +78,8 @@ def step(self, action): def _make_buffering_venv( - Env: Type[gym.Env], - error_on_premature_reset: bool, + Env: Type[gym.Env], + error_on_premature_reset: bool, ) -> BufferingWrapper: venv = DummyVecEnv([Env] * 2) wrapped_venv = BufferingWrapper(venv, error_on_premature_reset) @@ -95,7 +95,7 @@ def _assert_equal_scrambled_vectors(a: np.ndarray, b: np.ndarray) -> None: def _join_transitions( - trans_list: Sequence[types.TransitionsWithRew], + trans_list: Sequence[types.TransitionsWithRew], ) -> types.TransitionsWithRew: def concat(x): return np.concatenate(list(x)) @@ -121,10 +121,10 @@ def concat(x): @pytest.mark.parametrize("n_steps", [1, 2, 20, 21]) @pytest.mark.parametrize("extra_pop_timesteps", [(), (1,), (4, 8)]) def test_pop( - Env: Type[gym.Env], - episode_lengths: Sequence[int], - n_steps: int, - extra_pop_timesteps: Sequence[int], + Env: Type[gym.Env], + episode_lengths: Sequence[int], + n_steps: int, + extra_pop_timesteps: Sequence[int], ) -> None: """Check pop_transitions() results for BufferWrapper. @@ -293,8 +293,9 @@ def test_raises_assertion_error_if_env_not_in_correct_render_mode(): env = gym.make("CartPole-v0", render_mode=wrong_mode) with pytest.raises( - AssertionError, - match='The environment must be in render mode "rgb_array" in order to use this wrapper ' - f'but render_mode is "{wrong_mode}"' + AssertionError, + match='The environment must be in render mode "rgb_array" ' + 'in order to use this wrapper ' + f'but render_mode is "{wrong_mode}"', ): RenderImageInfoWrapper(env) From be8e0a8b33d9bf3a80aef3c66b59048a73bebb93 Mon Sep 17 00:00:00 2001 From: rk1a Date: Sun, 10 Dec 2023 16:01:27 +0100 Subject: [PATCH 066/143] Add and fix more tests --- .../algorithms/preference_comparisons.py | 8 +-- .../algorithms/test_preference_comparisons.py | 66 ++++++++++++++----- tests/data/test_wrappers.py | 4 +- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 94977bbc1..50c380270 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -903,16 +903,16 @@ def write_fragment_video( # Create fragment videos from environment's render images if available if fragment.infos is not None and "rendered_img" in fragment.infos[0]: for i in range(len(fragment.infos)): - frame: Union[os.PathLike, np.ndarray] = fragment.infos[i][ - "rendered_img" - ] + frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] if isinstance(frame, np.ndarray): frame = add_missing_rgb_channels(frame) frames_list.append(frame) # Create fragment video from observations if possible else: if isinstance(fragment.obs, np.ndarray): - frames_list = [frame for frame in add_missing_rgb_channels(fragment.obs[1:])] + frames_list = [ + frame for frame in add_missing_rgb_channels(fragment.obs[1:]) + ] else: # TODO add support for DictObs raise ValueError( diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 5b73e8569..00ae768a9 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -313,7 +313,10 @@ def build_preference_comparisons(gatherer, reward_trainer, fragmenter, rng): with pytest.raises(ValueError, match=with_rng_msg): build_preference_comparisons( - gatherer, reward_trainer, random_fragmenter, rng=rng, + gatherer, + reward_trainer, + random_fragmenter, + rng=rng, ) # This should not raise @@ -1164,20 +1167,43 @@ def test_sends_put_request_for_each_query(requests_mock): assert requests_mock.last_request.text == f'{{"uuid": "{query_id}"}}' +@pytest.fixture +def empty_trajectory_with_rew(): + num_frames = 10 + frame_shape = (200, 200) + return types.TrajectoryWithRew( + obs=np.zeros((num_frames, *frame_shape, 3), np.uint8), + acts=np.zeros((num_frames - 1,), np.uint8), + infos=np.array([{} for _ in range(num_frames - 1)]), + rews=np.zeros((num_frames - 1,)), + terminal=True, + ) + + +def test_prefcollectquerent_call_creates_all_videos(empty_trajectory_with_rew): + address = "https://test.de" + queries = [(empty_trajectory_with_rew, empty_trajectory_with_rew)] + querent = PrefCollectQuerent(pref_collect_address=address, video_output_dir="video") + identified_queries = querent(queries) + for query_id, _ in identified_queries.items(): + file = os.path.join(querent.video_output_dir, query_id + "-{}.webm") + for part in ["left", "right"]: + assert os.path.isfile(file.format(part)) + os.remove(file.format(part)) + + @pytest.fixture( params=["obs_only", "dictobs", "with_render_images", "with_render_image_paths"], ) -def fragment(request): - num_frames = 10 - frame_shape = (200, 200) - obs = np.zeros((num_frames, *frame_shape, 3), np.uint8) - acts = np.zeros((num_frames - 1,), np.uint8) - rews = np.zeros((num_frames - 1,)) - infos = None +def fragment(request, empty_trajectory_with_rew): + obs = empty_trajectory_with_rew.obs + infos = empty_trajectory_with_rew.infos if request.param == "dictobs": - obs = types.DictObs({"obs": obs}) + obs = types.DictObs({"obs": empty_trajectory_with_rew.obs}) elif request.param == "with_render_images": - infos = np.array([{"rendered_img": frame} for frame in obs[1:]]) + infos = np.array( + [{"rendered_img": frame} for frame in empty_trajectory_with_rew.obs[1:]], + ) elif request.param == "with_render_image_paths": tmp_dir = tempfile.mkdtemp() infos = [] @@ -1191,15 +1217,16 @@ def fragment(request): infos = np.array(infos) yield types.TrajectoryWithRew( obs=obs, - acts=acts, + acts=empty_trajectory_with_rew.acts, infos=infos, terminal=True, - rews=rews, + rews=empty_trajectory_with_rew.rews, ) if request.param == "with_render_image_paths": shutil.rmtree(tmp_dir) +# utils @pytest.mark.parametrize("codec", ["webm", "mp4"]) def test_write_fragment_video(fragment, codec): video_path = f"video.{codec}" @@ -1246,7 +1273,10 @@ def test_returns_none_for_unanswered_query(requests_mock): query_id = "1234" answer = None - gatherer = PrefCollectGatherer(pref_collect_address=address) + gatherer = PrefCollectGatherer( + pref_collect_address=address, + querent_kwargs={"video_output_dir": "videos"}, + ) requests_mock.get( f"{address}/preferences/query/{query_id}", @@ -1263,7 +1293,10 @@ def test_returns_preference_for_answered_query(requests_mock): query_id = "1234" answer = 1.0 - gatherer = PrefCollectGatherer(pref_collect_address=address) + gatherer = PrefCollectGatherer( + pref_collect_address=address, + querent_kwargs={"video_output_dir": "videos"}, + ) requests_mock.get( f"{address}/preferences/query/{query_id}", @@ -1279,6 +1312,7 @@ def test_keeps_pending_query_for_unanswered_query(): gatherer = PrefCollectGatherer( pref_collect_address="https://test.de", wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, ) gatherer._gather_preference = MagicMock(return_value=None) gatherer.pending_queries = {"1234": Mock()} @@ -1293,6 +1327,7 @@ def test_deletes_pending_query_for_answered_query(): gatherer = PrefCollectGatherer( pref_collect_address="https://test.de", wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1307,6 +1342,7 @@ def test_gathers_valid_preference(): gatherer = PrefCollectGatherer( pref_collect_address="https://test.de", wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, ) preference = 0.5 gatherer._gather_preference = MagicMock(return_value=preference) @@ -1323,6 +1359,7 @@ def test_ignores_incomparable_answer(): gatherer = PrefCollectGatherer( pref_collect_address="https://test.de", wait_for_user=False, + querent_kwargs={"video_output_dir": "videos"}, ) # incomparable preference value = -1 gatherer._gather_preference = MagicMock(return_value=-1.0) @@ -1339,7 +1376,6 @@ def test_ignores_incomparable_answer(): @patch("IPython.display.display") def test_synchronous_human_gatherer(mock_display, mock_input): del mock_display # unused - querent = PreferenceQuerent() gatherer = preference_comparisons.SynchronousHumanGatherer( video_dir=pathlib.Path("."), ) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 1c78e73e2..25bcc86d3 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -295,7 +295,7 @@ def test_raises_assertion_error_if_env_not_in_correct_render_mode(): with pytest.raises( AssertionError, match='The environment must be in render mode "rgb_array" ' - 'in order to use this wrapper ' - f'but render_mode is "{wrong_mode}"', + "in order to use this wrapper " + f'but render_mode is "{wrong_mode}"', ): RenderImageInfoWrapper(env) From 7f6be46279c0a6170fbe014a7da7e1ff935a735c Mon Sep 17 00:00:00 2001 From: rk1a Date: Wed, 3 Jan 2024 21:35:57 +0100 Subject: [PATCH 067/143] add test for remove_rendered_images --- tests/algorithms/test_preference_comparisons.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 00ae768a9..e869231da 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -30,6 +30,7 @@ PreferenceGatherer, PreferenceQuerent, SyntheticGatherer, + remove_rendered_images, write_fragment_video, ) from imitation.data import types @@ -1239,6 +1240,12 @@ def test_write_fragment_video(fragment, codec): os.remove(video_path) +def test_remove_rendered_images(fragment): + trajs = [fragment] + remove_rendered_images(trajs) + assert not any("rendered_img" in info for traj in trajs for info in traj.infos) + + # PreferenceGatherer class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" From ba25756acaaed521c605e69bda02ab0232f72919 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 11 Jan 2024 13:49:35 +0100 Subject: [PATCH 068/143] Add test for preference comparisons with collected preferences --- tests/scripts/test_scripts.py | 37 +++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index ae39116e7..0fa60ae2b 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -11,6 +11,7 @@ import pathlib import pickle import platform +import re import shutil import sys import tempfile @@ -22,6 +23,7 @@ import pandas as pd import pytest import ray.tune as tune +import requests_mock import sacred import sacred.utils import stable_baselines3 @@ -168,6 +170,41 @@ def test_train_preference_comparisons_main(tmpdir, preference_comparison_config) assert isinstance(run.result, dict) +def response_callback(request, context): + return {"query_id": request.path.split("/")[0], "label": 1.0} + + +def test_train_preference_comparisons_with_collected_preferences(tmpdir): + address = "http://localhost:8000" + config_updates = dict( + logging=dict(log_root=tmpdir), + environment=dict(env_make_kwargs=dict(render_mode="rgb_array")), + gatherer_kwargs=dict( + wait_for_user=False, + querent_kwargs=dict(video_output_dir=tmpdir), + pref_collect_address=address, + ), + ) + + with requests_mock.Mocker() as m: + request_matcher = re.compile(f"{address}/preferences/query/") + + m.put(url=request_matcher) + m.get( + url=request_matcher, + json=response_callback, + ) + + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=["cartpole", "human_preferences"] + + ALGO_FAST_CONFIGS["preference_comparison"], + config_updates=config_updates, + ) + + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + @pytest.mark.parametrize( "env_name", ["seals_cartpole", "mountain_car", "seals_mountain_car"], From cc0303919e38d2e20172f97dc0e7c7fa29d0c255 Mon Sep 17 00:00:00 2001 From: rk1a Date: Mon, 22 Jan 2024 22:46:51 +0100 Subject: [PATCH 069/143] Add tests and fix bug for RenderImageWrapper --- src/imitation/data/wrappers.py | 2 +- tests/data/test_wrappers.py | 32 +++++++++++++++++++++++++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 122252f50..1b033c811 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -56,8 +56,8 @@ def step(self, action): rendered_image = self.render() # Scale the render image scaled_size = ( - int(self.scale_factor * rendered_image.shape[0]), int(self.scale_factor * rendered_image.shape[1]), + int(self.scale_factor * rendered_image.shape[0]), ) scaled_rendered_image = cv2.resize( rendered_image, diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 25bcc86d3..82133fd65 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -2,6 +2,8 @@ from typing import List, Sequence, Type +import imageio +from pathlib import Path import gymnasium as gym import numpy as np import pytest @@ -279,13 +281,19 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() - -def test_writes_to_info(): +@pytest.mark.parametrize("scale_factor", [0.1, 0.5, 1.]) +def test_writes_rendered_img_to_info(scale_factor): env = gym.make("CartPole-v0", render_mode="rgb_array") - wrapped_env = RenderImageInfoWrapper(env) + wrapped_env = RenderImageInfoWrapper(env, scale_factor=scale_factor) wrapped_env.reset() + rendered_img = wrapped_env.render() _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample()) assert "rendered_img" in info + assert isinstance(info["rendered_img"], np.ndarray) + if scale_factor == 1.: + assert np.allclose(info["rendered_img"], rendered_img) + assert int(scale_factor * rendered_img.shape[0]) == info["rendered_img"].shape[0] + assert int(scale_factor * rendered_img.shape[1]) == info["rendered_img"].shape[1] def test_raises_assertion_error_if_env_not_in_correct_render_mode(): @@ -294,8 +302,22 @@ def test_raises_assertion_error_if_env_not_in_correct_render_mode(): with pytest.raises( AssertionError, - match='The environment must be in render mode "rgb_array" ' + match="The environment must be in render mode 'rgb_array' " "in order to use this wrapper " - f'but render_mode is "{wrong_mode}"', + f"but render_mode is '{wrong_mode}'.", ): RenderImageInfoWrapper(env) + + +def test_rendered_img_file_cache(): + env = gym.make("CartPole-v0", render_mode="rgb_array") + wrapped_env = RenderImageInfoWrapper(env, use_file_cache=True) + assert Path(wrapped_env.file_cache).exists() + wrapped_env.reset() + _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample()) + rendered_img_path = info["rendered_img"] + assert Path(rendered_img_path).exists() + assert (imageio.imread(rendered_img_path) == wrapped_env.render()).all() + wrapped_env.close() + assert not Path(wrapped_env.file_cache).exists() + From 2632a88b65e95e47e3d2b6367d3467a267c14d69 Mon Sep 17 00:00:00 2001 From: rk1a Date: Mon, 22 Jan 2024 22:47:01 +0100 Subject: [PATCH 070/143] Add missing docstring --- src/imitation/algorithms/preference_comparisons.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 50c380270..e5c7b9c14 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -883,6 +883,17 @@ def _query(self, query_id): def add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: + """Add missing RGB channels if needed. + If less than three channels are present, multiplies the last channel + until all three channels exist. + + Args: + frames: a stack of frames with potentially missing channels; + expected shape (batch, height, width, channels). + + Returns: + a stack of frames with exactly three channels. + """ if frames.shape[-1] < 3: missing_channels = 3 - frames.shape[-1] frames = np.concatenate( @@ -1352,6 +1363,7 @@ def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: rendered_img_info = info["rendered_img"] if isinstance(rendered_img_info, (str, bytes, os.PathLike)): os.remove(rendered_img_info) + del info["rendered_img"] elif isinstance(rendered_img_info, np.ndarray): del info["rendered_img"] From 4876d6012d94771c3fd41ff951585c6e9167d38b Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 24 Jan 2024 17:26:53 +0100 Subject: [PATCH 071/143] Cleanup changes made to docs --- docs/algorithms/airl.rst | 2 +- docs/algorithms/gail.rst | 2 +- docs/algorithms/preference_comparisons.rst | 2 - docs/tutorials/10_train_custom_env.ipynb | 4 +- docs/tutorials/3_train_gail.ipynb | 6 +- docs/tutorials/4_train_airl.ipynb | 6 +- .../5_train_preference_comparisons.ipynb | 555 +----------------- ...rain_preference_comparisons_with_cnn.ipynb | 9 +- 8 files changed, 24 insertions(+), 562 deletions(-) diff --git a/docs/algorithms/airl.rst b/docs/algorithms/airl.rst index eba60b439..3eea0f47e 100644 --- a/docs/algorithms/airl.rst +++ b/docs/algorithms/airl.rst @@ -42,7 +42,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl` "seals:seals/CartPole-v0", rng=np.random.default_rng(SEED), n_envs=8, - post_wrappers={"RolloutInfoWrapper": lambda env, _: RolloutInfoWrapper(env)}, + post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts ) expert = load_policy( "ppo-huggingface", diff --git a/docs/algorithms/gail.rst b/docs/algorithms/gail.rst index 40087fd69..748d8b2f6 100644 --- a/docs/algorithms/gail.rst +++ b/docs/algorithms/gail.rst @@ -39,7 +39,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail` "seals:seals/CartPole-v0", rng=np.random.default_rng(SEED), n_envs=8, - post_wrappers={"RolloutInfoWrapper": lambda env, _: RolloutInfoWrapper(env)}, + post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts ) expert = load_policy( "ppo-huggingface", diff --git a/docs/algorithms/preference_comparisons.rst b/docs/algorithms/preference_comparisons.rst index 643f80b24..f1ae00680 100644 --- a/docs/algorithms/preference_comparisons.rst +++ b/docs/algorithms/preference_comparisons.rst @@ -47,7 +47,6 @@ For a more detailed example, refer to :doc:`../tutorials/5_train_preference_comp fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng) gatherer = preference_comparisons.SyntheticGatherer(rng=rng) - querent = preference_comparisons.PreferenceQuerent() preference_model = preference_comparisons.PreferenceModel(reward_net) reward_trainer = preference_comparisons.BasicRewardTrainer( preference_model=preference_model, @@ -85,7 +84,6 @@ For a more detailed example, refer to :doc:`../tutorials/5_train_preference_comp reward_net, num_iterations=5, # Set to 60 for better performance fragmenter=fragmenter, - preference_querent=querent, preference_gatherer=gatherer, reward_trainer=reward_trainer, initial_epoch_multiplier=4, diff --git a/docs/tutorials/10_train_custom_env.ipynb b/docs/tutorials/10_train_custom_env.ipynb index 92a3b4247..9175721cb 100644 --- a/docs/tutorials/10_train_custom_env.ipynb +++ b/docs/tutorials/10_train_custom_env.ipynb @@ -136,12 +136,12 @@ "\n", "# Create a vectorized environment for training with `imitation`\n", "\n", - "# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)}`\n", + "# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers=[lambda env, _: RolloutInfoWrapper(env)]`\n", "venv = make_vec_env(\n", " \"custom/ObservationMatching-v0\",\n", " rng=np.random.default_rng(),\n", " n_envs=4,\n", - " post_wrappers={\"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)},\n", + " post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", ")\n", "\n", "\n", diff --git a/docs/tutorials/3_train_gail.ipynb b/docs/tutorials/3_train_gail.ipynb index 835b50381..051335fea 100644 --- a/docs/tutorials/3_train_gail.ipynb +++ b/docs/tutorials/3_train_gail.ipynb @@ -37,9 +37,9 @@ " \"seals:seals/CartPole-v0\",\n", " rng=np.random.default_rng(SEED),\n", " n_envs=8,\n", - " post_wrappers={\n", - " \"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)\n", - " }, # needed for computing rollouts later\n", + " post_wrappers=[\n", + " lambda env, _: RolloutInfoWrapper(env)\n", + " ], # needed for computing rollouts later\n", ")\n", "expert = load_policy(\n", " \"ppo-huggingface\",\n", diff --git a/docs/tutorials/4_train_airl.ipynb b/docs/tutorials/4_train_airl.ipynb index 3af6107f6..258358541 100644 --- a/docs/tutorials/4_train_airl.ipynb +++ b/docs/tutorials/4_train_airl.ipynb @@ -41,9 +41,9 @@ " \"seals:seals/CartPole-v0\",\n", " rng=np.random.default_rng(SEED),\n", " n_envs=8,\n", - " post_wrappers={\n", - " \"RolloutInfoWrapper\": lambda env, _: RolloutInfoWrapper(env)\n", - " }, # needed for computing rollouts later\n", + " post_wrappers=[\n", + " lambda env, _: RolloutInfoWrapper(env)\n", + " ], # needed for computing rollouts later\n", ")\n", "expert = load_policy(\n", " \"ppo-huggingface\",\n", diff --git a/docs/tutorials/5_train_preference_comparisons.ipynb b/docs/tutorials/5_train_preference_comparisons.ipynb index abd84ec0b..a84dcfad9 100644 --- a/docs/tutorials/5_train_preference_comparisons.ipynb +++ b/docs/tutorials/5_train_preference_comparisons.ipynb @@ -19,13 +19,8 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "ExecuteTime": { - "end_time": "2023-07-02T18:16:31.401489014Z", - "start_time": "2023-07-02T18:16:26.574553894Z" - } - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "import random\n", @@ -50,7 +45,6 @@ " warning_threshold=0,\n", " rng=rng,\n", ")\n", - "querent = preference_comparisons.PreferenceQuerent()\n", "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", @@ -97,7 +91,6 @@ " reward_net,\n", " num_iterations=5, # Set to 60 for better performance\n", " fragmenter=fragmenter,\n", - " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=100,\n", @@ -118,499 +111,9 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "ExecuteTime": { - "end_time": "2023-07-02T18:17:18.887827019Z", - "start_time": "2023-07-02T18:16:31.404169166Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Query schedule: [20, 51, 41, 34, 29, 25]\n", - "Collecting 40 fragments (4000 transitions)\n", - "Requested 4000 transitions but only 0 in buffer. Sampling 4000 additional transitions.\n", - "Creating fragment pairs\n", - "Gathering preferences\n", - "Dataset now contains 20 comparisons\n" - ] - }, - { - "data": { - "text/plain": "Training reward model: 0%| | 0/3 [00:00" - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "learner = PPO(\n", " seed=0,\n", @@ -697,22 +181,9 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "ExecuteTime": { - "end_time": "2023-07-02T18:17:21.983164120Z", - "start_time": "2023-07-02T18:17:20.917257800Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-1057.3085665\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from stable_baselines3.common.evaluation import evaluate_policy\n", "\n", diff --git a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb index f72b9ca50..4b6efb82f 100644 --- a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb +++ b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb @@ -24,12 +24,7 @@ "cell_type": "code", "execution_count": null, "id": "93187e19", - "metadata": { - "ExecuteTime": { - "end_time": "2023-07-02T18:29:25.396097810Z", - "start_time": "2023-07-02T18:29:18.742766046Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "import torch as th\n", @@ -80,7 +75,6 @@ ").to(device)\n", "\n", "fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng)\n", - "querent = preference_comparisons.PreferenceQuerent()\n", "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", @@ -114,7 +108,6 @@ " reward_net,\n", " num_iterations=2,\n", " fragmenter=fragmenter,\n", - " preference_querent=querent,\n", " preference_gatherer=gatherer,\n", " reward_trainer=reward_trainer,\n", " fragment_length=10,\n", From 356defadf845bc7ffc67fc84ca637bd18ca4560e Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Thu, 15 Feb 2024 16:28:31 +0100 Subject: [PATCH 072/143] Remove benchmark_summary.md --- docs/main-concepts/benchmark_summary.md | 69 ------------------------- 1 file changed, 69 deletions(-) delete mode 100644 docs/main-concepts/benchmark_summary.md diff --git a/docs/main-concepts/benchmark_summary.md b/docs/main-concepts/benchmark_summary.md deleted file mode 100644 index a2480f292..000000000 --- a/docs/main-concepts/benchmark_summary.md +++ /dev/null @@ -1,69 +0,0 @@ -# Benchmark Summary - -This is a summary of the sacred runs in `benchmark_runs` generated by `sacred_output_to_markdown_summary.py`. -## Scores - -The scores are normalized based on the performance of a random agent as the baseline and the expert as the maximum possible score as explained [in this blog post](https://araffin.github.io/post/rliable/): -> `(score - random_score) / (expert_score - random_score)` - -Aggregate scores and confidence intervals are computed using the [rliable library](https://agarwl.github.io/rliable/). -### AIRL -Environment | Score (mean/std)| Normalized Score (mean/std) | N - --- | --- | --- | --- -seals/Ant-v1 | 2485.889 / 533.471 | 0.981 / 0.184 | 10 -seals/HalfCheetah-v1 | 938.450 / 804.871 | 0.627 / 0.412 | 10 -seals/Hopper-v1 | 183.780 / 93.295 | 0.921 / 0.373 | 10 -seals/Swimmer-v1 | 286.699 / 7.763 | 0.970 / 0.027 | 10 -seals/Walker2d-v1 | 1154.921 / 659.564 | 0.461 / 0.264 | 10 - -#### Aggregate Normalized scores -Metric | Value | 95% CI - --- | --- | --- -Mean | 0.792 | [0.709, 0.792] -IQM | 0.918 | [0.871, 0.974] - -### BC -Environment | Score (mean/std)| Normalized Score (mean/std) | N - --- | --- | --- | --- -seals/Ant-v1 | 2090.551 / 180.340 | 0.844 / 0.062 | 10 -seals/HalfCheetah-v1 | 1516.476 / 37.487 | 0.923 / 0.019 | 10 -seals/Hopper-v1 | 204.271 / 0.609 | 1.003 / 0.002 | 10 -seals/Swimmer-v1 | 276.242 / 9.328 | 0.935 / 0.032 | 10 -seals/Walker2d-v1 | 2393.254 / 37.641 | 0.956 / 0.015 | 10 - -#### Aggregate Normalized scores -Metric | Value | 95% CI - --- | --- | --- -Mean | 0.932 | [0.922, 0.932] -IQM | 0.941 | [0.941, 0.949] - -### DAGGER -Environment | Score (mean/std)| Normalized Score (mean/std) | N - --- | --- | --- | --- -seals/Ant-v1 | 2302.527 / 108.315 | 0.957 / 0.052 | 10 -seals/HalfCheetah-v1 | 1615.004 / 8.262 | 1.017 / 0.008 | 10 -seals/Hopper-v1 | 204.789 / 1.599 | 1.011 / 0.012 | 10 -seals/Swimmer-v1 | 283.776 / 6.524 | 0.988 / 0.024 | 10 -seals/Walker2d-v1 | 2419.748 / 52.215 | 1.002 / 0.026 | 10 - -#### Aggregate Normalized scores -Metric | Value | 95% CI - --- | --- | --- -Mean | 0.995 | [0.987, 0.998] -IQM | 1.004 | [1.003, 1.008] - -### GAIL -Environment | Score (mean/std)| Normalized Score (mean/std) | N - --- | --- | --- | --- -seals/Ant-v1 | 2527.566 / 148.034 | 0.995 / 0.051 | 10 -seals/HalfCheetah-v1 | 1595.129 / 37.374 | 0.963 / 0.019 | 10 -seals/Hopper-v1 | 187.105 / 14.298 | 0.935 / 0.057 | 10 -seals/Swimmer-v1 | 249.949 / 74.295 | 0.845 / 0.254 | 10 -seals/Walker2d-v1 | 2399.196 / 89.949 | 0.959 / 0.036 | 10 - -#### Aggregate Normalized scores -Metric | Value | 95% CI - --- | --- | --- -Mean | 0.939 | [0.900, 0.944] -IQM | 0.957 | [0.965, 0.970] - From 4c7031e0095bdb713fa499ab1e6a9aa50e101794 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 13:45:19 -0600 Subject: [PATCH 073/143] Initial commit of Zooniverse preference comparisons --- .../algorithms/preference_comparisons.py | 180 ++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e5c7b9c14..c864aa13b 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -54,6 +54,14 @@ from imitation.util import logger as imit_logger from imitation.util import networks, util +from panoptes_client import ( + Panoptes, + Project, + Workflow, + Classification, + SubjectSet, + Subject +) class TrajectoryGenerator(abc.ABC): """Generator of trajectories with optional training logic.""" @@ -882,6 +890,73 @@ def _query(self, query_id): ) +class ZooniverseQuerent(PrefCollectQuerent): + """Sends queries to the Zooniverse interface.""" + + def __init__( + self, + pref_collect_address: str, + zoo_project_id: int, + zoo_workflow_id: int, + linked_subject_set_id: int, + retired_subject_set_id: int, + experiment_id: int, + video_output_dir: AnyPath, + video_fps: str = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ): + super().__init__(pref_collect_address, video_output_dir, video_fps, rng, custom_logger) + self.zoo_project_id = zoo_project_id + self.zoo_workflow_id = zoo_workflow_id + self.linked_subject_set_id = linked_subject_set_id + self.retired_subject_set_id = retired_subject_set_id + self.experiment_id = experiment_id + + # Create video directory + os.makedirs(self.video_output_dir, exist_ok=True) + + # Authenticate with Zooniverse + panoptes_username = os.environ["PANOPTES_USERNAME"] + panoptes_password = os.environ["PANOPTES_PASSWORD"] + Panoptes.connect(username=panoptes_username, password=panoptes_password) + + def _query(self, query_id): + # Find project and workflow + project = Project.find(self.zoo_project_id) + workflow = Workflow.find(self.zoo_workflow_id) + + # Find subject sets + linked_subject_set = SubjectSet.find(self.linked_subject_set_id) + + # Create subject for this query_id + subject = Subject() + subject.links.project = project + + output_file_name = os.path.join( + self.video_output_dir, f"{query_id}" + "{}.webm" + ) + + subject.add_location(open(output_file_name.format("left"), "rb")) + subject.add_location(open(output_file_name.format("right"), "rb")) + + subject.metadata["query_id"] = f"{query_id}" + subject.metadata["#left_video"] = output_file_name.format("left") + subject.metadata["#right_video"] = output_file_name.format("right") + subject.metadata["#video_fps"] = self.video_fps + subject.metadata["#zoo_project_id"] = self.zoo_project_id + subject.metadata["#zoo_workflow_id"] = self.zoo_workflow_id + subject.metadata["#linked_subject_set_id"] = self.linked_subject_set_id + subject.metadata["#retired_subject_set_id"] = self.retired_subject_set_id + subject.metadata["#linked_subject_set_id"] = self.linked_subject_set_id + subject.metadata["#experiment_id"] = self.experiment_id + + subject.save() + + # Add the subject to the linked subject set + linked_subject_set.add(subject) + + def add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: """Add missing RGB channels if needed. If less than three channels are present, multiplies the last channel @@ -1355,6 +1430,111 @@ def _gather_preference(self, query_id: str) -> float: return answered_query["label"] +class ZooniverseGatherer(PrefCollectGatherer): + """Gathers preferences from Zooniverse interface.""" + + def __init__( + self, + pref_collect_address: str, + zoo_project_id: int, + zoo_workflow_id: int, + linked_subject_set_id: int, + retired_subject_set_id: int, + experiment_id: int, + wait_for_user: bool = True, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ) -> None: + """Initializes the preference gatherer. + + Args: + pref_collect_address: Network address to PrefCollect instance. + wait_for_user: Waits for user to input their preferences. + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + """ + super().__init__(pref_collect_address, wait_for_user, rng, custom_logger) + self.zoo_project_id = zoo_project_id + self.zoo_workflow_id = zoo_workflow_id + self.linked_subject_set_id = linked_subject_set_id + self.retired_subject_set_id = retired_subject_set_id + self.experiment_id = experiment_id + + # Authenticate with Zooniverse + panoptes_username = os.environ["PANOPTES_USERNAME"] + panoptes_password = os.environ["PANOPTES_PASSWORD"] + Panoptes.connect(username=panoptes_username, password=panoptes_password) + + self._process_zoo_classifications(self, last_id=0) + + # Define annotation to label map + self.annotation_to_label = { + "Left is better.": 1, + "Right is better.": 0, + "Indifferent.": .5, + "Incomparable.": -1 + } + + def _process_zoo_classifications(self, last_id=0): + + # The default last_id is 0 meaning process all classifications the project has + # recieved for the specified workflow. + # TODO: make last_id trackable to avoid processing all classifications each time + # the gatherer is called. + + # Access classifications from the last_id + classifications = Classification.where(last_id=last_id, scope='project', + project_id=pid, workflow_id=wid) + + # Find workflow + self.workflow = Workflow.find(self.zoo_workflow_id) + + self.subject_to_query = {} + self.subject_to_annotations = {} + for c in classifications: + d = c.raw + # Extract subject id + sid = int(d["links"]["subjects"][0]) + # Get subject status + status = self.workflow.subject_workflow_status(sid) + # Check that subject is retired + if status.raw["retirement_reason"] is not None: + label = self.annotation_to_label[d["annotations"][0]["value"]] + try: + # Add label for this classification for the subject + self.subject_to_annotations[sid].append(label) + except KeyError: + # Get query_id for this subject and add it to map + subject = Subject.find(sid) + self.subject_to_query[sid] = subject.raw["metadata"]["query_id"] + # Create map entry for this subject + self.subject_to_annotations[sid] = [label] + + def _gather_preference(self, query_id: str) -> float: + + # Find subject sets + linked_subject_set = SubjectSet.find(self.linked_subject_set_id) + retired_subject_set = SubjectSet.find(self.retired_subject_set_id) + + # Get subject_id corresponding to query_id + subject_id = self.subject_to_query[query_id] + + # Get reduced_label for subject_id aggregated from each annotation for that subject + reduced_label = self._reduce_annotations(self.subject_to_annotations[subject_id]) + + # Remove this subject from the subject set linked to the workflow + linked_subject_set.remove([subject_id]) + + # Add subject to the subject set for completed subjects + retired_subject_set.add([subject_id]) + + return reduced_label + + def _reduce_annotations(self, annotations): + count = Counter(annotations) + return count.most_common(1)[0][0] + + def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: """Removes rendered images of the provided trajectories list.""" for traj in trajectories: From 821d1a322b26b3bda6c91a8be74bf0bf38c85bb5 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 14:23:37 -0600 Subject: [PATCH 074/143] Add querent_kwargs to ZooniverseGatherer --- src/imitation/algorithms/preference_comparisons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index c864aa13b..8f758ce37 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1444,6 +1444,7 @@ def __init__( wait_for_user: bool = True, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None ) -> None: """Initializes the preference gatherer. @@ -1453,7 +1454,7 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(pref_collect_address, wait_for_user, rng, custom_logger) + super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, querent_kwargs) self.zoo_project_id = zoo_project_id self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id From b68d61dce052985027a6df943b46870834390418 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 14:55:53 -0600 Subject: [PATCH 075/143] Refactor Zooniverse elements handling --- src/imitation/algorithms/preference_comparisons.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 8f758ce37..40c4fcf74 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1440,7 +1440,6 @@ def __init__( zoo_workflow_id: int, linked_subject_set_id: int, retired_subject_set_id: int, - experiment_id: int, wait_for_user: bool = True, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, From 1cf4ca2b1b372e379f138e9abf05fcb75c627a98 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 15:12:13 -0600 Subject: [PATCH 076/143] Override PrefCollectGatherer self.querent --- src/imitation/algorithms/preference_comparisons.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 40c4fcf74..1b1b9337d 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1454,6 +1454,9 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. """ super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, querent_kwargs) + self.querent = ZooniverseQuerent( + **querent_kwargs, + ) self.zoo_project_id = zoo_project_id self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id From a2856393adcce2831e60b1f1d1550e94db7e6a19 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 15:20:39 -0600 Subject: [PATCH 077/143] Remove querent_kwargs from super().__init__ --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 1b1b9337d..14a495038 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1453,7 +1453,7 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, querent_kwargs) + super().__init__(pref_collect_address, wait_for_user, rng, custom_logger) self.querent = ZooniverseQuerent( **querent_kwargs, ) From ec126d8e44660b23781d9e81dbdb8f0a93d3ec95 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 15:28:31 -0600 Subject: [PATCH 078/143] Handle querent_kwargs requirements --- src/imitation/algorithms/preference_comparisons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 14a495038..a6695568a 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1453,7 +1453,8 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(pref_collect_address, wait_for_user, rng, custom_logger) + super_querent_kwargs = {"video_output_dir": querent_kwargs["video_output_dir"]} + super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, super_querent_kwargs) self.querent = ZooniverseQuerent( **querent_kwargs, ) From f0946f04d1494e3f54138f014246d5996953b200 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:28:36 -0600 Subject: [PATCH 079/143] Handle querent_kwargs requirements --- src/imitation/algorithms/preference_comparisons.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index a6695568a..4cd6d592e 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1453,11 +1453,12 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super_querent_kwargs = {"video_output_dir": querent_kwargs["video_output_dir"]} - super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, super_querent_kwargs) + #super_querent_kwargs = {"video_output_dir": querent_kwargs["video_output_dir"]} + super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, **querent_kwargs) self.querent = ZooniverseQuerent( - **querent_kwargs, + **querent_kwargs ) + self.zoo_project_id = zoo_project_id self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id From 16416ecd64cea3725de6ed665de333d399e2ed87 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:29:55 -0600 Subject: [PATCH 080/143] Pass pref_collect_address to ZooniverseQuerent --- src/imitation/algorithms/preference_comparisons.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 4cd6d592e..1732cf8e4 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1456,6 +1456,7 @@ def __init__( #super_querent_kwargs = {"video_output_dir": querent_kwargs["video_output_dir"]} super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, **querent_kwargs) self.querent = ZooniverseQuerent( + pref_collect_address, **querent_kwargs ) From 75db9396af7abb48c57e36230aef1046901d37a5 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:31:06 -0600 Subject: [PATCH 081/143] Handle querent_kwargs requirements --- src/imitation/algorithms/preference_comparisons.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 1732cf8e4..3a27e0c1c 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1453,8 +1453,8 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - #super_querent_kwargs = {"video_output_dir": querent_kwargs["video_output_dir"]} - super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, **querent_kwargs) + super_querent_kwargs = {"video_output_dir": querent_kwargs["video_output_dir"]} + super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, **super_querent_kwargs) self.querent = ZooniverseQuerent( pref_collect_address, **querent_kwargs From ba48d73d851b3d894cc7aedbd445a8c8b9a6a963 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:33:57 -0600 Subject: [PATCH 082/143] ZooniverseGatherer fix super().__init__ call. --- src/imitation/algorithms/preference_comparisons.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 3a27e0c1c..c311f25af 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1453,8 +1453,8 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super_querent_kwargs = {"video_output_dir": querent_kwargs["video_output_dir"]} - super().__init__(pref_collect_address, wait_for_user, rng, custom_logger, **super_querent_kwargs) + video_output_dir = querent_kwargs["video_output_dir"] + super().__init__(pref_collect_address, video_output_dir, wait_for_user, rng, custom_logger) self.querent = ZooniverseQuerent( pref_collect_address, **querent_kwargs From 3634e608daa45d44bd18c816d42ebfe71634d0e8 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:35:20 -0600 Subject: [PATCH 083/143] ZooniverseGatherer fix super().__init__ call. --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index c311f25af..614142fd1 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1454,7 +1454,7 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. """ video_output_dir = querent_kwargs["video_output_dir"] - super().__init__(pref_collect_address, video_output_dir, wait_for_user, rng, custom_logger) + super().__init__(pref_collect_address, video_output_dir, wait_for_user, rng, custom_logger, querent_kwargs=None) self.querent = ZooniverseQuerent( pref_collect_address, **querent_kwargs From d04b5056b5ab7319b27ca261c4857fc5ccc7de43 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:38:47 -0600 Subject: [PATCH 084/143] ZooniverseGatherer fix super().__init__ call. --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 614142fd1..8f8fa6d06 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1454,7 +1454,7 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. """ video_output_dir = querent_kwargs["video_output_dir"] - super().__init__(pref_collect_address, video_output_dir, wait_for_user, rng, custom_logger, querent_kwargs=None) + super().__init__(pref_collect_address, video_output_dir, querent_kwargs=None) self.querent = ZooniverseQuerent( pref_collect_address, **querent_kwargs From 5bd2cbd71b090a207f0ea50a6b7ce0095bff19f7 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:40:36 -0600 Subject: [PATCH 085/143] ZooniverseGatherer fix super().__init__ call. --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 8f8fa6d06..091275a1d 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1454,7 +1454,7 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. """ video_output_dir = querent_kwargs["video_output_dir"] - super().__init__(pref_collect_address, video_output_dir, querent_kwargs=None) + super().__init__(pref_collect_address, querent_kwargs={"video_output_dir": querent_kwargs["video_output_dir"]}) self.querent = ZooniverseQuerent( pref_collect_address, **querent_kwargs From 9423e9b01367d667a10fdc80dfd47c21237a5ba3 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:42:21 -0600 Subject: [PATCH 086/143] Remove experiment_id attr from ZooniverseGatherer. --- src/imitation/algorithms/preference_comparisons.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 091275a1d..9caa75bdf 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1464,7 +1464,6 @@ def __init__( self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id self.retired_subject_set_id = retired_subject_set_id - self.experiment_id = experiment_id # Authenticate with Zooniverse panoptes_username = os.environ["PANOPTES_USERNAME"] From 90760fdb6b164a868882a134870f927ad8bcd44c Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:43:17 -0600 Subject: [PATCH 087/143] Remove self from method call. --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 9caa75bdf..a0df9355f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1470,7 +1470,7 @@ def __init__( panoptes_password = os.environ["PANOPTES_PASSWORD"] Panoptes.connect(username=panoptes_username, password=panoptes_password) - self._process_zoo_classifications(self, last_id=0) + self._process_zoo_classifications(last_id=0) # Define annotation to label map self.annotation_to_label = { From f00ad8ffb8a20901ade7cd6671632d51dc671b66 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 19:44:31 -0600 Subject: [PATCH 088/143] Fix incorrect var names --- src/imitation/algorithms/preference_comparisons.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index a0df9355f..9d04a3274 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1488,8 +1488,12 @@ def _process_zoo_classifications(self, last_id=0): # the gatherer is called. # Access classifications from the last_id - classifications = Classification.where(last_id=last_id, scope='project', - project_id=pid, workflow_id=wid) + classifications = Classification.where( + last_id=last_id, + scope='project', + project_id=self.zoo_project_id, + workflow_id=self.zoo_workflow_id + ) # Find workflow self.workflow = Workflow.find(self.zoo_workflow_id) From 684550d466b3f50a8272ac5fa89039ccce9b0ef1 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 29 Feb 2024 20:08:00 -0600 Subject: [PATCH 089/143] Remove repeated makedirs call. --- src/imitation/algorithms/preference_comparisons.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 9d04a3274..0c759a6a6 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -912,9 +912,6 @@ def __init__( self.linked_subject_set_id = linked_subject_set_id self.retired_subject_set_id = retired_subject_set_id self.experiment_id = experiment_id - - # Create video directory - os.makedirs(self.video_output_dir, exist_ok=True) # Authenticate with Zooniverse panoptes_username = os.environ["PANOPTES_USERNAME"] From 4abc06d12e1cba6e63e31d70b5d277cb6ad9db4a Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 11:01:30 -0600 Subject: [PATCH 090/143] Fix output_file_name --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 0c759a6a6..8d8546b00 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -931,7 +931,7 @@ def _query(self, query_id): subject.links.project = project output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "{}.webm" + self.video_output_dir, f"{query_id}" + "-{}.webm" ) subject.add_location(open(output_file_name.format("left"), "rb")) From 4430c3211ce86009586a4e66cad115040c5840ee Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 11:03:16 -0600 Subject: [PATCH 091/143] Make video_fps an attribute --- src/imitation/algorithms/preference_comparisons.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 8d8546b00..d36be5e79 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -912,6 +912,7 @@ def __init__( self.linked_subject_set_id = linked_subject_set_id self.retired_subject_set_id = retired_subject_set_id self.experiment_id = experiment_id + self.video_fps = video_fps # Authenticate with Zooniverse panoptes_username = os.environ["PANOPTES_USERNAME"] From ab23fdabda086fa6b04b0f0b04942ecd0f2e030e Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 11:10:29 -0600 Subject: [PATCH 092/143] Override Querent __call__ to write .mp4 --- .../algorithms/preference_comparisons.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index d36be5e79..84c8e3068 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -919,6 +919,32 @@ def __init__( panoptes_password = os.environ["PANOPTES_PASSWORD"] Panoptes.connect(username=panoptes_username, password=panoptes_password) + def __call__( + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: + identified_queries = super().__call__(queries) + + # Save fragment videos and submit queries + for query_id, query in identified_queries.items(): + output_file_name = os.path.join( + self.video_output_dir, + f"{query_id}" + "-{}.mp4", + ) + write_fragment_video( + query[0], + frames_per_second=self.frames_per_second, + output_path=output_file_name.format("left"), + ) + write_fragment_video( + query[1], + frames_per_second=self.frames_per_second, + output_path=output_file_name.format("right"), + ) + self._query(query_id) + + return identified_queries + def _query(self, query_id): # Find project and workflow project = Project.find(self.zoo_project_id) @@ -932,7 +958,7 @@ def _query(self, query_id): subject.links.project = project output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "-{}.webm" + self.video_output_dir, f"{query_id}" + "-{}.mp4" ) subject.add_location(open(output_file_name.format("left"), "rb")) From 5db471f5632e4d80ca07a8045787308f01157918 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 11:19:20 -0600 Subject: [PATCH 093/143] Call PreferenceQuerent not PrefCollectQuerent --- src/imitation/algorithms/preference_comparisons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 84c8e3068..a2ed633f9 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -923,7 +923,8 @@ def __call__( self, queries: Sequence[TrajectoryWithRewPair], ) -> Dict[str, TrajectoryWithRewPair]: - identified_queries = super().__call__(queries) + # Call PreferenceQuerent not PrefCollectQuerent + identified_queries = super(PrefCollectQuerent, self).__call__(queries) # Save fragment videos and submit queries for query_id, query in identified_queries.items(): From 26da52117e05665e2654a854d9f24e08300f2167 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 11:52:26 -0600 Subject: [PATCH 094/143] Write .webm --- src/imitation/algorithms/preference_comparisons.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index a2ed633f9..af0fa946f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -930,7 +930,7 @@ def __call__( for query_id, query in identified_queries.items(): output_file_name = os.path.join( self.video_output_dir, - f"{query_id}" + "-{}.mp4", + f"{query_id}" + "-{}.webm", ) write_fragment_video( query[0], @@ -959,7 +959,7 @@ def _query(self, query_id): subject.links.project = project output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "-{}.mp4" + self.video_output_dir, f"{query_id}" + "-{}.webm" ) subject.add_location(open(output_file_name.format("left"), "rb")) From 3b51a525a85e7024568fb4ee9cbb88d9d9c062e7 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 14:16:49 -0600 Subject: [PATCH 095/143] Write .mp4 --- src/imitation/algorithms/preference_comparisons.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index af0fa946f..a2ed633f9 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -930,7 +930,7 @@ def __call__( for query_id, query in identified_queries.items(): output_file_name = os.path.join( self.video_output_dir, - f"{query_id}" + "-{}.webm", + f"{query_id}" + "-{}.mp4", ) write_fragment_video( query[0], @@ -959,7 +959,7 @@ def _query(self, query_id): subject.links.project = project output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "-{}.webm" + self.video_output_dir, f"{query_id}" + "-{}.mp4" ) subject.add_location(open(output_file_name.format("left"), "rb")) From 316f4b335fb73cd25a5b405900040e84ab5ecc89 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 14:45:37 -0600 Subject: [PATCH 096/143] Minor refactor --- src/imitation/algorithms/preference_comparisons.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index a2ed633f9..e0ff165b4 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1490,13 +1490,6 @@ def __init__( self.linked_subject_set_id = linked_subject_set_id self.retired_subject_set_id = retired_subject_set_id - # Authenticate with Zooniverse - panoptes_username = os.environ["PANOPTES_USERNAME"] - panoptes_password = os.environ["PANOPTES_PASSWORD"] - Panoptes.connect(username=panoptes_username, password=panoptes_password) - - self._process_zoo_classifications(last_id=0) - # Define annotation to label map self.annotation_to_label = { "Left is better.": 1, @@ -1505,6 +1498,13 @@ def __init__( "Incomparable.": -1 } + # Authenticate with Zooniverse + panoptes_username = os.environ["PANOPTES_USERNAME"] + panoptes_password = os.environ["PANOPTES_PASSWORD"] + Panoptes.connect(username=panoptes_username, password=panoptes_password) + + self._process_zoo_classifications(last_id=0) + def _process_zoo_classifications(self, last_id=0): # The default last_id is 0 meaning process all classifications the project has From 6392ea4f234a3dc8026ee42e4e5880d1e96841bd Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 15:31:18 -0600 Subject: [PATCH 097/143] Fix annotation_to_label mapping --- src/imitation/algorithms/preference_comparisons.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e0ff165b4..4af4e9502 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1492,10 +1492,10 @@ def __init__( # Define annotation to label map self.annotation_to_label = { - "Left is better.": 1, - "Right is better.": 0, - "Indifferent.": .5, - "Incomparable.": -1 + 0: 1, + 1: 0, + 2: .5, + 3: -1 } # Authenticate with Zooniverse @@ -1527,6 +1527,7 @@ def _process_zoo_classifications(self, last_id=0): self.subject_to_annotations = {} for c in classifications: d = c.raw + print(d) # Extract subject id sid = int(d["links"]["subjects"][0]) # Get subject status From 92f334126e6bd2ac6cf63e53cb85abfcbbde078b Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 15:31:42 -0600 Subject: [PATCH 098/143] Fix annotation_to_label mapping --- src/imitation/algorithms/preference_comparisons.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 4af4e9502..b09f60311 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1527,7 +1527,6 @@ def _process_zoo_classifications(self, last_id=0): self.subject_to_annotations = {} for c in classifications: d = c.raw - print(d) # Extract subject id sid = int(d["links"]["subjects"][0]) # Get subject status From 9044f5ee9d55273e941672c35dc8545a5a581805 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 15:37:13 -0600 Subject: [PATCH 099/143] Invert subjet_to_query map --- src/imitation/algorithms/preference_comparisons.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index b09f60311..e8ecb3f8b 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1523,7 +1523,7 @@ def _process_zoo_classifications(self, last_id=0): # Find workflow self.workflow = Workflow.find(self.zoo_workflow_id) - self.subject_to_query = {} + self.query_to_subject = {} self.subject_to_annotations = {} for c in classifications: d = c.raw @@ -1540,7 +1540,7 @@ def _process_zoo_classifications(self, last_id=0): except KeyError: # Get query_id for this subject and add it to map subject = Subject.find(sid) - self.subject_to_query[sid] = subject.raw["metadata"]["query_id"] + self.query_to_subject[subject.raw["metadata"]["query_id"]] = sid # Create map entry for this subject self.subject_to_annotations[sid] = [label] From 359a27032e480d58fbee657bba7fb8d2c419fc74 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 15:42:05 -0600 Subject: [PATCH 100/143] Invert subjet_to_query map --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e8ecb3f8b..d5755a9a3 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1551,7 +1551,7 @@ def _gather_preference(self, query_id: str) -> float: retired_subject_set = SubjectSet.find(self.retired_subject_set_id) # Get subject_id corresponding to query_id - subject_id = self.subject_to_query[query_id] + subject_id = self.query_to_subject[query_id] # Get reduced_label for subject_id aggregated from each annotation for that subject reduced_label = self._reduce_annotations(self.subject_to_annotations[subject_id]) From 04d9e239973627a8f32ebeae919a4bb9a6ce6b40 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 15:52:58 -0600 Subject: [PATCH 101/143] Process classifications with each gather call --- src/imitation/algorithms/preference_comparisons.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index d5755a9a3..c79883a0e 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1498,13 +1498,13 @@ def __init__( 3: -1 } + self.query_to_subject = None + # Authenticate with Zooniverse panoptes_username = os.environ["PANOPTES_USERNAME"] panoptes_password = os.environ["PANOPTES_PASSWORD"] Panoptes.connect(username=panoptes_username, password=panoptes_password) - self._process_zoo_classifications(last_id=0) - def _process_zoo_classifications(self, last_id=0): # The default last_id is 0 meaning process all classifications the project has @@ -1546,11 +1546,17 @@ def _process_zoo_classifications(self, last_id=0): def _gather_preference(self, query_id: str) -> float: + # Without last_id trakcing this must be called each time to ensure latest + # classifications are included. This could become time consuming if many + # classifications have been submitted to the project. + self._process_zoo_classifications(last_id=0) + # Find subject sets linked_subject_set = SubjectSet.find(self.linked_subject_set_id) retired_subject_set = SubjectSet.find(self.retired_subject_set_id) # Get subject_id corresponding to query_id + print(self.query_to_subject) subject_id = self.query_to_subject[query_id] # Get reduced_label for subject_id aggregated from each annotation for that subject From 93e86bd81568accab991331edada80889e3028c3 Mon Sep 17 00:00:00 2001 From: Darryl Date: Mon, 4 Mar 2024 15:56:29 -0600 Subject: [PATCH 102/143] Fix Counter import --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index c79883a0e..8e33063e6 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -38,6 +38,7 @@ from torch import nn from torch.utils import data as data_th from tqdm.auto import tqdm +from collections import Counter from imitation.algorithms import base from imitation.data import rollout, types, wrappers @@ -1556,7 +1557,6 @@ def _gather_preference(self, query_id: str) -> float: retired_subject_set = SubjectSet.find(self.retired_subject_set_id) # Get subject_id corresponding to query_id - print(self.query_to_subject) subject_id = self.query_to_subject[query_id] # Get reduced_label for subject_id aggregated from each annotation for that subject From 6a3c3cd3f5f70a275b2d0e5fda783346c5807534 Mon Sep 17 00:00:00 2001 From: Darryl Date: Tue, 5 Mar 2024 09:48:01 -0600 Subject: [PATCH 103/143] Allow writing .gif --- src/imitation/algorithms/preference_comparisons.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 8e33063e6..983d433d8 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1035,7 +1035,10 @@ def write_fragment_video( # lists of image paths or numpy arrays clip = ImageSequenceClip(frames_list, fps=frames_per_second) moviepy_logger = None if not progress_logger else "bar" - clip.write_videofile(output_path, logger=moviepy_logger) + if output_path.endswith('.gif'): + clip.write_gif(output_path,fps=frames_per_second) + else: + clip.write_videofile(output_path, logger=moviepy_logger) class PreferenceGatherer(abc.ABC): From a28fb200b8377bd194b92aa3d02952792d626a73 Mon Sep 17 00:00:00 2001 From: Darryl Date: Tue, 5 Mar 2024 09:49:23 -0600 Subject: [PATCH 104/143] Allow writing .gif --- src/imitation/algorithms/preference_comparisons.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 983d433d8..63041a6a6 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -931,7 +931,7 @@ def __call__( for query_id, query in identified_queries.items(): output_file_name = os.path.join( self.video_output_dir, - f"{query_id}" + "-{}.mp4", + f"{query_id}" + "-{}.gif", ) write_fragment_video( query[0], @@ -960,7 +960,7 @@ def _query(self, query_id): subject.links.project = project output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "-{}.mp4" + self.video_output_dir, f"{query_id}" + "-{}.gif" ) subject.add_location(open(output_file_name.format("left"), "rb")) From fa69e575c3416254508156f1f39dc546dcb040fe Mon Sep 17 00:00:00 2001 From: Darryl Date: Tue, 5 Mar 2024 09:55:11 -0600 Subject: [PATCH 105/143] Remove fps from write_gif call --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 63041a6a6..b225f1d65 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1036,7 +1036,7 @@ def write_fragment_video( clip = ImageSequenceClip(frames_list, fps=frames_per_second) moviepy_logger = None if not progress_logger else "bar" if output_path.endswith('.gif'): - clip.write_gif(output_path,fps=frames_per_second) + clip.write_gif(output_path) else: clip.write_videofile(output_path, logger=moviepy_logger) From b7632f33e526707f9bffcd23ff5040ca47e77bf7 Mon Sep 17 00:00:00 2001 From: Darryl Date: Tue, 5 Mar 2024 10:00:57 -0600 Subject: [PATCH 106/143] Add ffmpeg and logger to write_gif call --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index b225f1d65..bab5daaa4 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1036,7 +1036,7 @@ def write_fragment_video( clip = ImageSequenceClip(frames_list, fps=frames_per_second) moviepy_logger = None if not progress_logger else "bar" if output_path.endswith('.gif'): - clip.write_gif(output_path) + clip.write_gif(output_path, logger=moviepy_logger) else: clip.write_videofile(output_path, logger=moviepy_logger) From 8e5142700e3eab61440e6c68925da45aa87e7aef Mon Sep 17 00:00:00 2001 From: Darryl Date: Tue, 5 Mar 2024 10:03:08 -0600 Subject: [PATCH 107/143] Add ffmpeg and logger to write_gif call --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index bab5daaa4..138c9db5d 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1036,7 +1036,7 @@ def write_fragment_video( clip = ImageSequenceClip(frames_list, fps=frames_per_second) moviepy_logger = None if not progress_logger else "bar" if output_path.endswith('.gif'): - clip.write_gif(output_path, logger=moviepy_logger) + clip.write_gif(output_path, program='ffmpeg', logger=moviepy_logger) else: clip.write_videofile(output_path, logger=moviepy_logger) From 354ea10c8ca1f6fb1dc2a185ea65e4debdb5097e Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Fri, 8 Mar 2024 13:22:41 -0600 Subject: [PATCH 108/143] Zoo authenticate for each query and gather call to avoid timeout --- .../algorithms/preference_comparisons.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 138c9db5d..fe662648d 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -914,11 +914,6 @@ def __init__( self.retired_subject_set_id = retired_subject_set_id self.experiment_id = experiment_id self.video_fps = video_fps - - # Authenticate with Zooniverse - panoptes_username = os.environ["PANOPTES_USERNAME"] - panoptes_password = os.environ["PANOPTES_PASSWORD"] - Panoptes.connect(username=panoptes_username, password=panoptes_password) def __call__( self, @@ -948,6 +943,12 @@ def __call__( return identified_queries def _query(self, query_id): + + # Authenticate with Zooniverse + panoptes_username = os.environ["PANOPTES_USERNAME"] + panoptes_password = os.environ["PANOPTES_PASSWORD"] + Panoptes.connect(username=panoptes_username, password=panoptes_password) + # Find project and workflow project = Project.find(self.zoo_project_id) workflow = Workflow.find(self.zoo_workflow_id) @@ -1504,11 +1505,6 @@ def __init__( self.query_to_subject = None - # Authenticate with Zooniverse - panoptes_username = os.environ["PANOPTES_USERNAME"] - panoptes_password = os.environ["PANOPTES_PASSWORD"] - Panoptes.connect(username=panoptes_username, password=panoptes_password) - def _process_zoo_classifications(self, last_id=0): # The default last_id is 0 meaning process all classifications the project has @@ -1549,6 +1545,11 @@ def _process_zoo_classifications(self, last_id=0): self.subject_to_annotations[sid] = [label] def _gather_preference(self, query_id: str) -> float: + + # Authenticate with Zooniverse + panoptes_username = os.environ["PANOPTES_USERNAME"] + panoptes_password = os.environ["PANOPTES_PASSWORD"] + Panoptes.connect(username=panoptes_username, password=panoptes_password) # Without last_id trakcing this must be called each time to ensure latest # classifications are included. This could become time consuming if many From 1bca7bf8cbf8b7588d16de8dfc27a30537706db2 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Fri, 8 Mar 2024 14:27:20 -0600 Subject: [PATCH 109/143] Do not move retired subjects to a separate subject set. This avoids a Zoo error when no subejcts are linked. --- src/imitation/algorithms/preference_comparisons.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index fe662648d..e80bade36 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -900,7 +900,6 @@ def __init__( zoo_project_id: int, zoo_workflow_id: int, linked_subject_set_id: int, - retired_subject_set_id: int, experiment_id: int, video_output_dir: AnyPath, video_fps: str = 20, @@ -911,7 +910,6 @@ def __init__( self.zoo_project_id = zoo_project_id self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id - self.retired_subject_set_id = retired_subject_set_id self.experiment_id = experiment_id self.video_fps = video_fps @@ -974,8 +972,6 @@ def _query(self, query_id): subject.metadata["#zoo_project_id"] = self.zoo_project_id subject.metadata["#zoo_workflow_id"] = self.zoo_workflow_id subject.metadata["#linked_subject_set_id"] = self.linked_subject_set_id - subject.metadata["#retired_subject_set_id"] = self.retired_subject_set_id - subject.metadata["#linked_subject_set_id"] = self.linked_subject_set_id subject.metadata["#experiment_id"] = self.experiment_id subject.save() @@ -1469,7 +1465,6 @@ def __init__( zoo_project_id: int, zoo_workflow_id: int, linked_subject_set_id: int, - retired_subject_set_id: int, wait_for_user: bool = True, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, @@ -1493,7 +1488,6 @@ def __init__( self.zoo_project_id = zoo_project_id self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id - self.retired_subject_set_id = retired_subject_set_id # Define annotation to label map self.annotation_to_label = { @@ -1556,9 +1550,8 @@ def _gather_preference(self, query_id: str) -> float: # classifications have been submitted to the project. self._process_zoo_classifications(last_id=0) - # Find subject sets + # Find linked subject set linked_subject_set = SubjectSet.find(self.linked_subject_set_id) - retired_subject_set = SubjectSet.find(self.retired_subject_set_id) # Get subject_id corresponding to query_id subject_id = self.query_to_subject[query_id] @@ -1569,9 +1562,6 @@ def _gather_preference(self, query_id: str) -> float: # Remove this subject from the subject set linked to the workflow linked_subject_set.remove([subject_id]) - # Add subject to the subject set for completed subjects - retired_subject_set.add([subject_id]) - return reduced_label def _reduce_annotations(self, annotations): From 04a07752b3be341ef409c5eb5a2db96f5b87551a Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Fri, 8 Mar 2024 14:50:27 -0600 Subject: [PATCH 110/143] Add None annotation to label map. --- src/imitation/algorithms/preference_comparisons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e80bade36..57eac66ac 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1494,7 +1494,8 @@ def __init__( 0: 1, 1: 0, 2: .5, - 3: -1 + 3: -1, + None: None } self.query_to_subject = None From 28d5f09093f5fb596c55ef2ead4a1bf294384626 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Sat, 9 Mar 2024 20:43:42 -0600 Subject: [PATCH 111/143] Make last_id trackable. --- .../algorithms/preference_comparisons.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 57eac66ac..be7d21ea6 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1498,18 +1498,14 @@ def __init__( None: None } - self.query_to_subject = None - - def _process_zoo_classifications(self, last_id=0): + self.query_to_subject = {} + self.subject_to_annotations = {} - # The default last_id is 0 meaning process all classifications the project has - # recieved for the specified workflow. - # TODO: make last_id trackable to avoid processing all classifications each time - # the gatherer is called. + def _process_zoo_classifications(self): # Access classifications from the last_id classifications = Classification.where( - last_id=last_id, + last_id=self.last_id, scope='project', project_id=self.zoo_project_id, workflow_id=self.zoo_workflow_id @@ -1518,8 +1514,6 @@ def _process_zoo_classifications(self, last_id=0): # Find workflow self.workflow = Workflow.find(self.zoo_workflow_id) - self.query_to_subject = {} - self.subject_to_annotations = {} for c in classifications: d = c.raw # Extract subject id @@ -1538,6 +1532,7 @@ def _process_zoo_classifications(self, last_id=0): self.query_to_subject[subject.raw["metadata"]["query_id"]] = sid # Create map entry for this subject self.subject_to_annotations[sid] = [label] + self.last_id = d['id'] def _gather_preference(self, query_id: str) -> float: @@ -1549,7 +1544,7 @@ def _gather_preference(self, query_id: str) -> float: # Without last_id trakcing this must be called each time to ensure latest # classifications are included. This could become time consuming if many # classifications have been submitted to the project. - self._process_zoo_classifications(last_id=0) + self._process_zoo_classifications(last_id=self.last_id) # Find linked subject set linked_subject_set = SubjectSet.find(self.linked_subject_set_id) @@ -1560,12 +1555,10 @@ def _gather_preference(self, query_id: str) -> float: # Get reduced_label for subject_id aggregated from each annotation for that subject reduced_label = self._reduce_annotations(self.subject_to_annotations[subject_id]) - # Remove this subject from the subject set linked to the workflow - linked_subject_set.remove([subject_id]) - return reduced_label def _reduce_annotations(self, annotations): + # Aggregate Zooniverse classifications count = Counter(annotations) return count.most_common(1)[0][0] From 50af8bc1af4bc5cc255cefa792dfaa81f6338912 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Sat, 9 Mar 2024 20:46:03 -0600 Subject: [PATCH 112/143] Make last_id trackable. --- src/imitation/algorithms/preference_comparisons.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index be7d21ea6..36a4dcdfc 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1498,6 +1498,7 @@ def __init__( None: None } + self.last_id = 0 self.query_to_subject = {} self.subject_to_annotations = {} From 4640623f24e73a7dc178308572dd5887e77b1454 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Sat, 9 Mar 2024 20:48:29 -0600 Subject: [PATCH 113/143] Make last_id trackable. --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 36a4dcdfc..a98a35d07 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1545,7 +1545,7 @@ def _gather_preference(self, query_id: str) -> float: # Without last_id trakcing this must be called each time to ensure latest # classifications are included. This could become time consuming if many # classifications have been submitted to the project. - self._process_zoo_classifications(last_id=self.last_id) + self._process_zoo_classifications(last_id) # Find linked subject set linked_subject_set = SubjectSet.find(self.linked_subject_set_id) From 0261e99a8e26317982b15c53f4e766f56417604f Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Sat, 9 Mar 2024 20:52:54 -0600 Subject: [PATCH 114/143] Fix last_id NameError --- src/imitation/algorithms/preference_comparisons.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index a98a35d07..3075fca81 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1542,10 +1542,7 @@ def _gather_preference(self, query_id: str) -> float: panoptes_password = os.environ["PANOPTES_PASSWORD"] Panoptes.connect(username=panoptes_username, password=panoptes_password) - # Without last_id trakcing this must be called each time to ensure latest - # classifications are included. This could become time consuming if many - # classifications have been submitted to the project. - self._process_zoo_classifications(last_id) + self._process_zoo_classifications() # Find linked subject set linked_subject_set = SubjectSet.find(self.linked_subject_set_id) From f39849f25545c398f4088d0092ff28b3cbe3f67d Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Sat, 9 Mar 2024 21:03:25 -0600 Subject: [PATCH 115/143] Fix UnboundLocalError --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 3075fca81..7a98c376f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1533,7 +1533,7 @@ def _process_zoo_classifications(self): self.query_to_subject[subject.raw["metadata"]["query_id"]] = sid # Create map entry for this subject self.subject_to_annotations[sid] = [label] - self.last_id = d['id'] + self.last_id = d['id'] def _gather_preference(self, query_id: str) -> float: From 67b1dd87cbf1cf7a55cc0d3b3bc049e4d5756543 Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 21 Mar 2024 13:54:55 -0500 Subject: [PATCH 116/143] Handle deletion of subjects on panoptes FE --- src/imitation/algorithms/preference_comparisons.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 7a98c376f..472244ff8 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1489,6 +1489,9 @@ def __init__( self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id + # Find workflow + self.workflow = Workflow.find(self.zoo_workflow_id) + # Define annotation to label map self.annotation_to_label = { 0: 1, @@ -1512,17 +1515,18 @@ def _process_zoo_classifications(self): workflow_id=self.zoo_workflow_id ) - # Find workflow - self.workflow = Workflow.find(self.zoo_workflow_id) + # get linked subjects and their statuses + statuses = self.workflow.subject_workflow_statuses(self.linked_subject_set_id) + linked_subject_statuses = {s.raw['links']['subject']: s.raw['retirement_reason'] for s in statuses} for c in classifications: d = c.raw # Extract subject id sid = int(d["links"]["subjects"][0]) # Get subject status - status = self.workflow.subject_workflow_status(sid) - # Check that subject is retired - if status.raw["retirement_reason"] is not None: + status = linked_subject_statuses[sid] + # Check that subject is linked to workflow and retired + if sid in set(linked_subjects) and status is not None: label = self.annotation_to_label[d["annotations"][0]["value"]] try: # Add label for this classification for the subject From 55c0e525a15831dab1fb87898ebd8abd75a05e5c Mon Sep 17 00:00:00 2001 From: Darryl Wright Date: Thu, 21 Mar 2024 14:04:13 -0500 Subject: [PATCH 117/143] Handle deletion of subjects on panoptes FE --- .../algorithms/preference_comparisons.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 472244ff8..b3d427554 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1523,20 +1523,22 @@ def _process_zoo_classifications(self): d = c.raw # Extract subject id sid = int(d["links"]["subjects"][0]) - # Get subject status - status = linked_subject_statuses[sid] - # Check that subject is linked to workflow and retired - if sid in set(linked_subjects) and status is not None: - label = self.annotation_to_label[d["annotations"][0]["value"]] - try: - # Add label for this classification for the subject - self.subject_to_annotations[sid].append(label) - except KeyError: - # Get query_id for this subject and add it to map - subject = Subject.find(sid) - self.query_to_subject[subject.raw["metadata"]["query_id"]] = sid - # Create map entry for this subject - self.subject_to_annotations[sid] = [label] + # Check that the subject is linked to workflow + if sid in set(linked_subjects): + # Get subject status + status = linked_subject_statuses[sid] + # Check subject isretired + if status is not None: + label = self.annotation_to_label[d["annotations"][0]["value"]] + try: + # Add label for this classification for the subject + self.subject_to_annotations[sid].append(label) + except KeyError: + # Get query_id for this subject and add it to map + subject = Subject.find(sid) + self.query_to_subject[subject.raw["metadata"]["query_id"]] = sid + # Create map entry for this subject + self.subject_to_annotations[sid] = [label] self.last_id = d['id'] def _gather_preference(self, query_id: str) -> float: From cd45613771dce0727aa46e55da1f449d15a06ad4 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 17:55:58 +0100 Subject: [PATCH 118/143] Extract video writing from PrefCollectGatherer to new parent class --- .../algorithms/preference_comparisons.py | 181 ++++++++++-------- .../algorithms/test_preference_comparisons.py | 10 +- 2 files changed, 108 insertions(+), 83 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e5c7b9c14..ad2d48f58 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -820,21 +820,17 @@ def __call__( return {str(uuid.uuid4()): query for query in queries} -class PrefCollectQuerent(PreferenceQuerent): - """Sends queries to a preference collection web service via HTTP requests.""" - +class VideoBasedQuerent(PreferenceQuerent): def __init__( self, - pref_collect_address: str, video_output_dir: str, video_fps: int = 20, rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): - """Initializes the PrefCollect querent. + """Initializes the querent. Args: - pref_collect_address: end point of the PrefCollect web service. video_output_dir: path to the video clip directory. video_fps: frames per second of the generated videos. rng: random number generator, if applicable. @@ -842,7 +838,6 @@ def __init__( """ super().__init__(custom_logger=custom_logger) self.rng = rng - self.query_endpoint = pref_collect_address + "/preferences/query/" self.video_output_dir = video_output_dir self.frames_per_second = video_fps @@ -854,87 +849,115 @@ def __call__( queries: Sequence[TrajectoryWithRewPair], ) -> Dict[str, TrajectoryWithRewPair]: identified_queries = super().__call__(queries) - - # Save fragment videos and submit queries for query_id, query in identified_queries.items(): - output_file_name = os.path.join( - self.video_output_dir, - f"{query_id}" + "-{}.webm", - ) - write_fragment_video( - query[0], - frames_per_second=self.frames_per_second, - output_path=output_file_name.format("left"), - ) - write_fragment_video( - query[1], - frames_per_second=self.frames_per_second, - output_path=output_file_name.format("right"), - ) - self._query(query_id) - + self._write_query_videos(query_id, query) return identified_queries - def _query(self, query_id): - requests.put( - self.query_endpoint + query_id, - json={"uuid": "{}".format(query_id)}, + def _write_query_videos(self, query_id, query): + output_file_name = os.path.join( + self.video_output_dir, + f"{query_id}" + "-{}.webm", + ) + self._write_fragment_video( + query[0], + output_path=output_file_name.format("left"), + ) + self._write_fragment_video( + query[1], + output_path=output_file_name.format("right"), ) + def _write_fragment_video( + self, + fragment: TrajectoryWithRew, + output_path: AnyPath, + progress_logger: bool = True, + ) -> None: + """Write fragment video clip.""" + frames_list: List[Union[os.PathLike, np.ndarray]] = [] + # Create fragment videos from environment's render images if available + if fragment.infos is not None and "rendered_img" in fragment.infos[0]: + for i in range(len(fragment.infos)): + frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] + if isinstance(frame, np.ndarray): + frame = self._add_missing_rgb_channels(frame) + frames_list.append(frame) + # Create fragment video from observations if possible + else: + if isinstance(fragment.obs, np.ndarray): + frames_list = [ + frame for frame in self._add_missing_rgb_channels(fragment.obs[1:]) + ] + else: + # TODO add support for DictObs + raise ValueError( + "Unsupported observation type " + f"for writing fragment video: {type(fragment.obs)}", + ) + # Note: `ImageSeqeuenceClip` handily accepts both + # lists of image paths or numpy arrays + clip = ImageSequenceClip(frames_list, fps=self.frames_per_second) + moviepy_logger = None if not progress_logger else "bar" + clip.write_videofile(output_path, logger=moviepy_logger) -def add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: - """Add missing RGB channels if needed. - If less than three channels are present, multiplies the last channel - until all three channels exist. + @staticmethod + def _add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: + """Add missing RGB channels if needed. + If less than three channels are present, multiplies the last channel + until all three channels exist. - Args: - frames: a stack of frames with potentially missing channels; - expected shape (batch, height, width, channels). + Args: + frames: a stack of frames with potentially missing channels; + expected shape (batch, height, width, channels). - Returns: - a stack of frames with exactly three channels. - """ - if frames.shape[-1] < 3: - missing_channels = 3 - frames.shape[-1] - frames = np.concatenate( - [frames] + missing_channels * [frames[..., -1][..., None]], - axis=-1, - ) - return frames - - -def write_fragment_video( - fragment: TrajectoryWithRew, - frames_per_second: int, - output_path: AnyPath, - progress_logger: bool = True, -) -> None: - """Write fragment video clip.""" - frames_list: List[Union[os.PathLike, np.ndarray]] = [] - # Create fragment videos from environment's render images if available - if fragment.infos is not None and "rendered_img" in fragment.infos[0]: - for i in range(len(fragment.infos)): - frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] - if isinstance(frame, np.ndarray): - frame = add_missing_rgb_channels(frame) - frames_list.append(frame) - # Create fragment video from observations if possible - else: - if isinstance(fragment.obs, np.ndarray): - frames_list = [ - frame for frame in add_missing_rgb_channels(fragment.obs[1:]) - ] - else: - # TODO add support for DictObs - raise ValueError( - "Unsupported observation type " - f"for writing fragment video: {type(fragment.obs)}", + Returns: + a stack of frames with exactly three channels. + """ + if frames.shape[-1] < 3: + missing_channels = 3 - frames.shape[-1] + frames = np.concatenate( + [frames] + missing_channels * [frames[..., -1][..., None]], + axis=-1, ) - # Note: `ImageSeqeuenceClip` handily accepts both - # lists of image paths or numpy arrays - clip = ImageSequenceClip(frames_list, fps=frames_per_second) - moviepy_logger = None if not progress_logger else "bar" - clip.write_videofile(output_path, logger=moviepy_logger) + return frames + + +class PrefCollectQuerent(VideoBasedQuerent): + """Sends queries to a REST web service.""" + + def __init__( + self, + pref_collect_address: str, + video_output_dir: str, + video_fps: int = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + ): + """Initializes the querent. + + Args: + pref_collect_address: end point of the PrefCollect web service. + video_output_dir: path to the video clip directory. + video_fps: frames per second of the generated videos. + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + """ + super().__init__(video_output_dir, video_fps, rng, custom_logger) + self.query_endpoint = pref_collect_address + "/preferences/query/" + + def __call__( + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: + identified_queries = super().__call__(queries) + for query_id in identified_queries.keys(): + self._query(query_id) + + def _query(self, query_id): + requests.put( + self.query_endpoint + query_id, + json={"uuid": "{}".format(query_id)}, + ) class PreferenceGatherer(abc.ABC): diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index e869231da..498a97794 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -29,9 +29,9 @@ PrefCollectQuerent, PreferenceGatherer, PreferenceQuerent, + VideoBasedQuerent, SyntheticGatherer, remove_rendered_images, - write_fragment_video, ) from imitation.data import types from imitation.data.types import TrajectoryWithRew, TrajectoryWithRewPair @@ -1230,12 +1230,14 @@ def fragment(request, empty_trajectory_with_rew): # utils @pytest.mark.parametrize("codec", ["webm", "mp4"]) def test_write_fragment_video(fragment, codec): - video_path = f"video.{codec}" + output_dir = "video" + video_based_querent = VideoBasedQuerent(video_output_dir=output_dir) + video_path = f"{output_dir}.{codec}" if isinstance(fragment.obs, types.DictObs): with pytest.raises(ValueError): - write_fragment_video(fragment, frames_per_second=5, output_path=video_path) + video_based_querent._write_fragment_video(fragment, output_path=video_path) else: - write_fragment_video(fragment, frames_per_second=5, output_path=video_path) + video_based_querent._write_fragment_video(fragment, output_path=video_path) assert os.path.isfile(video_path) os.remove(video_path) From 389b7199d7a8eade60274db294cf042764a44c4e Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 18:04:19 +0100 Subject: [PATCH 119/143] Extract handling of asynchronous preference collection to new parent class --- .../algorithms/preference_comparisons.py | 65 ++++++++++++------- .../algorithms/test_preference_comparisons.py | 12 ++-- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index ad2d48f58..7c3ce1d48 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1321,33 +1321,15 @@ def is_running_pytest_test(self) -> bool: return "PYTEST_CURRENT_TEST" in os.environ -class PrefCollectGatherer(PreferenceGatherer): - """Gathers preferences from PrefCollect interface.""" - +class AsynchronousHumanGatherer(PreferenceGatherer, abc.ABC): def __init__( - self, - pref_collect_address: str, - wait_for_user: bool = True, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None, + self, + wait_for_user: bool = True, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None, ) -> None: - """Initializes the preference gatherer. - - Args: - pref_collect_address: Network address to PrefCollect instance. - wait_for_user: Waits for user to input their preferences. - rng: random number generator, if applicable. - custom_logger: Where to log to; if None (default), creates a new logger. - querent_kwargs: Keyword arguments passed to the querent. - """ - super().__init__(rng, custom_logger) - querent_kwargs = querent_kwargs if querent_kwargs else {} - self.querent = PrefCollectQuerent( - pref_collect_address=pref_collect_address, - **querent_kwargs, - ) - self.query_endpoint = pref_collect_address + "/preferences/query/" + super().__init__(rng, custom_logger, querent_kwargs) self.pending_queries = {} self.wait_for_user = wait_for_user @@ -1373,6 +1355,39 @@ def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: return gathered_queries, np.array(gathered_preferences, dtype=np.float32) + @abc.abstractmethod + def _gather_preference(self, query_id: str) -> float: + raise NotImplementedError + + +class PrefCollectGatherer(AsynchronousHumanGatherer): + """Gathers preferences from a REST interface.""" + + def __init__( + self, + collection_service_address: str, + wait_for_user: bool = True, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None, + ) -> None: + """Initializes the preference gatherer. + + Args: + collection_service_address: Network address of the collection service's REST interface. + wait_for_user: Waits for user to input their preferences. + rng: random number generator, if applicable. + custom_logger: Where to log to; if None (default), creates a new logger. + querent_kwargs: Keyword arguments passed to the querent. + """ + super().__init__(wait_for_user, rng, custom_logger) + querent_kwargs = querent_kwargs if querent_kwargs else {} + self.querent = PrefCollectQuerent( + pref_collect_address=collection_service_address, + **querent_kwargs, + ) + self.query_endpoint = collection_service_address + "/preferences/query/" + def _gather_preference(self, query_id: str) -> float: answered_query = requests.get(self.query_endpoint + query_id).json() return answered_query["label"] diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 498a97794..9982be273 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1283,7 +1283,7 @@ def test_returns_none_for_unanswered_query(requests_mock): answer = None gatherer = PrefCollectGatherer( - pref_collect_address=address, + collection_service_address=address, querent_kwargs={"video_output_dir": "videos"}, ) @@ -1303,7 +1303,7 @@ def test_returns_preference_for_answered_query(requests_mock): answer = 1.0 gatherer = PrefCollectGatherer( - pref_collect_address=address, + collection_service_address=address, querent_kwargs={"video_output_dir": "videos"}, ) @@ -1319,7 +1319,7 @@ def test_returns_preference_for_answered_query(requests_mock): def test_keeps_pending_query_for_unanswered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", + collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, ) @@ -1334,7 +1334,7 @@ def test_keeps_pending_query_for_unanswered_query(): def test_deletes_pending_query_for_answered_query(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", + collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, ) @@ -1349,7 +1349,7 @@ def test_deletes_pending_query_for_answered_query(): def test_gathers_valid_preference(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", + collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, ) @@ -1366,7 +1366,7 @@ def test_gathers_valid_preference(): def test_ignores_incomparable_answer(): gatherer = PrefCollectGatherer( - pref_collect_address="https://test.de", + collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, ) From 9e57c81df1b67c415a9197959306eed43e8bf87d Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 18:05:06 +0100 Subject: [PATCH 120/143] Adapt SynchronousHumanGatherer to use new VideoBasedQuerent --- .../algorithms/preference_comparisons.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 7c3ce1d48..d50122a8e 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1132,11 +1132,10 @@ def __init__( rng: random number generator """ super().__init__(custom_logger=custom_logger, rng=rng) + self.querent = VideoBasedQuerent(video_output_dir=video_dir, video_fps=frames_per_second) self.video_dir = video_dir - os.makedirs(video_dir, exist_ok=True) self.video_width = video_width self.video_height = video_height - self.frames_per_second = frames_per_second def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: """Displays each pair of fragments and asks for a preference. @@ -1150,22 +1149,12 @@ def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape (b, ), where b is the length of `fragment_pairs` """ + queries = [] preferences = np.zeros(len(self.pending_queries), dtype=np.float32) for i, (query_id, query) in enumerate(self.pending_queries.items()): - write_fragment_video( - query[0], - frames_per_second=self.frames_per_second, - output_path=os.path.join(self.video_dir, f"{query_id}-left.webm"), - ) - write_fragment_video( - query[1], - frames_per_second=self.frames_per_second, - output_path=os.path.join(self.video_dir, f"{query_id}-right.webm"), - ) if self._display_videos_and_gather_preference(query_id): + queries.append(query) preferences[i] = 1 - - queries = list(self.pending_queries.values()) self.pending_queries.clear() return queries, preferences From b7e53827bbd3e3bbb447afe6ae1c4cb182eb1e02 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 18:07:31 +0100 Subject: [PATCH 121/143] Rename pref collect gatherer/querent to REST gatherer/querent --- .../algorithms/preference_comparisons.py | 14 ++++++------- .../config/train_preference_comparisons.py | 2 +- .../algorithms/test_preference_comparisons.py | 20 +++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index d50122a8e..326223cdd 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -922,12 +922,12 @@ def _add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: return frames -class PrefCollectQuerent(VideoBasedQuerent): +class RESTQuerent(VideoBasedQuerent): """Sends queries to a REST web service.""" def __init__( self, - pref_collect_address: str, + collection_service_address: str, video_output_dir: str, video_fps: int = 20, rng: Optional[np.random.Generator] = None, @@ -936,14 +936,14 @@ def __init__( """Initializes the querent. Args: - pref_collect_address: end point of the PrefCollect web service. + collection_service_address: Network address of the collection service's REST interface. video_output_dir: path to the video clip directory. video_fps: frames per second of the generated videos. rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ super().__init__(video_output_dir, video_fps, rng, custom_logger) - self.query_endpoint = pref_collect_address + "/preferences/query/" + self.query_endpoint = collection_service_address + "/preferences/query/" def __call__( self, @@ -1349,7 +1349,7 @@ def _gather_preference(self, query_id: str) -> float: raise NotImplementedError -class PrefCollectGatherer(AsynchronousHumanGatherer): +class RESTGatherer(AsynchronousHumanGatherer): """Gathers preferences from a REST interface.""" def __init__( @@ -1371,8 +1371,8 @@ def __init__( """ super().__init__(wait_for_user, rng, custom_logger) querent_kwargs = querent_kwargs if querent_kwargs else {} - self.querent = PrefCollectQuerent( - pref_collect_address=collection_service_address, + self.querent = RESTQuerent( + collection_service_address=collection_service_address, **querent_kwargs, ) self.query_endpoint = collection_service_address + "/preferences/query/" diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index bd96c329b..7807fe0c8 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -92,7 +92,7 @@ def synch_human_preferences(): @train_preference_comparisons_ex.named_config def human_preferences(): - gatherer_cls = preference_comparisons.PrefCollectGatherer + gatherer_cls = preference_comparisons.RESTGatherer gatherer_kwargs = dict( pref_collect_address="http://127.0.0.1:8000", wait_for_user=True, diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 9982be273..aa7bc44bf 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -25,8 +25,8 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons from imitation.algorithms.preference_comparisons import ( - PrefCollectGatherer, - PrefCollectQuerent, + RESTGatherer, + RESTQuerent, PreferenceGatherer, PreferenceQuerent, VideoBasedQuerent, @@ -1158,7 +1158,7 @@ def test_returned_queries_have_uuid(): # PrefCollectQuerent def test_sends_put_request_for_each_query(requests_mock): address = "https://test.de" - querent = PrefCollectQuerent(pref_collect_address=address, video_output_dir="video") + querent = RESTQuerent(collection_service_address=address, video_output_dir="video") query_id = "1234" requests_mock.put(f"{address}/preferences/query/{query_id}") @@ -1184,7 +1184,7 @@ def empty_trajectory_with_rew(): def test_prefcollectquerent_call_creates_all_videos(empty_trajectory_with_rew): address = "https://test.de" queries = [(empty_trajectory_with_rew, empty_trajectory_with_rew)] - querent = PrefCollectQuerent(pref_collect_address=address, video_output_dir="video") + querent = RESTQuerent(collection_service_address=address, video_output_dir="video") identified_queries = querent(queries) for query_id, _ in identified_queries.items(): file = os.path.join(querent.video_output_dir, query_id + "-{}.webm") @@ -1282,7 +1282,7 @@ def test_returns_none_for_unanswered_query(requests_mock): query_id = "1234" answer = None - gatherer = PrefCollectGatherer( + gatherer = RESTGatherer( collection_service_address=address, querent_kwargs={"video_output_dir": "videos"}, ) @@ -1302,7 +1302,7 @@ def test_returns_preference_for_answered_query(requests_mock): query_id = "1234" answer = 1.0 - gatherer = PrefCollectGatherer( + gatherer = RESTGatherer( collection_service_address=address, querent_kwargs={"video_output_dir": "videos"}, ) @@ -1318,7 +1318,7 @@ def test_returns_preference_for_answered_query(requests_mock): def test_keeps_pending_query_for_unanswered_query(): - gatherer = PrefCollectGatherer( + gatherer = RESTGatherer( collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, @@ -1333,7 +1333,7 @@ def test_keeps_pending_query_for_unanswered_query(): def test_deletes_pending_query_for_answered_query(): - gatherer = PrefCollectGatherer( + gatherer = RESTGatherer( collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, @@ -1348,7 +1348,7 @@ def test_deletes_pending_query_for_answered_query(): def test_gathers_valid_preference(): - gatherer = PrefCollectGatherer( + gatherer = RESTGatherer( collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, @@ -1365,7 +1365,7 @@ def test_gathers_valid_preference(): def test_ignores_incomparable_answer(): - gatherer = PrefCollectGatherer( + gatherer = RESTGatherer( collection_service_address="https://test.de", wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, From a951dc5152dd985e25cff681f57b11ac3cea200d Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 21:54:49 +0100 Subject: [PATCH 122/143] Adjust ZooniverseGatherer/Querent to new base classes --- .../algorithms/preference_comparisons.py | 97 ++++++++----------- 1 file changed, 43 insertions(+), 54 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index f10ef995b..391f9384f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -831,11 +831,12 @@ def __call__( class VideoBasedQuerent(PreferenceQuerent): def __init__( - self, - video_output_dir: str, - video_fps: int = 20, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + self, + video_output_dir: str, + video_type="webm", + video_fps: int = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Initializes the querent. @@ -848,6 +849,7 @@ def __init__( super().__init__(custom_logger=custom_logger) self.rng = rng self.video_output_dir = video_output_dir + self.video_type = video_type self.frames_per_second = video_fps # Create video directory @@ -865,7 +867,7 @@ def __call__( def _write_query_videos(self, query_id, query): output_file_name = os.path.join( self.video_output_dir, - f"{query_id}" + "-{}.webm", + f"{query_id}" + "-{}" + f".{self.video_type}", ) self._write_fragment_video( query[0], @@ -938,49 +940,35 @@ class ZooniverseQuerent(VideoBasedQuerent): """Sends queries to the Zooniverse interface.""" def __init__( - self, - pref_collect_address: str, - zoo_project_id: int, - zoo_workflow_id: int, - linked_subject_set_id: int, - experiment_id: int, - video_output_dir: AnyPath, - video_fps: str = 20, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + self, + zoo_project_id: int, + zoo_workflow_id: int, + linked_subject_set_id: int, + experiment_id: int, + video_output_dir: AnyPath, + video_fps: str = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): - super().__init__(pref_collect_address, video_output_dir, video_fps, rng, custom_logger) + super().__init__( + video_output_dir=video_output_dir, + video_type="gif", + video_fps=video_fps, + rng=rng, + custom_logger=custom_logger + ) self.zoo_project_id = zoo_project_id self.zoo_workflow_id = zoo_workflow_id self.linked_subject_set_id = linked_subject_set_id self.experiment_id = experiment_id - self.video_fps = video_fps def __call__( self, queries: Sequence[TrajectoryWithRewPair], ) -> Dict[str, TrajectoryWithRewPair]: - # Call PreferenceQuerent not PrefCollectQuerent identified_queries = super().__call__(queries) - - # Save fragment videos and submit queries - for query_id, query in identified_queries.items(): - output_file_name = os.path.join( - self.video_output_dir, - f"{query_id}" + "-{}.gif", - ) - write_fragment_video( - query[0], - frames_per_second=self.frames_per_second, - output_path=output_file_name.format("left"), - ) - write_fragment_video( - query[1], - frames_per_second=self.frames_per_second, - output_path=output_file_name.format("right"), - ) + for query_id in identified_queries.keys(): self._query(query_id) - return identified_queries def _query(self, query_id): @@ -1011,7 +999,7 @@ def _query(self, query_id): subject.metadata["query_id"] = f"{query_id}" subject.metadata["#left_video"] = output_file_name.format("left") subject.metadata["#right_video"] = output_file_name.format("right") - subject.metadata["#video_fps"] = self.video_fps + subject.metadata["#video_fps"] = self.frames_per_second subject.metadata["#zoo_project_id"] = self.zoo_project_id subject.metadata["#zoo_workflow_id"] = self.zoo_workflow_id subject.metadata["#linked_subject_set_id"] = self.linked_subject_set_id @@ -1043,7 +1031,7 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(video_output_dir, video_fps, rng, custom_logger) + super().__init__(video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger) self.query_endpoint = collection_service_address + "/preferences/query/" def __call__( @@ -1053,6 +1041,7 @@ def __call__( identified_queries = super().__call__(queries) for query_id in identified_queries.keys(): self._query(query_id) + return identified_queries def _query(self, query_id): requests.put( @@ -1486,17 +1475,10 @@ def _gather_preference(self, query_id: str) -> float: class ZooniverseGatherer(AsynchronousHumanGatherer): """Gathers preferences from Zooniverse interface.""" - def __init__( - self, - pref_collect_address: str, - zoo_project_id: int, - zoo_workflow_id: int, - linked_subject_set_id: int, - wait_for_user: bool = True, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None - ) -> None: + def __init__(self, zoo_project_id: int, zoo_workflow_id: int, linked_subject_set_id: int, + wait_for_user: bool = True, rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None) -> None: """Initializes the preference gatherer. Args: @@ -1505,11 +1487,18 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - video_output_dir = querent_kwargs["video_output_dir"] - super().__init__(pref_collect_address, querent_kwargs={"video_output_dir": querent_kwargs["video_output_dir"]}) + super().__init__( + wait_for_user=wait_for_user, + rng=rng, + custom_logger=custom_logger, + ) self.querent = ZooniverseQuerent( - pref_collect_address, - **querent_kwargs + zoo_project_id=zoo_project_id, + zoo_workflow_id=zoo_workflow_id, + linked_subject_set_id=linked_subject_set_id, + rng=rng, + custom_logger=custom_logger, + **querent_kwargs, ) self.zoo_project_id = zoo_project_id From fb947069c9b99b18b303071f2645069406824937 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 21:56:57 +0100 Subject: [PATCH 123/143] Remove Zooniverse classes --- .../algorithms/preference_comparisons.py | 186 ------------------ 1 file changed, 186 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 391f9384f..8384de422 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -936,81 +936,6 @@ def _add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: return frames -class ZooniverseQuerent(VideoBasedQuerent): - """Sends queries to the Zooniverse interface.""" - - def __init__( - self, - zoo_project_id: int, - zoo_workflow_id: int, - linked_subject_set_id: int, - experiment_id: int, - video_output_dir: AnyPath, - video_fps: str = 20, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - ): - super().__init__( - video_output_dir=video_output_dir, - video_type="gif", - video_fps=video_fps, - rng=rng, - custom_logger=custom_logger - ) - self.zoo_project_id = zoo_project_id - self.zoo_workflow_id = zoo_workflow_id - self.linked_subject_set_id = linked_subject_set_id - self.experiment_id = experiment_id - - def __call__( - self, - queries: Sequence[TrajectoryWithRewPair], - ) -> Dict[str, TrajectoryWithRewPair]: - identified_queries = super().__call__(queries) - for query_id in identified_queries.keys(): - self._query(query_id) - return identified_queries - - def _query(self, query_id): - - # Authenticate with Zooniverse - panoptes_username = os.environ["PANOPTES_USERNAME"] - panoptes_password = os.environ["PANOPTES_PASSWORD"] - Panoptes.connect(username=panoptes_username, password=panoptes_password) - - # Find project and workflow - project = Project.find(self.zoo_project_id) - workflow = Workflow.find(self.zoo_workflow_id) - - # Find subject sets - linked_subject_set = SubjectSet.find(self.linked_subject_set_id) - - # Create subject for this query_id - subject = Subject() - subject.links.project = project - - output_file_name = os.path.join( - self.video_output_dir, f"{query_id}" + "-{}.gif" - ) - - subject.add_location(open(output_file_name.format("left"), "rb")) - subject.add_location(open(output_file_name.format("right"), "rb")) - - subject.metadata["query_id"] = f"{query_id}" - subject.metadata["#left_video"] = output_file_name.format("left") - subject.metadata["#right_video"] = output_file_name.format("right") - subject.metadata["#video_fps"] = self.frames_per_second - subject.metadata["#zoo_project_id"] = self.zoo_project_id - subject.metadata["#zoo_workflow_id"] = self.zoo_workflow_id - subject.metadata["#linked_subject_set_id"] = self.linked_subject_set_id - subject.metadata["#experiment_id"] = self.experiment_id - - subject.save() - - # Add the subject to the linked subject set - linked_subject_set.add(subject) - - class RESTQuerent(VideoBasedQuerent): """Sends queries to a REST web service.""" @@ -1472,117 +1397,6 @@ def _gather_preference(self, query_id: str) -> float: return answered_query["label"] -class ZooniverseGatherer(AsynchronousHumanGatherer): - """Gathers preferences from Zooniverse interface.""" - - def __init__(self, zoo_project_id: int, zoo_workflow_id: int, linked_subject_set_id: int, - wait_for_user: bool = True, rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None) -> None: - """Initializes the preference gatherer. - - Args: - pref_collect_address: Network address to PrefCollect instance. - wait_for_user: Waits for user to input their preferences. - rng: random number generator, if applicable. - custom_logger: Where to log to; if None (default), creates a new logger. - """ - super().__init__( - wait_for_user=wait_for_user, - rng=rng, - custom_logger=custom_logger, - ) - self.querent = ZooniverseQuerent( - zoo_project_id=zoo_project_id, - zoo_workflow_id=zoo_workflow_id, - linked_subject_set_id=linked_subject_set_id, - rng=rng, - custom_logger=custom_logger, - **querent_kwargs, - ) - - self.zoo_project_id = zoo_project_id - self.zoo_workflow_id = zoo_workflow_id - self.linked_subject_set_id = linked_subject_set_id - - # Find workflow - self.workflow = Workflow.find(self.zoo_workflow_id) - - # Define annotation to label map - self.annotation_to_label = { - 0: 1, - 1: 0, - 2: .5, - 3: -1, - None: None - } - - self.last_id = 0 - self.query_to_subject = {} - self.subject_to_annotations = {} - - def _process_zoo_classifications(self): - - # Access classifications from the last_id - classifications = Classification.where( - last_id=self.last_id, - scope='project', - project_id=self.zoo_project_id, - workflow_id=self.zoo_workflow_id - ) - - # get linked subjects and their statuses - statuses = self.workflow.subject_workflow_statuses(self.linked_subject_set_id) - linked_subject_statuses = {s.raw['links']['subject']: s.raw['retirement_reason'] for s in statuses} - - for c in classifications: - d = c.raw - # Extract subject id - sid = int(d["links"]["subjects"][0]) - # Check that the subject is linked to workflow - if sid in set(linked_subjects): - # Get subject status - status = linked_subject_statuses[sid] - # Check subject isretired - if status is not None: - label = self.annotation_to_label[d["annotations"][0]["value"]] - try: - # Add label for this classification for the subject - self.subject_to_annotations[sid].append(label) - except KeyError: - # Get query_id for this subject and add it to map - subject = Subject.find(sid) - self.query_to_subject[subject.raw["metadata"]["query_id"]] = sid - # Create map entry for this subject - self.subject_to_annotations[sid] = [label] - self.last_id = d['id'] - - def _gather_preference(self, query_id: str) -> float: - - # Authenticate with Zooniverse - panoptes_username = os.environ["PANOPTES_USERNAME"] - panoptes_password = os.environ["PANOPTES_PASSWORD"] - Panoptes.connect(username=panoptes_username, password=panoptes_password) - - self._process_zoo_classifications() - - # Find linked subject set - linked_subject_set = SubjectSet.find(self.linked_subject_set_id) - - # Get subject_id corresponding to query_id - subject_id = self.query_to_subject[query_id] - - # Get reduced_label for subject_id aggregated from each annotation for that subject - reduced_label = self._reduce_annotations(self.subject_to_annotations[subject_id]) - - return reduced_label - - def _reduce_annotations(self, annotations): - # Aggregate Zooniverse classifications - count = Counter(annotations) - return count.most_common(1)[0][0] - - def remove_rendered_images(trajectories: Sequence[TrajectoryWithRew]) -> None: """Removes rendered images of the provided trajectories list.""" for traj in trajectories: From ee40083ef47216ce4b663760b0d2c7dea3fd3dbd Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 21:59:54 +0100 Subject: [PATCH 124/143] Add empty _query method to VideoBasedQuerent --- src/imitation/algorithms/preference_comparisons.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 8384de422..7ffbfd827 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -862,6 +862,7 @@ def __call__( identified_queries = super().__call__(queries) for query_id, query in identified_queries.items(): self._write_query_videos(query_id, query) + self._query(query_id) return identified_queries def _write_query_videos(self, query_id, query): @@ -935,6 +936,9 @@ def _add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: ) return frames + def _query(self, query_id): + pass + class RESTQuerent(VideoBasedQuerent): """Sends queries to a REST web service.""" @@ -959,15 +963,6 @@ def __init__( super().__init__(video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger) self.query_endpoint = collection_service_address + "/preferences/query/" - def __call__( - self, - queries: Sequence[TrajectoryWithRewPair], - ) -> Dict[str, TrajectoryWithRewPair]: - identified_queries = super().__call__(queries) - for query_id in identified_queries.keys(): - self._query(query_id) - return identified_queries - def _query(self, query_id): requests.put( self.query_endpoint + query_id, From cb1be271875139547bc6c03ca734241242f31955 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 22:54:25 +0100 Subject: [PATCH 125/143] Split video writing into smaller methods --- .../algorithms/preference_comparisons.py | 84 +++++++++++-------- 1 file changed, 47 insertions(+), 37 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 7ffbfd827..3d43304b0 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -870,14 +870,11 @@ def _write_query_videos(self, query_id, query): self.video_output_dir, f"{query_id}" + "-{}" + f".{self.video_type}", ) - self._write_fragment_video( - query[0], - output_path=output_file_name.format("left"), - ) - self._write_fragment_video( - query[1], - output_path=output_file_name.format("right"), - ) + for i, alternative in enumerate(("left", "right")): + self._write_fragment_video( + fragment=query[i], + output_path=output_file_name.format(alternative), + ) def _write_fragment_video( self, @@ -886,37 +883,42 @@ def _write_fragment_video( progress_logger: bool = True, ) -> None: """Write fragment video clip.""" - frames_list: List[Union[os.PathLike, np.ndarray]] = [] - # Create fragment videos from environment's render images if available - if fragment.infos is not None and "rendered_img" in fragment.infos[0]: - for i in range(len(fragment.infos)): - frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] - if isinstance(frame, np.ndarray): - frame = self._add_missing_rgb_channels(frame) - frames_list.append(frame) - # Create fragment video from observations if possible - else: - if isinstance(fragment.obs, np.ndarray): - frames_list = [ - frame for frame in self._add_missing_rgb_channels(fragment.obs[1:]) - ] - else: - # TODO add support for DictObs - raise ValueError( - "Unsupported observation type " - f"for writing fragment video: {type(fragment.obs)}", - ) - # Note: `ImageSeqeuenceClip` handily accepts both - # lists of image paths or numpy arrays - clip = ImageSequenceClip(frames_list, fps=self.frames_per_second) - moviepy_logger = None if not progress_logger else "bar" - if output_path.endswith('.gif'): - clip.write_gif(output_path, program='ffmpeg', logger=moviepy_logger) - else: - clip.write_videofile(output_path, logger=moviepy_logger) + frames = self._get_frames(fragment) + self._write(frames, output_path, progress_logger) + + def _get_frames(self, fragment): + if self._rendered_image_of_observation_is_available(fragment): + return self._get_frames_for_each_observation(fragment) + elif self._observation_type_allows_rendering(fragment.obs): + return self._render_frames_for_each_observation(fragment) + else: # TODO add support for DictObs + raise ValueError( + "Unsupported observation type " + f"for writing fragment video: {type(fragment.obs)}", + ) + + @staticmethod + def _rendered_image_of_observation_is_available(fragment): + return fragment.infos is not None and "rendered_img" in fragment.infos[0] + + def _get_frames_for_each_observation(self, fragment): + frames: List[Union[os.PathLike, np.ndarray]] = [] + for i in range(len(fragment.infos)): + frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] + if isinstance(frame, np.ndarray): + frame = self._maybe_add_missing_rgb_channels(frame) + frames.append(frame) + return frames + + @staticmethod + def _observation_type_allows_rendering(observation): + return isinstance(observation, np.ndarray) + + def _render_frames_for_each_observation(self, fragment): + return [frame for frame in self._maybe_add_missing_rgb_channels(fragment.obs[1:])] @staticmethod - def _add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: + def _maybe_add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: """Add missing RGB channels if needed. If less than three channels are present, multiplies the last channel until all three channels exist. @@ -936,7 +938,15 @@ def _add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: ) return frames + def _write(self, frames: List[Union[os.PathLike, np.ndarray]], output_path, progress_logger): + clip = ImageSequenceClip(frames, fps=self.frames_per_second) # accepts list of image paths and numpy arrays + if output_path.endswith('.gif'): + clip.write_gif(output_path, program='ffmpeg', logger="bar" if progress_logger else None) + else: + clip.write_videofile(output_path, logger="bar" if progress_logger else None) + def _query(self, query_id): + """Override this method in subclasses to specify query behavior.""" pass From 0646dc73915a356bb54e43990c5450e3425cc8d0 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Wed, 27 Mar 2024 23:09:23 +0100 Subject: [PATCH 126/143] Remove unused imports --- src/imitation/algorithms/preference_comparisons.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 3d43304b0..e24bc053e 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -38,7 +38,6 @@ from torch import nn from torch.utils import data as data_th from tqdm.auto import tqdm -from collections import Counter from imitation.algorithms import base from imitation.data import rollout, types, wrappers @@ -55,14 +54,6 @@ from imitation.util import logger as imit_logger from imitation.util import networks, util -from panoptes_client import ( - Panoptes, - Project, - Workflow, - Classification, - SubjectSet, - Subject -) class TrajectoryGenerator(abc.ABC): """Generator of trajectories with optional training logic.""" From 09ddc0ece014e4b4a4384e198e5cbe18a57badd4 Mon Sep 17 00:00:00 2001 From: rk1a Date: Mon, 8 Apr 2024 20:59:22 +0200 Subject: [PATCH 127/143] Remove support for observations in video writing --- .../algorithms/preference_comparisons.py | 38 ++----------------- 1 file changed, 3 insertions(+), 35 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index e24bc053e..7116e27f5 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -880,12 +880,10 @@ def _write_fragment_video( def _get_frames(self, fragment): if self._rendered_image_of_observation_is_available(fragment): return self._get_frames_for_each_observation(fragment) - elif self._observation_type_allows_rendering(fragment.obs): - return self._render_frames_for_each_observation(fragment) - else: # TODO add support for DictObs + else: raise ValueError( - "Unsupported observation type " - f"for writing fragment video: {type(fragment.obs)}", + "No rendered images contained in info dict. " + "Please apply `RenderImageWrapper` to your environment.", ) @staticmethod @@ -896,39 +894,9 @@ def _get_frames_for_each_observation(self, fragment): frames: List[Union[os.PathLike, np.ndarray]] = [] for i in range(len(fragment.infos)): frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] - if isinstance(frame, np.ndarray): - frame = self._maybe_add_missing_rgb_channels(frame) frames.append(frame) return frames - @staticmethod - def _observation_type_allows_rendering(observation): - return isinstance(observation, np.ndarray) - - def _render_frames_for_each_observation(self, fragment): - return [frame for frame in self._maybe_add_missing_rgb_channels(fragment.obs[1:])] - - @staticmethod - def _maybe_add_missing_rgb_channels(frames: np.ndarray) -> np.ndarray: - """Add missing RGB channels if needed. - If less than three channels are present, multiplies the last channel - until all three channels exist. - - Args: - frames: a stack of frames with potentially missing channels; - expected shape (batch, height, width, channels). - - Returns: - a stack of frames with exactly three channels. - """ - if frames.shape[-1] < 3: - missing_channels = 3 - frames.shape[-1] - frames = np.concatenate( - [frames] + missing_channels * [frames[..., -1][..., None]], - axis=-1, - ) - return frames - def _write(self, frames: List[Union[os.PathLike, np.ndarray]], output_path, progress_logger): clip = ImageSequenceClip(frames, fps=self.frames_per_second) # accepts list of image paths and numpy arrays if output_path.endswith('.gif'): From 5d9a160b858b5804d3dc80e11c6809cbdcb81818 Mon Sep 17 00:00:00 2001 From: rk1a Date: Mon, 8 Apr 2024 21:45:54 +0200 Subject: [PATCH 128/143] Transfer video data via REST request and simplify VideoBasedQuerent --- .../algorithms/preference_comparisons.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 7116e27f5..2bc29f431 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -853,19 +853,17 @@ def __call__( identified_queries = super().__call__(queries) for query_id, query in identified_queries.items(): self._write_query_videos(query_id, query) - self._query(query_id) return identified_queries def _write_query_videos(self, query_id, query): - output_file_name = os.path.join( - self.video_output_dir, - f"{query_id}" + "-{}" + f".{self.video_type}", - ) for i, alternative in enumerate(("left", "right")): self._write_fragment_video( fragment=query[i], - output_path=output_file_name.format(alternative), + output_path=self._create_query_video_path(query_id, alternative), ) + + def _create_query_video_path(self, query_id: str, alternative: str): + return pathlib.Path(self.video_output_dir) / f"{query_id}-{alternative}.{self.video_type}" def _write_fragment_video( self, @@ -904,10 +902,6 @@ def _write(self, frames: List[Union[os.PathLike, np.ndarray]], output_path, prog else: clip.write_videofile(output_path, logger="bar" if progress_logger else None) - def _query(self, query_id): - """Override this method in subclasses to specify query behavior.""" - pass - class RESTQuerent(VideoBasedQuerent): """Sends queries to a REST web service.""" @@ -931,13 +925,33 @@ def __init__( """ super().__init__(video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger) self.query_endpoint = collection_service_address + "/preferences/query/" + + + def __call__( + self, + queries: Sequence[TrajectoryWithRewPair], + ) -> Dict[str, TrajectoryWithRewPair]: + identified_queries = super().__call__(queries) + for query_id, query in identified_queries.items(): + self._query(query_id) + return identified_queries def _query(self, query_id): + video_data = self._load_video_data() requests.put( self.query_endpoint + query_id, json={"uuid": "{}".format(query_id)}, + data=video_data, ) + def _load_video_data(self, query_id): + video_data = [] + for alternative in ("left", "right"): + video_path = self._create_query_video_path(query_id, alternative) + with open(video_path, 'rb') as video_file: + video_data.append(video_file.read()) + return video_data + class PreferenceGatherer(abc.ABC): """Base class for gathering preference comparisons between trajectory fragments.""" From a3175ac7fea3741eba0ebf0949a8fe3c6956f49c Mon Sep 17 00:00:00 2001 From: rk1a Date: Mon, 29 Apr 2024 00:05:06 +0200 Subject: [PATCH 129/143] Fix tests as far as possible --- .../algorithms/preference_comparisons.py | 34 +++++++++--------- .../config/train_preference_comparisons.py | 2 +- .../algorithms/test_preference_comparisons.py | 35 ++++++++----------- 3 files changed, 34 insertions(+), 37 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 2bc29f431..953515a33 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -875,7 +875,7 @@ def _write_fragment_video( frames = self._get_frames(fragment) self._write(frames, output_path, progress_logger) - def _get_frames(self, fragment): + def _get_frames(self, fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]: if self._rendered_image_of_observation_is_available(fragment): return self._get_frames_for_each_observation(fragment) else: @@ -885,22 +885,23 @@ def _get_frames(self, fragment): ) @staticmethod - def _rendered_image_of_observation_is_available(fragment): + def _rendered_image_of_observation_is_available(fragment: TrajectoryWithRew) -> bool: return fragment.infos is not None and "rendered_img" in fragment.infos[0] - def _get_frames_for_each_observation(self, fragment): - frames: List[Union[os.PathLike, np.ndarray]] = [] + def _get_frames_for_each_observation(self, fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]: + frames: list[Union[os.PathLike, np.ndarray]] = [] for i in range(len(fragment.infos)): frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] frames.append(frame) return frames - def _write(self, frames: List[Union[os.PathLike, np.ndarray]], output_path, progress_logger): + def _write(self, frames: List[Union[os.PathLike, np.ndarray]], output_path: os.PathLike, progress_logger: bool): clip = ImageSequenceClip(frames, fps=self.frames_per_second) # accepts list of image paths and numpy arrays - if output_path.endswith('.gif'): - clip.write_gif(output_path, program='ffmpeg', logger="bar" if progress_logger else None) + if output_path.suffix == '.gif': + clip.write_gif(str(output_path), program='ffmpeg', logger="bar" if progress_logger else None) else: - clip.write_videofile(output_path, logger="bar" if progress_logger else None) + print(output_path) + clip.write_videofile(str(output_path), logger="bar" if progress_logger else None) class RESTQuerent(VideoBasedQuerent): @@ -936,20 +937,21 @@ def __call__( self._query(query_id) return identified_queries - def _query(self, query_id): - video_data = self._load_video_data() + def _query(self, query_id: str): + video_data = self._load_video_data(query_id) requests.put( self.query_endpoint + query_id, - json={"uuid": "{}".format(query_id)}, - data=video_data, + json={"uuid": "{}".format(query_id), **video_data}, ) - def _load_video_data(self, query_id): - video_data = [] + def _load_video_data(self, query_id: str) -> dict[str, bytes]: + import base64 + video_data = {} for alternative in ("left", "right"): video_path = self._create_query_video_path(query_id, alternative) - with open(video_path, 'rb') as video_file: - video_data.append(video_file.read()) + if video_path.exists(): + with open(video_path, 'rb') as video_file: + video_data[alternative] = base64.b64encode(video_file.read()).decode('utf-8') return video_data diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 7807fe0c8..065dc6528 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -94,7 +94,7 @@ def synch_human_preferences(): def human_preferences(): gatherer_cls = preference_comparisons.RESTGatherer gatherer_kwargs = dict( - pref_collect_address="http://127.0.0.1:8000", + collection_service_address="http://127.0.0.1:8000", wait_for_user=True, querent_kwargs=dict( video_output_dir="../pref_collect/videofiles", diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index aa7bc44bf..de8f949d8 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1169,21 +1169,21 @@ def test_sends_put_request_for_each_query(requests_mock): @pytest.fixture -def empty_trajectory_with_rew(): +def empty_trajectory_with_rew_and_render_imgs(): num_frames = 10 frame_shape = (200, 200) return types.TrajectoryWithRew( obs=np.zeros((num_frames, *frame_shape, 3), np.uint8), acts=np.zeros((num_frames - 1,), np.uint8), - infos=np.array([{} for _ in range(num_frames - 1)]), + infos=np.array([{"rendered_img": np.zeros((*frame_shape, 3), np.uint8)} for _ in range(num_frames - 1)]), rews=np.zeros((num_frames - 1,)), terminal=True, ) -def test_prefcollectquerent_call_creates_all_videos(empty_trajectory_with_rew): +def test_prefcollectquerent_call_creates_all_videos(empty_trajectory_with_rew_and_render_imgs): address = "https://test.de" - queries = [(empty_trajectory_with_rew, empty_trajectory_with_rew)] + queries = [(empty_trajectory_with_rew_and_render_imgs, empty_trajectory_with_rew_and_render_imgs)] querent = RESTQuerent(collection_service_address=address, video_output_dir="video") identified_queries = querent(queries) for query_id, _ in identified_queries.items(): @@ -1194,34 +1194,29 @@ def test_prefcollectquerent_call_creates_all_videos(empty_trajectory_with_rew): @pytest.fixture( - params=["obs_only", "dictobs", "with_render_images", "with_render_image_paths"], + params=["obs_only", "with_render_images", "with_render_image_paths"], ) -def fragment(request, empty_trajectory_with_rew): - obs = empty_trajectory_with_rew.obs - infos = empty_trajectory_with_rew.infos - if request.param == "dictobs": - obs = types.DictObs({"obs": empty_trajectory_with_rew.obs}) - elif request.param == "with_render_images": +def fragment(request, empty_trajectory_with_rew_and_render_imgs): + obs = empty_trajectory_with_rew_and_render_imgs.obs + infos = empty_trajectory_with_rew_and_render_imgs.infos + if request.param == "with_render_images": infos = np.array( - [{"rendered_img": frame} for frame in empty_trajectory_with_rew.obs[1:]], + [{"rendered_img": frame} for frame in empty_trajectory_with_rew_and_render_imgs.obs[1:]], ) elif request.param == "with_render_image_paths": tmp_dir = tempfile.mkdtemp() infos = [] for frame in obs[1:]: - unique_file_path = os.path.join( - tmp_dir, - str(uuid.uuid4()) + ".png", - ) + unique_file_path = str(pathlib.Path(tmp_dir) / (str(uuid.uuid4()) + ".png")) imageio.imwrite(unique_file_path, frame) infos.append({"rendered_img": unique_file_path}) infos = np.array(infos) yield types.TrajectoryWithRew( obs=obs, - acts=empty_trajectory_with_rew.acts, + acts=empty_trajectory_with_rew_and_render_imgs.acts, infos=infos, terminal=True, - rews=empty_trajectory_with_rew.rews, + rews=empty_trajectory_with_rew_and_render_imgs.rews, ) if request.param == "with_render_image_paths": shutil.rmtree(tmp_dir) @@ -1232,8 +1227,8 @@ def fragment(request, empty_trajectory_with_rew): def test_write_fragment_video(fragment, codec): output_dir = "video" video_based_querent = VideoBasedQuerent(video_output_dir=output_dir) - video_path = f"{output_dir}.{codec}" - if isinstance(fragment.obs, types.DictObs): + video_path = pathlib.Path(f"{output_dir}.{codec}") + if "rendered_img" not in fragment.infos[0]: with pytest.raises(ValueError): video_based_querent._write_fragment_video(fragment, output_path=video_path) else: From 7851d0f85b034ad75e21a8546a177d6f09a8edf4 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Tue, 7 May 2024 13:09:53 +0200 Subject: [PATCH 130/143] Refactor Gatherer classes --- .../algorithms/preference_comparisons.py | 167 ++++++------------ .../config/train_preference_comparisons.py | 2 +- .../algorithms/test_preference_comparisons.py | 2 +- 3 files changed, 57 insertions(+), 114 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 953515a33..d47a55a19 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -981,33 +981,44 @@ def __init__( self.logger = custom_logger or imit_logger.configure() self.pending_queries: Dict = {} - @abc.abstractmethod def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: - """Gathers the probabilities that fragment 1 is preferred in `queries`. + """Gathers preference probabilities for completed queries. Returns: - TODO return value - A numpy array with shape (b, ), where b is the length of the input - (i.e. batch size). Each item in the array is the probability that - fragment 1 is preferred over fragment 2 for the corresponding - pair of fragments. + * A list of length b with queries for which preferences have been gathered. + + * A numpy array with shape (b, ). + Each item in the array is the probability that fragment 1 is preferred + over fragment 2 for the corresponding query in the list above. Note that for human feedback, these probabilities are simply 0 or 1 (or 0.5 in case of indifference), but synthetic models may yield other probabilities. """ # noqa: DAR202 - def query(self, queries: Sequence[TrajectoryWithRewPair]) -> None: - identified_queries = self.querent(queries) - self._add(identified_queries) + gathered_queries = [] + gathered_preferences = [] - def _add(self, new_queries: Dict[str, TrajectoryWithRewPair]) -> None: - """Adds queries to pending queries. + for query_id, query in list(self.pending_queries.items()): + preference = self._gather_preference(query_id) - Args: - new_queries: pairs of trajectory fragments - """ - self.pending_queries = {**self.pending_queries, **new_queries} + if preference is not None: + # Preference for this query has been provided + if 0 <= preference <= 1: + gathered_queries.append(query) + gathered_preferences.append(preference) + # else: fragments were incomparable + del self.pending_queries[query_id] + + return gathered_queries, np.array(gathered_preferences, dtype=np.float32) + + @abc.abstractmethod + def _gather_preference(self, query_id: str) -> float: + raise NotImplementedError + + def query(self, queries: Sequence[TrajectoryWithRewPair]) -> None: + identified_queries = self.querent(queries) + self.pending_queries = {**self.pending_queries, **identified_queries} class SyntheticGatherer(PreferenceGatherer): @@ -1057,54 +1068,42 @@ def __init__( if self.sample and self.rng is None: raise ValueError("If `sample` is True, then `rng` must be provided.") - def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: - """Computes probability fragment 1 is preferred over fragment 2.""" - queries = list(self.pending_queries.values()) - # Clear pending queries because the oracle will have answered all - self.pending_queries.clear() + def _gather_preference(self, query_id): + query = self.pending_queries[query_id] - returns1, returns2 = self._reward_sums(queries) + return_1, return_2 = ( + np.array(rollout.discounted_sum(query[0].rews, self.discount_factor), dtype=np.float32), + np.array(rollout.discounted_sum(query[1].rews, self.discount_factor), dtype=np.float32) + ) if self.temperature == 0: - return queries, (np.sign(returns1 - returns2) + 1) / 2 + return (np.sign(return_1 - return_2) + 1) / 2 - returns1 /= self.temperature - returns2 /= self.temperature + return_1 /= self.temperature + return_2 /= self.temperature # clip the returns to avoid overflows in the softmax below - returns_diff = np.clip(returns2 - returns1, -self.threshold, self.threshold) + returns_diff = np.clip(return_2 - return_1, -self.threshold, self.threshold) # Instead of computing exp(rews1) / (exp(rews1) + exp(rews2)) directly, # we divide enumerator and denominator by exp(rews1) to prevent overflows: - choice_probs = 1 / (1 + np.exp(returns_diff)) + choice_probability = 1 / (1 + np.exp(returns_diff)) # Compute the mean binary entropy. This metric helps estimate # how good we can expect the performance of the learned reward # model to be at predicting preferences. entropy = -( - special.xlogy(choice_probs, choice_probs) - + special.xlogy(1 - choice_probs, 1 - choice_probs) + special.xlogy(choice_probability, choice_probability) + + special.xlogy(1 - choice_probability, 1 - choice_probability) ).mean() self.logger.record("entropy", entropy) if self.sample: assert self.rng is not None - return queries, self.rng.binomial(n=1, p=choice_probs).astype(np.float32) - - return queries, choice_probs + return self.rng.binomial(n=1, p=choice_probability).astype(np.float32) - def _reward_sums(self, fragment_pairs) -> Tuple[np.ndarray, np.ndarray]: - rews1, rews2 = zip( - *[ - ( - rollout.discounted_sum(f1.rews, self.discount_factor), - rollout.discounted_sum(f2.rews, self.discount_factor), - ) - for f1, f2 in fragment_pairs - ], - ) - return np.array(rews1, dtype=np.float32), np.array(rews2, dtype=np.float32) + return choice_probability -class SynchronousHumanGatherer(PreferenceGatherer): +class CommandLineGatherer(PreferenceGatherer): """Queries for human preferences using the command line or a notebook.""" def __init__( @@ -1132,29 +1131,11 @@ def __init__( self.video_width = video_width self.video_height = video_height - def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: - """Displays each pair of fragments and asks for a preference. - - It iteratively requests user feedback for each pair of fragments. If in the - command line, it will pop out a video player for each fragment. If in a - notebook, it will display the videos. Either way, it will request 1 or 2 to - indicate which is preferred. - - Returns: - A numpy array of 1 if fragment 1 is preferred and 0 otherwise, with shape - (b, ), where b is the length of `fragment_pairs` - """ - queries = [] - preferences = np.zeros(len(self.pending_queries), dtype=np.float32) - for i, (query_id, query) in enumerate(self.pending_queries.items()): - if self._display_videos_and_gather_preference(query_id): - queries.append(query) - preferences[i] = 1 - self.pending_queries.clear() - return queries, preferences - - def _display_videos_and_gather_preference(self, query_id: uuid.UUID) -> bool: + def _gather_preference(self, query_id): """Displays the videos of the two fragments. + If in the command line, it will pop out a video player for each fragment. + If in a notebook, it will display the videos. + Either way, it will request 1 or 2 to indicate which is preferred. Args: query_id: the id of the fragment pair to be displayed. @@ -1179,11 +1160,11 @@ def _display_videos_and_gather_preference(self, query_id: uuid.UUID) -> bool: pref = input("Please enter 1 or 2 or q or r: ") if pref == "q": - raise KeyboardInterrupt + return None elif pref == "1": - return True + return 1. elif pref == "2": - return False + return 0. # should never be hit assert False @@ -1256,7 +1237,7 @@ def _display_in_windows( ret2, frame2 = cap2.read() elif key == "1" or key == "2": cv2.destroyAllWindows() - return key == "1" + return 1.0 if key == "1" else 0.0 cv2.destroyAllWindows() raise KeyboardInterrupt @@ -1297,54 +1278,16 @@ def _display_videos_in_notebook( def _in_ipython(self) -> bool: try: - return self.is_running_pytest_test() or get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[name-defined] # noqa + return self._is_running_pytest_test() or get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore[name-defined] # noqa except NameError: return False - def is_running_pytest_test(self) -> bool: + @staticmethod + def _is_running_pytest_test() -> bool: return "PYTEST_CURRENT_TEST" in os.environ -class AsynchronousHumanGatherer(PreferenceGatherer, abc.ABC): - def __init__( - self, - wait_for_user: bool = True, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None, - ) -> None: - super().__init__(rng, custom_logger, querent_kwargs) - self.pending_queries = {} - self.wait_for_user = wait_for_user - - def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: - # TODO: create user-independent (automated) waiting policy - if self.wait_for_user: - print("Waiting for user to provide preferences. Press enter to continue.") - input() - - gathered_queries = [] - gathered_preferences = [] - - for query_id, query in list(self.pending_queries.items()): - preference = self._gather_preference(query_id) - - if preference is not None: - # Preference for this query has been provided - if 0 <= preference <= 1: - gathered_queries.append(query) - gathered_preferences.append(preference) - # else: fragments were incomparable - del self.pending_queries[query_id] - - return gathered_queries, np.array(gathered_preferences, dtype=np.float32) - - @abc.abstractmethod - def _gather_preference(self, query_id: str) -> float: - raise NotImplementedError - - -class RESTGatherer(AsynchronousHumanGatherer): +class RESTGatherer(PreferenceGatherer): """Gathers preferences from a REST interface.""" def __init__( diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 065dc6528..308c58937 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -72,7 +72,7 @@ def train_defaults(): @train_preference_comparisons_ex.named_config def synch_human_preferences(): - gatherer_cls = preference_comparisons.SynchronousHumanGatherer + gatherer_cls = preference_comparisons.CommandLineGatherer gatherer_kwargs = dict(video_dir="videos") querent_cls = preference_comparisons.PreferenceQuerent querent_kwargs = dict() diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index de8f949d8..8558b3f37 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1380,7 +1380,7 @@ def test_ignores_incomparable_answer(): @patch("IPython.display.display") def test_synchronous_human_gatherer(mock_display, mock_input): del mock_display # unused - gatherer = preference_comparisons.SynchronousHumanGatherer( + gatherer = preference_comparisons.CommandLineGatherer( video_dir=pathlib.Path("."), ) From df491291257986c9109f4d53b5d1b33489809b22 Mon Sep 17 00:00:00 2001 From: rk1a Date: Thu, 16 May 2024 23:49:09 +0200 Subject: [PATCH 131/143] Add test for video loading method --- .../algorithms/preference_comparisons.py | 111 +++++++++++------- .../algorithms/test_preference_comparisons.py | 73 ++++++++++-- tests/data/test_wrappers.py | 10 +- 3 files changed, 140 insertions(+), 54 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 953515a33..f288b0386 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -5,6 +5,7 @@ """ import abc +import base64 import math import os import pathlib @@ -822,12 +823,12 @@ def __call__( class VideoBasedQuerent(PreferenceQuerent): def __init__( - self, - video_output_dir: str, - video_type="webm", - video_fps: int = 20, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + self, + video_output_dir: Union[str, os.PathLike], + video_type="webm", + video_fps: int = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Initializes the querent. @@ -861,21 +862,26 @@ def _write_query_videos(self, query_id, query): fragment=query[i], output_path=self._create_query_video_path(query_id, alternative), ) - + def _create_query_video_path(self, query_id: str, alternative: str): - return pathlib.Path(self.video_output_dir) / f"{query_id}-{alternative}.{self.video_type}" + return ( + pathlib.Path(self.video_output_dir) + / f"{query_id}-{alternative}.{self.video_type}" + ) def _write_fragment_video( - self, - fragment: TrajectoryWithRew, - output_path: AnyPath, - progress_logger: bool = True, + self, + fragment: TrajectoryWithRew, + output_path: AnyPath, + progress_logger: bool = True, ) -> None: """Write fragment video clip.""" frames = self._get_frames(fragment) self._write(frames, output_path, progress_logger) - def _get_frames(self, fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]: + def _get_frames( + self, fragment: TrajectoryWithRew, + ) -> list[Union[os.PathLike, np.ndarray]]: if self._rendered_image_of_observation_is_available(fragment): return self._get_frames_for_each_observation(fragment) else: @@ -885,23 +891,40 @@ def _get_frames(self, fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np ) @staticmethod - def _rendered_image_of_observation_is_available(fragment: TrajectoryWithRew) -> bool: + def _rendered_image_of_observation_is_available( + fragment: TrajectoryWithRew, + ) -> bool: return fragment.infos is not None and "rendered_img" in fragment.infos[0] - def _get_frames_for_each_observation(self, fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]: + def _get_frames_for_each_observation( + self, fragment: TrajectoryWithRew, + ) -> list[Union[os.PathLike, np.ndarray]]: frames: list[Union[os.PathLike, np.ndarray]] = [] for i in range(len(fragment.infos)): frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] frames.append(frame) return frames - def _write(self, frames: List[Union[os.PathLike, np.ndarray]], output_path: os.PathLike, progress_logger: bool): - clip = ImageSequenceClip(frames, fps=self.frames_per_second) # accepts list of image paths and numpy arrays - if output_path.suffix == '.gif': - clip.write_gif(str(output_path), program='ffmpeg', logger="bar" if progress_logger else None) + def _write( + self, + frames: List[Union[os.PathLike, np.ndarray]], + output_path: Union[str, bytes, os.PathLike], + progress_logger: bool, + ): + clip = ImageSequenceClip( + frames, fps=self.frames_per_second, + ) # accepts list of image paths and numpy arrays + output_path = pathlib.Path(output_path) + if output_path.suffix == ".gif": + clip.write_gif( + str(output_path), + program="ffmpeg", + logger="bar" if progress_logger else None, + ) else: - print(output_path) - clip.write_videofile(str(output_path), logger="bar" if progress_logger else None) + clip.write_videofile( + str(output_path), logger="bar" if progress_logger else None, + ) class RESTQuerent(VideoBasedQuerent): @@ -924,9 +947,10 @@ def __init__( rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ - super().__init__(video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger) + super().__init__( + video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger, + ) self.query_endpoint = collection_service_address + "/preferences/query/" - def __call__( self, @@ -944,14 +968,19 @@ def _query(self, query_id: str): json={"uuid": "{}".format(query_id), **video_data}, ) - def _load_video_data(self, query_id: str) -> dict[str, bytes]: - import base64 + def _load_video_data(self, query_id: str) -> dict[str, str]: video_data = {} for alternative in ("left", "right"): video_path = self._create_query_video_path(query_id, alternative) if video_path.exists(): - with open(video_path, 'rb') as video_file: - video_data[alternative] = base64.b64encode(video_file.read()).decode('utf-8') + with open(video_path, "rb") as video_file: + video_data[alternative] = base64.b64encode( + video_file.read(), + ).decode("utf-8") + else: + raise RuntimeError( + f"Video to be loaded does not exist at {video_path}.", + ) return video_data @@ -1109,7 +1138,7 @@ class SynchronousHumanGatherer(PreferenceGatherer): def __init__( self, - video_dir: pathlib.Path, + video_dir: Union[str, os.PathLike], video_width: int = 500, video_height: int = 500, frames_per_second: int = 25, @@ -1127,7 +1156,9 @@ def __init__( rng: random number generator """ super().__init__(custom_logger=custom_logger, rng=rng) - self.querent = VideoBasedQuerent(video_output_dir=video_dir, video_fps=frames_per_second) + self.querent = VideoBasedQuerent( + video_output_dir=video_dir, video_fps=frames_per_second, + ) self.video_dir = video_dir self.video_width = video_width self.video_height = video_height @@ -1307,11 +1338,11 @@ def is_running_pytest_test(self) -> bool: class AsynchronousHumanGatherer(PreferenceGatherer, abc.ABC): def __init__( - self, - wait_for_user: bool = True, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None, + self, + wait_for_user: bool = True, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None, ) -> None: super().__init__(rng, custom_logger, querent_kwargs) self.pending_queries = {} @@ -1348,12 +1379,12 @@ class RESTGatherer(AsynchronousHumanGatherer): """Gathers preferences from a REST interface.""" def __init__( - self, - collection_service_address: str, - wait_for_user: bool = True, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None, + self, + collection_service_address: str, + wait_for_user: bool = True, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + querent_kwargs: Optional[Mapping] = None, ) -> None: """Initializes the preference gatherer. diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index de8f949d8..408314688 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1,5 +1,7 @@ """Tests for the preference comparisons reward learning implementation.""" +import base64 +import binascii import math import os import pathlib @@ -25,12 +27,12 @@ import imitation.testing.reward_nets as testing_reward_nets from imitation.algorithms import preference_comparisons from imitation.algorithms.preference_comparisons import ( - RESTGatherer, - RESTQuerent, PreferenceGatherer, PreferenceQuerent, - VideoBasedQuerent, + RESTGatherer, + RESTQuerent, SyntheticGatherer, + VideoBasedQuerent, remove_rendered_images, ) from imitation.data import types @@ -1169,21 +1171,71 @@ def test_sends_put_request_for_each_query(requests_mock): @pytest.fixture -def empty_trajectory_with_rew_and_render_imgs(): +def empty_trajectory_with_rew_and_render_imgs() -> TrajectoryWithRew: num_frames = 10 frame_shape = (200, 200) return types.TrajectoryWithRew( obs=np.zeros((num_frames, *frame_shape, 3), np.uint8), acts=np.zeros((num_frames - 1,), np.uint8), - infos=np.array([{"rendered_img": np.zeros((*frame_shape, 3), np.uint8)} for _ in range(num_frames - 1)]), + infos=np.array( + [ + {"rendered_img": np.zeros((*frame_shape, 3), np.uint8)} + for _ in range(num_frames - 1) + ], + ), rews=np.zeros((num_frames - 1,)), terminal=True, ) -def test_prefcollectquerent_call_creates_all_videos(empty_trajectory_with_rew_and_render_imgs): +def is_base64(data): + """Checks if the data is base64 encoded.""" + try: + base64.b64decode(data) + return True + except binascii.Error: + return False + + +def test_load_video_data(empty_trajectory_with_rew_and_render_imgs): + address = "https://test.de" + video_dir = "video" + querent = RESTQuerent( + collection_service_address=address, video_output_dir=video_dir, + ) + + # Setup query with saved videos + query_id = "1234" + frames = querent._get_frames_for_each_observation( + empty_trajectory_with_rew_and_render_imgs, + ) + output_path = querent._create_query_video_path(query_id, "left") + querent._write(frames, output_path, progress_logger=False) + output_path = querent._create_query_video_path(query_id, "right") + querent._write(frames, output_path, progress_logger=False) + + # Load videos and check their encoding + video_data = querent._load_video_data(query_id) + for alternative in ("left", "right"): + assert is_base64(video_data[alternative]) + + # Check that loading video data of non-existent query fails + with pytest.raises(RuntimeError): + querent._load_video_data("0") + + shutil.rmtree(video_dir) + + +def test_prefcollectquerent_call_creates_all_videos( + empty_trajectory_with_rew_and_render_imgs, +): address = "https://test.de" - queries = [(empty_trajectory_with_rew_and_render_imgs, empty_trajectory_with_rew_and_render_imgs)] + queries = [ + ( + empty_trajectory_with_rew_and_render_imgs, + empty_trajectory_with_rew_and_render_imgs, + ), + ] querent = RESTQuerent(collection_service_address=address, video_output_dir="video") identified_queries = querent(queries) for query_id, _ in identified_queries.items(): @@ -1201,7 +1253,10 @@ def fragment(request, empty_trajectory_with_rew_and_render_imgs): infos = empty_trajectory_with_rew_and_render_imgs.infos if request.param == "with_render_images": infos = np.array( - [{"rendered_img": frame} for frame in empty_trajectory_with_rew_and_render_imgs.obs[1:]], + [ + {"rendered_img": frame} + for frame in empty_trajectory_with_rew_and_render_imgs.obs[1:] + ], ) elif request.param == "with_render_image_paths": tmp_dir = tempfile.mkdtemp() @@ -1247,7 +1302,7 @@ def test_remove_rendered_images(fragment): class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" - def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + def gather(self) -> np.ndarray: return np.zeros(shape=(1,)) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 82133fd65..119ea609b 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,10 +1,10 @@ """Tests for `imitation.data.wrappers`.""" +from pathlib import Path from typing import List, Sequence, Type -import imageio -from pathlib import Path import gymnasium as gym +import imageio import numpy as np import pytest from stable_baselines3.common.vec_env import DummyVecEnv @@ -281,7 +281,8 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() -@pytest.mark.parametrize("scale_factor", [0.1, 0.5, 1.]) + +@pytest.mark.parametrize("scale_factor", [0.1, 0.5, 1.0]) def test_writes_rendered_img_to_info(scale_factor): env = gym.make("CartPole-v0", render_mode="rgb_array") wrapped_env = RenderImageInfoWrapper(env, scale_factor=scale_factor) @@ -290,7 +291,7 @@ def test_writes_rendered_img_to_info(scale_factor): _, _, _, _, info = wrapped_env.step(wrapped_env.action_space.sample()) assert "rendered_img" in info assert isinstance(info["rendered_img"], np.ndarray) - if scale_factor == 1.: + if scale_factor == 1.0: assert np.allclose(info["rendered_img"], rendered_img) assert int(scale_factor * rendered_img.shape[0]) == info["rendered_img"].shape[0] assert int(scale_factor * rendered_img.shape[1]) == info["rendered_img"].shape[1] @@ -320,4 +321,3 @@ def test_rendered_img_file_cache(): assert (imageio.imread(rendered_img_path) == wrapped_env.render()).all() wrapped_env.close() assert not Path(wrapped_env.file_cache).exists() - From 8a6e210404060daafc69f72f99b5f4557f39814f Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 10:56:31 +0200 Subject: [PATCH 132/143] Fix bug in SyntheticGatherer --- src/imitation/algorithms/preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index d47a55a19..907babdd1 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1098,7 +1098,7 @@ def _gather_preference(self, query_id): if self.sample: assert self.rng is not None - return self.rng.binomial(n=1, p=choice_probability).astype(np.float32) + return self.rng.binomial(n=1, p=choice_probability) return choice_probability From f934886a72f07b7dc0664849b0e0f2c3887c8e76 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 10:57:08 +0200 Subject: [PATCH 133/143] Adapt ConcretePreferenceGatherer --- tests/algorithms/test_preference_comparisons.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 8558b3f37..26754571f 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1247,8 +1247,8 @@ def test_remove_rendered_images(fragment): class ConcretePreferenceGatherer(PreferenceGatherer): """A concrete preference gatherer for unit testing purposes only.""" - def gather(self) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: - return np.zeros(shape=(1,)) + def _gather_preference(self, query_id) -> Tuple[Sequence[TrajectoryWithRewPair], np.ndarray]: + return 0. def test_adds_queries_to_pending_queries(): @@ -1365,7 +1365,6 @@ def test_ignores_incomparable_answer(): wait_for_user=False, querent_kwargs={"video_output_dir": "videos"}, ) - # incomparable preference value = -1 gatherer._gather_preference = MagicMock(return_value=-1.0) gatherer.pending_queries = {"1234": Mock()} From 040af8b067f767939e26e63ed821f175f019ba9a Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 11:12:13 +0200 Subject: [PATCH 134/143] Improve variable naming --- tests/algorithms/test_preference_comparisons.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 26754571f..78e3df5b8 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1238,9 +1238,9 @@ def test_write_fragment_video(fragment, codec): def test_remove_rendered_images(fragment): - trajs = [fragment] - remove_rendered_images(trajs) - assert not any("rendered_img" in info for traj in trajs for info in traj.infos) + trajectories = [fragment] + remove_rendered_images(trajectories) + assert not any("rendered_img" in info for trajectory in trajectories for info in trajectory.infos) # PreferenceGatherer From 891e7f7ecbb8e113aa609f52e67497d1f896d106 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 14:16:09 +0200 Subject: [PATCH 135/143] Rename test for CommandLineGatherer --- tests/algorithms/test_preference_comparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 78e3df5b8..3f33f9420 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1377,7 +1377,7 @@ def test_ignores_incomparable_answer(): # SynchronousHumanGatherer @patch("builtins.input") @patch("IPython.display.display") -def test_synchronous_human_gatherer(mock_display, mock_input): +def test_command_line_gatherer(mock_display, mock_input, fragment): del mock_display # unused gatherer = preference_comparisons.CommandLineGatherer( video_dir=pathlib.Path("."), From 4b9d1f9084ccf4e98bded2e3882734578132c846 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 14:16:15 +0200 Subject: [PATCH 136/143] Fix test --- .../algorithms/test_preference_comparisons.py | 51 +------------------ 1 file changed, 1 insertion(+), 50 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 3f33f9420..9d8fcd340 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1384,56 +1384,7 @@ def test_command_line_gatherer(mock_display, mock_input, fragment): ) # these inputs are designed solely to pass the test. they aren't tested for anything - trajectory_pairs = [ - ( - types.TrajectoryWithRew( - np.zeros( - ( - 2, - 200, - 200, - 3, - ), - np.uint8, - ), - np.array([1]), - np.array( - [ - { - "video_path": pathlib.Path( - "tests/algorithms/test_preference_comparisons.py", - ), - }, - ], - ), - True, - np.array([1.0]), - ), - types.TrajectoryWithRew( - np.zeros( - ( - 2, - 200, - 200, - 3, - ), - np.uint8, - ), - np.array([1]), # act - np.array( # info - [ - { - "video_path": pathlib.Path( - "tests/algorithms/test_preference_comparisons.py", - ), - }, - ], - ), - True, # done - np.array([1.0]), # reward - ), - ), - ] + trajectory_pairs = [(fragment, fragment),] gatherer.query(trajectory_pairs) # this is the actual test From 2c632d8ce193c4243d6026d2e6f1cf6427b1cacc Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 15:05:57 +0200 Subject: [PATCH 137/143] Add documentation --- src/imitation/algorithms/preference_comparisons.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 907babdd1..eda3de030 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -821,6 +821,8 @@ def __call__( class VideoBasedQuerent(PreferenceQuerent): + """Writes videos for each query to the local file system for later use by child querent (and gatherer) classes.""" + def __init__( self, video_output_dir: str, From eebb31dcb44f41e042bb9d7a540e6edf061ee5ee Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 15:07:35 +0200 Subject: [PATCH 138/143] Add documentation --- src/imitation/algorithms/preference_comparisons.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index eda3de030..f53990496 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -835,6 +835,7 @@ def __init__( Args: video_output_dir: path to the video clip directory. + video_type: specifies the video format, e.g. 'webm' or 'mp4' video_fps: frames per second of the generated videos. rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. From f2f722a31d2d2a280be0712213eeb616f9940789 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 15:18:52 +0200 Subject: [PATCH 139/143] Make method static --- src/imitation/algorithms/preference_comparisons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index f53990496..54127850b 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -891,7 +891,8 @@ def _get_frames(self, fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np def _rendered_image_of_observation_is_available(fragment: TrajectoryWithRew) -> bool: return fragment.infos is not None and "rendered_img" in fragment.infos[0] - def _get_frames_for_each_observation(self, fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]: + @staticmethod + def _get_frames_for_each_observation(fragment: TrajectoryWithRew) -> list[Union[os.PathLike, np.ndarray]]: frames: list[Union[os.PathLike, np.ndarray]] = [] for i in range(len(fragment.infos)): frame: Union[os.PathLike, np.ndarray] = fragment.infos[i]["rendered_img"] From 4c16550406dbae37b653f547d38688860592cd82 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 15:19:04 +0200 Subject: [PATCH 140/143] Remove whitespace --- src/imitation/algorithms/preference_comparisons.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 54127850b..e8a8427a3 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -930,7 +930,6 @@ def __init__( """ super().__init__(video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger) self.query_endpoint = collection_service_address + "/preferences/query/" - def __call__( self, From 7b2a260c6d137f1f58b4a0aa6018ec373a3788c5 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Fri, 17 May 2024 17:52:10 +0200 Subject: [PATCH 141/143] Refine rest interface and add documentation --- .../algorithms/preference_comparisons.py | 48 ++++++++++++++----- .../config/train_preference_comparisons.py | 2 +- .../algorithms/test_preference_comparisons.py | 32 ++++++------- tests/scripts/test_scripts.py | 2 +- 4 files changed, 54 insertions(+), 30 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index c163f6102..49649d58f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -12,6 +12,7 @@ import pickle import re import uuid +from urllib.parse import urljoin from collections import defaultdict from typing import ( Any, @@ -825,12 +826,12 @@ class VideoBasedQuerent(PreferenceQuerent): """Writes videos for each query to the local file system for later use by child querent (and gatherer) classes.""" def __init__( - self, - video_output_dir: Union[str, os.PathLike], - video_type="webm", - video_fps: int = 20, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + self, + video_output_dir: Union[str, os.PathLike], + video_type="webm", + video_fps: int = 20, + rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Initializes the querent. @@ -930,7 +931,18 @@ def _write( class RESTQuerent(VideoBasedQuerent): - """Sends queries to a REST web service.""" + """Sends queries to a REST web service. + + The queries are sent via PUT request as json payload to `collection_service_address`/`query_id` + in the following form: + + { + "uuid": "1234", + "left": b64 encoded video, + "right" b64 encoded video + } + + """ def __init__( self, @@ -952,7 +964,7 @@ def __init__( super().__init__( video_output_dir, video_fps=video_fps, rng=rng, custom_logger=custom_logger, ) - self.query_endpoint = collection_service_address + "/preferences/query/" + self.query_endpoint = collection_service_address def __call__( self, @@ -966,7 +978,7 @@ def __call__( def _query(self, query_id: str): video_data = self._load_video_data(query_id) requests.put( - self.query_endpoint + query_id, + urljoin(self.query_endpoint, query_id), json={"uuid": "{}".format(query_id), **video_data}, ) @@ -1321,7 +1333,18 @@ def _is_running_pytest_test() -> bool: class RESTGatherer(PreferenceGatherer): - """Gathers preferences from a REST interface.""" + """Gathers preferences from a REST web service. + + The queries are gathered via GET request to `collection_service_address`/`query_id` + and returns a preference as json payload: + + { + "label": float + } + + The float value ranges from 0.0 (preferring left) to 1.0 (preferring right), + with 0.5 indicating indifference. -1.0 indicates that the query was incomparable. + """ def __init__( self, @@ -1346,10 +1369,11 @@ def __init__( collection_service_address=collection_service_address, **querent_kwargs, ) - self.query_endpoint = collection_service_address + "/preferences/query/" + self.query_endpoint = collection_service_address def _gather_preference(self, query_id: str) -> float: - answered_query = requests.get(self.query_endpoint + query_id).json() + query_url = urljoin(self.query_endpoint, query_id) + answered_query = requests.get(query_url).json() return answered_query["label"] diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 308c58937..33a417cbb 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -97,7 +97,7 @@ def human_preferences(): collection_service_address="http://127.0.0.1:8000", wait_for_user=True, querent_kwargs=dict( - video_output_dir="../pref_collect/videofiles", + video_output_dir="./videos", video_fps=20, ), ) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 21603a471..fa45981f4 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1161,17 +1161,17 @@ def test_returned_queries_have_uuid(): def test_sends_put_request_for_each_query(requests_mock): address = "https://test.de" querent = RESTQuerent(collection_service_address=address, video_output_dir="video") + querent._load_video_data = MagicMock() query_id = "1234" - requests_mock.put(f"{address}/preferences/query/{query_id}") + requests_mock.put(f"{address}/{query_id}") querent._query(query_id) assert requests_mock.last_request.method == "PUT" - assert requests_mock.last_request.text == f'{{"uuid": "{query_id}"}}' @pytest.fixture -def empty_trajectory_with_rew_and_render_imgs() -> TrajectoryWithRew: +def empty_trajectory_with_rew_and_render_images() -> TrajectoryWithRew: num_frames = 10 frame_shape = (200, 200) return types.TrajectoryWithRew( @@ -1197,7 +1197,7 @@ def is_base64(data): return False -def test_load_video_data(empty_trajectory_with_rew_and_render_imgs): +def test_load_video_data(empty_trajectory_with_rew_and_render_images): address = "https://test.de" video_dir = "video" querent = RESTQuerent( @@ -1207,7 +1207,7 @@ def test_load_video_data(empty_trajectory_with_rew_and_render_imgs): # Setup query with saved videos query_id = "1234" frames = querent._get_frames_for_each_observation( - empty_trajectory_with_rew_and_render_imgs, + empty_trajectory_with_rew_and_render_images, ) output_path = querent._create_query_video_path(query_id, "left") querent._write(frames, output_path, progress_logger=False) @@ -1227,13 +1227,13 @@ def test_load_video_data(empty_trajectory_with_rew_and_render_imgs): def test_prefcollectquerent_call_creates_all_videos( - empty_trajectory_with_rew_and_render_imgs, + empty_trajectory_with_rew_and_render_images, ): address = "https://test.de" queries = [ ( - empty_trajectory_with_rew_and_render_imgs, - empty_trajectory_with_rew_and_render_imgs, + empty_trajectory_with_rew_and_render_images, + empty_trajectory_with_rew_and_render_images, ), ] querent = RESTQuerent(collection_service_address=address, video_output_dir="video") @@ -1248,14 +1248,14 @@ def test_prefcollectquerent_call_creates_all_videos( @pytest.fixture( params=["obs_only", "with_render_images", "with_render_image_paths"], ) -def fragment(request, empty_trajectory_with_rew_and_render_imgs): - obs = empty_trajectory_with_rew_and_render_imgs.obs - infos = empty_trajectory_with_rew_and_render_imgs.infos +def fragment(request, empty_trajectory_with_rew_and_render_images): + obs = empty_trajectory_with_rew_and_render_images.obs + infos = empty_trajectory_with_rew_and_render_images.infos if request.param == "with_render_images": infos = np.array( [ {"rendered_img": frame} - for frame in empty_trajectory_with_rew_and_render_imgs.obs[1:] + for frame in empty_trajectory_with_rew_and_render_images.obs[1:] ], ) elif request.param == "with_render_image_paths": @@ -1268,10 +1268,10 @@ def fragment(request, empty_trajectory_with_rew_and_render_imgs): infos = np.array(infos) yield types.TrajectoryWithRew( obs=obs, - acts=empty_trajectory_with_rew_and_render_imgs.acts, + acts=empty_trajectory_with_rew_and_render_images.acts, infos=infos, terminal=True, - rews=empty_trajectory_with_rew_and_render_imgs.rews, + rews=empty_trajectory_with_rew_and_render_images.rews, ) if request.param == "with_render_image_paths": shutil.rmtree(tmp_dir) @@ -1338,7 +1338,7 @@ def test_returns_none_for_unanswered_query(requests_mock): ) requests_mock.get( - f"{address}/preferences/query/{query_id}", + f"{address}/{query_id}", json={"query_id": query_id, "label": answer}, ) @@ -1358,7 +1358,7 @@ def test_returns_preference_for_answered_query(requests_mock): ) requests_mock.get( - f"{address}/preferences/query/{query_id}", + f"{address}/{query_id}", json={"query_id": query_id, "label": answer}, ) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 0fa60ae2b..493d307b5 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -187,7 +187,7 @@ def test_train_preference_comparisons_with_collected_preferences(tmpdir): ) with requests_mock.Mocker() as m: - request_matcher = re.compile(f"{address}/preferences/query/") + request_matcher = re.compile(f"{address}/") m.put(url=request_matcher) m.get( From 38c7df9d796106f40576584a5aaebae1643311f4 Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Tue, 28 May 2024 15:36:03 +0200 Subject: [PATCH 142/143] Fix bug and integration test --- .../algorithms/preference_comparisons.py | 16 +++++----------- tests/scripts/test_scripts.py | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 49649d58f..d4500f8ea 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -1001,12 +1001,8 @@ def _load_video_data(self, query_id: str) -> dict[str, str]: class PreferenceGatherer(abc.ABC): """Base class for gathering preference comparisons between trajectory fragments.""" - def __init__( - self, - rng: Optional[np.random.Generator] = None, - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - querent_kwargs: Optional[Mapping] = None, - ) -> None: + def __init__(self, rng: Optional[np.random.Generator] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None) -> None: """Initializes the preference gatherer. Args: @@ -1018,9 +1014,7 @@ def __init__( # as an argument nevertheless because that means we can always # pass in a seed in training scripts (without worrying about whether # the PreferenceGatherer we use needs one). - del rng - querent_kwargs = querent_kwargs or {} - self.querent = PreferenceQuerent(**querent_kwargs) + self.querent = PreferenceQuerent(rng, custom_logger) self.logger = custom_logger or imit_logger.configure() self.pending_queries: Dict = {} @@ -1168,7 +1162,7 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. rng: random number generator """ - super().__init__(custom_logger=custom_logger, rng=rng) + super().__init__(rng=rng, custom_logger=custom_logger) self.querent = VideoBasedQuerent( video_output_dir=video_dir, video_fps=frames_per_second, ) @@ -1363,7 +1357,7 @@ def __init__( custom_logger: Where to log to; if None (default), creates a new logger. querent_kwargs: Keyword arguments passed to the querent. """ - super().__init__(wait_for_user, rng, custom_logger) + super().__init__(rng, custom_logger) querent_kwargs = querent_kwargs if querent_kwargs else {} self.querent = RESTQuerent( collection_service_address=collection_service_address, diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 493d307b5..e4f0cade3 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -182,7 +182,7 @@ def test_train_preference_comparisons_with_collected_preferences(tmpdir): gatherer_kwargs=dict( wait_for_user=False, querent_kwargs=dict(video_output_dir=tmpdir), - pref_collect_address=address, + collection_service_address=address, ), ) From fcfa92a1e84b6ac2cb86ecf7bb02fb70435dd15e Mon Sep 17 00:00:00 2001 From: "marvin.schweizer" Date: Tue, 28 May 2024 15:36:25 +0200 Subject: [PATCH 143/143] Delete videos after test --- tests/algorithms/test_preference_comparisons.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index fa45981f4..c5a72fd94 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -1434,8 +1434,9 @@ def test_ignores_incomparable_answer(): @patch("IPython.display.display") def test_command_line_gatherer(mock_display, mock_input, fragment): del mock_display # unused + video_dir = "videos" gatherer = preference_comparisons.CommandLineGatherer( - video_dir=pathlib.Path("."), + video_dir=pathlib.Path(video_dir), ) # these inputs are designed solely to pass the test. they aren't tested for anything @@ -1449,3 +1450,5 @@ def test_command_line_gatherer(mock_display, mock_input, fragment): gatherer.query(trajectory_pairs) mock_input.return_value = "2" assert gatherer.gather()[1] == np.array([0.0]) + + shutil.rmtree(video_dir)