diff --git a/amago/cli_utils.py b/amago/cli_utils.py index 1c30104..7c99156 100644 --- a/amago/cli_utils.py +++ b/amago/cli_utils.py @@ -73,7 +73,7 @@ def add_common_cli(parser: ArgumentParser) -> ArgumentParser: ) # main learning schedule parser.add_argument( - "--grads_per_epoch", + "--batches_per_epoch", type=int, default=1000, help="Gradient updates per training epoch.", @@ -82,7 +82,7 @@ def add_common_cli(parser: ArgumentParser) -> ArgumentParser: "--timesteps_per_epoch", type=int, default=1000, - help="Timesteps of environment interaction per epoch *per actor*. The update:data ratio is defined by `grads_per_epoch / (timesteps_per_epoch * parallel_actors)`.", + help="Timesteps of environment interaction per epoch *per actor*. The update:data ratio is defined by `batches_per_epoch / (timesteps_per_epoch * parallel_actors)`.", ) parser.add_argument( "--val_interval", @@ -326,7 +326,7 @@ def create_experiment_from_cli( epochs=cli.epochs, parallel_actors=cli.parallel_actors, train_timesteps_per_epoch=cli.timesteps_per_epoch, - train_grad_updates_per_epoch=cli.grads_per_epoch, + train_batches_per_epoch=cli.batches_per_epoch, start_learning_at_epoch=cli.start_learning_at_epoch, val_interval=cli.val_interval, ckpt_interval=cli.ckpt_interval, @@ -353,7 +353,7 @@ def make_experiment_learn_only(experiment: amago.Experiment) -> amago.Experiment def make_experiment_collect_only(experiment: amago.Experiment) -> amago.Experiment: experiment.start_collecting_at_epoch = 0 experiment.start_learning_at_epoch = float("inf") - experiment.train_grad_updates_per_epoch = 0 + experiment.train_batches_per_epoch = 0 experiment.val_checks_per_epoch = 0 experiment.ckpt_interval = None experiment.always_save_latest = False diff --git a/amago/envs/builtin/babyai.py b/amago/envs/builtin/babyai.py index 9019ac2..8f95a74 100644 --- a/amago/envs/builtin/babyai.py +++ b/amago/envs/builtin/babyai.py @@ -10,7 +10,11 @@ from amago.envs import AMAGOEnv from amago.hindsight import GoalSeq -from amago.envs.env_utils import space_convert, DiscreteActionWrapper +from amago.envs.env_utils import ( + space_convert, + DiscreteActionWrapper, + AMAGO_ENV_LOG_PREFIX, +) BANNED_BABYAI_TASKS = [ @@ -181,6 +185,9 @@ def step(self, action): next_obs, reward, terminated, truncated, info = self.env.step(action) done = False if terminated or truncated: + info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.current_episode} Success"] = ( + reward > 0.0 + ) self.current_episode += 1 if self.current_episode > self.k_episodes: done = True diff --git a/amago/envs/builtin/gym_envs.py b/amago/envs/builtin/gym_envs.py index 6649589..9fc118b 100644 --- a/amago/envs/builtin/gym_envs.py +++ b/amago/envs/builtin/gym_envs.py @@ -6,6 +6,7 @@ import numpy as np from amago.envs import AMAGOEnv +from amago.envs.env_utils import AMAGO_ENV_LOG_PREFIX from amago.hindsight import GoalSeq @@ -130,11 +131,11 @@ def __init__( def reset(self, *args, **kwargs): self.current_map = [[t for t in row] for row in generate_random_map(self.size)] - self.action_mapping = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)] + act_map = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)] if self.hard_mode and random.random() < 0.5: - temp = self.action_mapping[1] - self.action_mapping[1] = self.action_mapping[3] - self.action_mapping[3] = temp + # flip two control directions for an extra challenge + act_map[1], act_map[3] = act_map[3], act_map[1] + self.action_mapping = act_map self.current_k = 0 return self.soft_reset() @@ -181,13 +182,14 @@ def step(self, action): else: reward = 0.0 soft_reset = False - self.x = next_x self.y = next_y if soft_reset: - self.current_k += 1 next_state, info = self.soft_reset() + success = on == "G" + info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.current_k} Success"] = success + self.current_k += 1 else: next_state, info = self.make_obs(False), {} diff --git a/amago/envs/builtin/tmaze.py b/amago/envs/builtin/tmaze.py index e613b90..f255d28 100644 --- a/amago/envs/builtin/tmaze.py +++ b/amago/envs/builtin/tmaze.py @@ -58,9 +58,9 @@ def __init__( ) self.bias_x, self.bias_y = 1, 2 self.tmaze_map[self.bias_y, self.bias_x : -self.bias_x] = True # corridor - self.tmaze_map[[self.bias_y - 1, self.bias_y + 1], -self.bias_x - 1] = ( - True # goal candidates - ) + self.tmaze_map[ + [self.bias_y - 1, self.bias_y + 1], -self.bias_x - 1 + ] = True # goal candidates obs_dim = 2 if self.ambiguous_position else 3 if self.expose_goal: # test Markov policies diff --git a/amago/envs/env_utils.py b/amago/envs/env_utils.py index cf185e1..be0dd61 100644 --- a/amago/envs/env_utils.py +++ b/amago/envs/env_utils.py @@ -6,7 +6,7 @@ from uuid import uuid4 from dataclasses import dataclass from collections import defaultdict -from typing import Optional, Type, Callable, Iterable +from typing import Optional, Type, Callable, Iterable, Any import gymnasium as gym import numpy as np @@ -327,7 +327,36 @@ def add_score(self, env_name, score): SuccessHistory = ReturnHistory +class SpecialMetricHistory: + log_prefix = "AMAGO_LOG_METRIC" + + def __init__(self, env_name): + self.data = {} + + def add_score(self, env_name: str, key: str, value: Any): + if key.startswith(self.log_prefix): + key = key[len(self.log_prefix) :].strip() + if env_name not in self.data: + self.data[env_name] = {} + if key not in self.data[env_name]: + self.data[env_name][key] = [value] + else: + self.data[env_name][key].append(value) + + +AMAGO_ENV_LOG_PREFIX = SpecialMetricHistory.log_prefix + + class SequenceWrapper(gym.Wrapper): + """ + A wrapper that handles automatic resets, saving trajectory files to disk, + and rollout metrics. Automatically logs total return in all envs. + When using the goal-conditioned / relabeling system, this will + log meaningful total success rates. We also log any metric from the + gym env's `info` dict that begins with "AMAGO_LOG_METRIC" + (`amago.envs.env_utils.AMAGO_ENV_LOG_PREFIX`). + """ + def __init__( self, env, @@ -388,6 +417,7 @@ def reset_stats(self): # stores all of the success/return histories self.return_history = ReturnHistory(self.env_name) self.success_history = SuccessHistory(self.env_name) + self.special_history = SpecialMetricHistory(self.env_name) def reset(self, seed=None) -> Timestep: timestep, info = self.env.reset(seed=seed) @@ -407,6 +437,10 @@ def step(self, action): self.total_return += reward self.active_traj.add_timestep(timestep) self.since_last_save += 1 + for info_key, info_val in info.items(): + if info_key.startswith(self.special_history.log_prefix): + self.special_history.add_score(self.env.env_name, info_key, info_val) + if timestep.terminal: self.return_history.add_score(self.env.env_name, self.total_return) success = ( @@ -415,6 +449,7 @@ def step(self, action): else info["success"] ) self.success_history.add_score(self.env.env_name, success) + save = ( self.save_every is not None and self.since_last_save > self.save_this_time ) diff --git a/amago/learning.py b/amago/learning.py index ac67f80..70ef456 100644 --- a/amago/learning.py +++ b/amago/learning.py @@ -23,6 +23,7 @@ from amago.envs.env_utils import ( ReturnHistory, SuccessHistory, + SpecialMetricHistory, ExplorationWrapper, EpsilonGreedy, SequenceWrapper, @@ -73,7 +74,7 @@ class Experiment: start_learning_at_epoch: int = 0 start_collecting_at_epoch: int = 0 train_timesteps_per_epoch: int = 1000 - train_grad_updates_per_epoch: int = 1000 + train_batches_per_epoch: int = 1000 val_interval: Optional[int] = 10 val_timesteps_per_epoch: int = 10_000 log_interval: int = 250 @@ -300,7 +301,9 @@ def init_dsets(self): dset_root=self.dset_root, dset_name=self.dset_name, dset_split="train", - items_per_epoch=self.train_grad_updates_per_epoch * self.batch_size, + items_per_epoch=self.train_batches_per_epoch + * self.batch_size + * self.accelerator.num_processes, max_seq_len=self.max_seq_len, ) @@ -463,16 +466,24 @@ def get_t(_dones=None): return_history = utils.call_async_env(envs, "return_history") success_history = utils.call_async_env(envs, "success_history") + special_history = utils.call_async_env(envs, "special_history") return ( (obs_seqs, goal_seqs, rl2_seqs), hidden_state, return_history, success_history, + special_history, ) def collect_new_training_data(self): if self.train_timesteps_per_epoch > 0: - self.train_buffers, self.hidden_state, returns, successes = self.interact( + ( + self.train_buffers, + self.hidden_state, + returns, + successes, + specials, + ) = self.interact( self.train_envs, self.train_timesteps_per_epoch, buffers=self.train_buffers, @@ -480,11 +491,13 @@ def collect_new_training_data(self): def evaluate_val(self): if self.val_timesteps_per_epoch > 0: - *_, returns, successes = self.interact( + *_, returns, successes, specials = self.interact( self.val_envs, self.val_timesteps_per_epoch, ) - logs_per_process = self.policy_metrics(returns, successes) + logs_per_process = self.policy_metrics( + returns, successes, specials=specials + ) cur_return = logs_per_process["Average Total Return (Across All Env Names)"] if self.verbose: self.accelerator.print(f"Average Return : {cur_return}") @@ -499,12 +512,12 @@ def evaluate_test( ) Par = gym.vector.AsyncVectorEnv if self.async_envs else DummyAsyncVectorEnv test_envs = Par([make for _ in range(self.parallel_actors)]) - *_, returns, successes = self.interact( + *_, returns, successes, specials = self.interact( test_envs, timesteps, render=render, ) - logs = self.policy_metrics(returns, successes) + logs = self.policy_metrics(returns, successes, specials) self.log(logs, key="test") test_envs.close() return logs @@ -544,35 +557,46 @@ def make_figures(self, loss_info) -> dict[str, wandb.Image]: """ return {} - def policy_metrics(self, returns: ReturnHistory, successes: SuccessHistory): - return_by_env_name = {} - success_by_env_name = {} - for ret, suc in zip(returns, successes): + def policy_metrics( + self, + returns: ReturnHistory, + successes: SuccessHistory, + specials: SpecialMetricHistory, + ) -> dict: + returns_by_env_name = defaultdict(list) + success_by_env_name = defaultdict(list) + specials_by_env_name = defaultdict(lambda: defaultdict(list)) + + for ret, suc, spe in zip(returns, successes, specials): for env_name, scores in ret.data.items(): - if env_name in return_by_env_name: - return_by_env_name[env_name] += scores - else: - return_by_env_name[env_name] = scores + returns_by_env_name[env_name].extend(scores) for env_name, scores in suc.data.items(): - if env_name in success_by_env_name: - success_by_env_name[env_name] += scores - else: - success_by_env_name[env_name] = scores + success_by_env_name[env_name].extend(scores) + for env_name, specials_dict in spe.data.items(): + for special_key, special_val in specials_dict.items(): + specials_by_env_name[env_name][special_key].extend(special_val) avg_ret_per_env = { f"Average Total Return in {name}": np.array(scores).mean() - for name, scores in return_by_env_name.items() + for name, scores in returns_by_env_name.items() } avg_suc_per_env = { f"Average Success Rate in {name}": np.array(scores).mean() for name, scores in success_by_env_name.items() } + avg_special_per_env = { + f"Average {special_key} in {name}": np.array(special_vals).mean() + for name, specials_dict in specials_by_env_name.items() + for special_key, special_vals in specials_dict.items() + } avg_return_overall = { "Average Total Return (Across All Env Names)": np.array( list(avg_ret_per_env.values()) ).mean() } - return avg_ret_per_env | avg_suc_per_env | avg_return_overall + return ( + avg_ret_per_env | avg_suc_per_env | avg_return_overall | avg_special_per_env + ) def compute_loss(self, batch: Batch, log_step: bool): critic_loss, actor_loss = self.policy_aclr(batch, log_step=log_step) @@ -637,7 +661,7 @@ def make_pbar(loader, epoch_num): return tqdm( enumerate(loader), desc=f"{self.run_name} Epoch {epoch_num} Train", - total=self.train_grad_updates_per_epoch, + total=self.train_batches_per_epoch, colour="green", ) else: @@ -667,7 +691,7 @@ def make_pbar(loader, epoch_num): continue self.policy_aclr.train() for train_step, batch in make_pbar(self.train_dloader, epoch): - total_step = (epoch * self.train_grad_updates_per_epoch) + train_step + total_step = (epoch * self.train_batches_per_epoch) + train_step log_step = total_step % self.log_interval == 0 loss_dict = self.train_step(batch, log_step=log_step) if log_step: diff --git a/examples/00_kshot_frozen_lake.py b/examples/00_kshot_frozen_lake.py index 9e7e072..a699852 100644 --- a/examples/00_kshot_frozen_lake.py +++ b/examples/00_kshot_frozen_lake.py @@ -14,7 +14,6 @@ def add_cli(parser): ) parser.add_argument("--run_name", type=str, required=True) parser.add_argument("--buffer_dir", type=str, required=True) - parser.add_argument("--gpu", type=int, required=True) parser.add_argument("--log", action="store_true") parser.add_argument("--trials", type=int, default=3) parser.add_argument("--lake_size", type=int, default=5) @@ -87,10 +86,10 @@ def add_cli(parser): make_val_env=make_env, max_seq_len=args.max_rollout_length, traj_save_len=args.max_rollout_length, + agent_type=amago.agent.Agent, dset_max_size=10_000, run_name=run_name, dset_name=run_name, - gpu=args.gpu, dset_root=args.buffer_dir, dloader_workers=10, log_to_wandb=args.log, @@ -98,7 +97,7 @@ def add_cli(parser): epochs=500 if not args.hard_mode else 900, parallel_actors=24, train_timesteps_per_epoch=350, - train_grad_updates_per_epoch=700, + train_batches_per_epoch=700, val_interval=20, val_timesteps_per_epoch=args.max_rollout_length * 2, ckpt_interval=50,