Skip to content

Commit

Permalink
Merge pull request #58 from UT-Austin-RPL/fix
Browse files Browse the repository at this point in the history
Fix Git Error
  • Loading branch information
jakegrigsby authored Oct 8, 2024
2 parents 02d94f6 + 7014641 commit 15a6f76
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 41 deletions.
8 changes: 4 additions & 4 deletions amago/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def add_common_cli(parser: ArgumentParser) -> ArgumentParser:
)
# main learning schedule
parser.add_argument(
"--grads_per_epoch",
"--batches_per_epoch",
type=int,
default=1000,
help="Gradient updates per training epoch.",
Expand All @@ -82,7 +82,7 @@ def add_common_cli(parser: ArgumentParser) -> ArgumentParser:
"--timesteps_per_epoch",
type=int,
default=1000,
help="Timesteps of environment interaction per epoch *per actor*. The update:data ratio is defined by `grads_per_epoch / (timesteps_per_epoch * parallel_actors)`.",
help="Timesteps of environment interaction per epoch *per actor*. The update:data ratio is defined by `batches_per_epoch / (timesteps_per_epoch * parallel_actors)`.",
)
parser.add_argument(
"--val_interval",
Expand Down Expand Up @@ -326,7 +326,7 @@ def create_experiment_from_cli(
epochs=cli.epochs,
parallel_actors=cli.parallel_actors,
train_timesteps_per_epoch=cli.timesteps_per_epoch,
train_grad_updates_per_epoch=cli.grads_per_epoch,
train_batches_per_epoch=cli.batches_per_epoch,
start_learning_at_epoch=cli.start_learning_at_epoch,
val_interval=cli.val_interval,
ckpt_interval=cli.ckpt_interval,
Expand All @@ -353,7 +353,7 @@ def make_experiment_learn_only(experiment: amago.Experiment) -> amago.Experiment
def make_experiment_collect_only(experiment: amago.Experiment) -> amago.Experiment:
experiment.start_collecting_at_epoch = 0
experiment.start_learning_at_epoch = float("inf")
experiment.train_grad_updates_per_epoch = 0
experiment.train_batches_per_epoch = 0
experiment.val_checks_per_epoch = 0
experiment.ckpt_interval = None
experiment.always_save_latest = False
Expand Down
9 changes: 8 additions & 1 deletion amago/envs/builtin/babyai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from amago.envs import AMAGOEnv
from amago.hindsight import GoalSeq
from amago.envs.env_utils import space_convert, DiscreteActionWrapper
from amago.envs.env_utils import (
space_convert,
DiscreteActionWrapper,
AMAGO_ENV_LOG_PREFIX,
)


BANNED_BABYAI_TASKS = [
Expand Down Expand Up @@ -181,6 +185,9 @@ def step(self, action):
next_obs, reward, terminated, truncated, info = self.env.step(action)
done = False
if terminated or truncated:
info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.current_episode} Success"] = (
reward > 0.0
)
self.current_episode += 1
if self.current_episode > self.k_episodes:
done = True
Expand Down
14 changes: 8 additions & 6 deletions amago/envs/builtin/gym_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from amago.envs import AMAGOEnv
from amago.envs.env_utils import AMAGO_ENV_LOG_PREFIX
from amago.hindsight import GoalSeq


Expand Down Expand Up @@ -130,11 +131,11 @@ def __init__(

def reset(self, *args, **kwargs):
self.current_map = [[t for t in row] for row in generate_random_map(self.size)]
self.action_mapping = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]
act_map = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]
if self.hard_mode and random.random() < 0.5:
temp = self.action_mapping[1]
self.action_mapping[1] = self.action_mapping[3]
self.action_mapping[3] = temp
# flip two control directions for an extra challenge
act_map[1], act_map[3] = act_map[3], act_map[1]
self.action_mapping = act_map
self.current_k = 0
return self.soft_reset()

Expand Down Expand Up @@ -181,13 +182,14 @@ def step(self, action):
else:
reward = 0.0
soft_reset = False

self.x = next_x
self.y = next_y

if soft_reset:
self.current_k += 1
next_state, info = self.soft_reset()
success = on == "G"
info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.current_k} Success"] = success
self.current_k += 1
else:
next_state, info = self.make_obs(False), {}

Expand Down
6 changes: 3 additions & 3 deletions amago/envs/builtin/tmaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def __init__(
)
self.bias_x, self.bias_y = 1, 2
self.tmaze_map[self.bias_y, self.bias_x : -self.bias_x] = True # corridor
self.tmaze_map[[self.bias_y - 1, self.bias_y + 1], -self.bias_x - 1] = (
True # goal candidates
)
self.tmaze_map[
[self.bias_y - 1, self.bias_y + 1], -self.bias_x - 1
] = True # goal candidates

obs_dim = 2 if self.ambiguous_position else 3
if self.expose_goal: # test Markov policies
Expand Down
37 changes: 36 additions & 1 deletion amago/envs/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from uuid import uuid4
from dataclasses import dataclass
from collections import defaultdict
from typing import Optional, Type, Callable, Iterable
from typing import Optional, Type, Callable, Iterable, Any

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -327,7 +327,36 @@ def add_score(self, env_name, score):
SuccessHistory = ReturnHistory


class SpecialMetricHistory:
log_prefix = "AMAGO_LOG_METRIC"

def __init__(self, env_name):
self.data = {}

def add_score(self, env_name: str, key: str, value: Any):
if key.startswith(self.log_prefix):
key = key[len(self.log_prefix) :].strip()
if env_name not in self.data:
self.data[env_name] = {}
if key not in self.data[env_name]:
self.data[env_name][key] = [value]
else:
self.data[env_name][key].append(value)


AMAGO_ENV_LOG_PREFIX = SpecialMetricHistory.log_prefix


class SequenceWrapper(gym.Wrapper):
"""
A wrapper that handles automatic resets, saving trajectory files to disk,
and rollout metrics. Automatically logs total return in all envs.
When using the goal-conditioned / relabeling system, this will
log meaningful total success rates. We also log any metric from the
gym env's `info` dict that begins with "AMAGO_LOG_METRIC"
(`amago.envs.env_utils.AMAGO_ENV_LOG_PREFIX`).
"""

def __init__(
self,
env,
Expand Down Expand Up @@ -388,6 +417,7 @@ def reset_stats(self):
# stores all of the success/return histories
self.return_history = ReturnHistory(self.env_name)
self.success_history = SuccessHistory(self.env_name)
self.special_history = SpecialMetricHistory(self.env_name)

def reset(self, seed=None) -> Timestep:
timestep, info = self.env.reset(seed=seed)
Expand All @@ -407,6 +437,10 @@ def step(self, action):
self.total_return += reward
self.active_traj.add_timestep(timestep)
self.since_last_save += 1
for info_key, info_val in info.items():
if info_key.startswith(self.special_history.log_prefix):
self.special_history.add_score(self.env.env_name, info_key, info_val)

if timestep.terminal:
self.return_history.add_score(self.env.env_name, self.total_return)
success = (
Expand All @@ -415,6 +449,7 @@ def step(self, action):
else info["success"]
)
self.success_history.add_score(self.env.env_name, success)

save = (
self.save_every is not None and self.since_last_save > self.save_this_time
)
Expand Down
70 changes: 47 additions & 23 deletions amago/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from amago.envs.env_utils import (
ReturnHistory,
SuccessHistory,
SpecialMetricHistory,
ExplorationWrapper,
EpsilonGreedy,
SequenceWrapper,
Expand Down Expand Up @@ -73,7 +74,7 @@ class Experiment:
start_learning_at_epoch: int = 0
start_collecting_at_epoch: int = 0
train_timesteps_per_epoch: int = 1000
train_grad_updates_per_epoch: int = 1000
train_batches_per_epoch: int = 1000
val_interval: Optional[int] = 10
val_timesteps_per_epoch: int = 10_000
log_interval: int = 250
Expand Down Expand Up @@ -300,7 +301,9 @@ def init_dsets(self):
dset_root=self.dset_root,
dset_name=self.dset_name,
dset_split="train",
items_per_epoch=self.train_grad_updates_per_epoch * self.batch_size,
items_per_epoch=self.train_batches_per_epoch
* self.batch_size
* self.accelerator.num_processes,
max_seq_len=self.max_seq_len,
)

Expand Down Expand Up @@ -463,28 +466,38 @@ def get_t(_dones=None):

return_history = utils.call_async_env(envs, "return_history")
success_history = utils.call_async_env(envs, "success_history")
special_history = utils.call_async_env(envs, "special_history")
return (
(obs_seqs, goal_seqs, rl2_seqs),
hidden_state,
return_history,
success_history,
special_history,
)

def collect_new_training_data(self):
if self.train_timesteps_per_epoch > 0:
self.train_buffers, self.hidden_state, returns, successes = self.interact(
(
self.train_buffers,
self.hidden_state,
returns,
successes,
specials,
) = self.interact(
self.train_envs,
self.train_timesteps_per_epoch,
buffers=self.train_buffers,
)

def evaluate_val(self):
if self.val_timesteps_per_epoch > 0:
*_, returns, successes = self.interact(
*_, returns, successes, specials = self.interact(
self.val_envs,
self.val_timesteps_per_epoch,
)
logs_per_process = self.policy_metrics(returns, successes)
logs_per_process = self.policy_metrics(
returns, successes, specials=specials
)
cur_return = logs_per_process["Average Total Return (Across All Env Names)"]
if self.verbose:
self.accelerator.print(f"Average Return : {cur_return}")
Expand All @@ -499,12 +512,12 @@ def evaluate_test(
)
Par = gym.vector.AsyncVectorEnv if self.async_envs else DummyAsyncVectorEnv
test_envs = Par([make for _ in range(self.parallel_actors)])
*_, returns, successes = self.interact(
*_, returns, successes, specials = self.interact(
test_envs,
timesteps,
render=render,
)
logs = self.policy_metrics(returns, successes)
logs = self.policy_metrics(returns, successes, specials)
self.log(logs, key="test")
test_envs.close()
return logs
Expand Down Expand Up @@ -544,35 +557,46 @@ def make_figures(self, loss_info) -> dict[str, wandb.Image]:
"""
return {}

def policy_metrics(self, returns: ReturnHistory, successes: SuccessHistory):
return_by_env_name = {}
success_by_env_name = {}
for ret, suc in zip(returns, successes):
def policy_metrics(
self,
returns: ReturnHistory,
successes: SuccessHistory,
specials: SpecialMetricHistory,
) -> dict:
returns_by_env_name = defaultdict(list)
success_by_env_name = defaultdict(list)
specials_by_env_name = defaultdict(lambda: defaultdict(list))

for ret, suc, spe in zip(returns, successes, specials):
for env_name, scores in ret.data.items():
if env_name in return_by_env_name:
return_by_env_name[env_name] += scores
else:
return_by_env_name[env_name] = scores
returns_by_env_name[env_name].extend(scores)
for env_name, scores in suc.data.items():
if env_name in success_by_env_name:
success_by_env_name[env_name] += scores
else:
success_by_env_name[env_name] = scores
success_by_env_name[env_name].extend(scores)
for env_name, specials_dict in spe.data.items():
for special_key, special_val in specials_dict.items():
specials_by_env_name[env_name][special_key].extend(special_val)

avg_ret_per_env = {
f"Average Total Return in {name}": np.array(scores).mean()
for name, scores in return_by_env_name.items()
for name, scores in returns_by_env_name.items()
}
avg_suc_per_env = {
f"Average Success Rate in {name}": np.array(scores).mean()
for name, scores in success_by_env_name.items()
}
avg_special_per_env = {
f"Average {special_key} in {name}": np.array(special_vals).mean()
for name, specials_dict in specials_by_env_name.items()
for special_key, special_vals in specials_dict.items()
}
avg_return_overall = {
"Average Total Return (Across All Env Names)": np.array(
list(avg_ret_per_env.values())
).mean()
}
return avg_ret_per_env | avg_suc_per_env | avg_return_overall
return (
avg_ret_per_env | avg_suc_per_env | avg_return_overall | avg_special_per_env
)

def compute_loss(self, batch: Batch, log_step: bool):
critic_loss, actor_loss = self.policy_aclr(batch, log_step=log_step)
Expand Down Expand Up @@ -637,7 +661,7 @@ def make_pbar(loader, epoch_num):
return tqdm(
enumerate(loader),
desc=f"{self.run_name} Epoch {epoch_num} Train",
total=self.train_grad_updates_per_epoch,
total=self.train_batches_per_epoch,
colour="green",
)
else:
Expand Down Expand Up @@ -667,7 +691,7 @@ def make_pbar(loader, epoch_num):
continue
self.policy_aclr.train()
for train_step, batch in make_pbar(self.train_dloader, epoch):
total_step = (epoch * self.train_grad_updates_per_epoch) + train_step
total_step = (epoch * self.train_batches_per_epoch) + train_step
log_step = total_step % self.log_interval == 0
loss_dict = self.train_step(batch, log_step=log_step)
if log_step:
Expand Down
5 changes: 2 additions & 3 deletions examples/00_kshot_frozen_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def add_cli(parser):
)
parser.add_argument("--run_name", type=str, required=True)
parser.add_argument("--buffer_dir", type=str, required=True)
parser.add_argument("--gpu", type=int, required=True)
parser.add_argument("--log", action="store_true")
parser.add_argument("--trials", type=int, default=3)
parser.add_argument("--lake_size", type=int, default=5)
Expand Down Expand Up @@ -87,18 +86,18 @@ def add_cli(parser):
make_val_env=make_env,
max_seq_len=args.max_rollout_length,
traj_save_len=args.max_rollout_length,
agent_type=amago.agent.Agent,
dset_max_size=10_000,
run_name=run_name,
dset_name=run_name,
gpu=args.gpu,
dset_root=args.buffer_dir,
dloader_workers=10,
log_to_wandb=args.log,
wandb_group_name=group_name,
epochs=500 if not args.hard_mode else 900,
parallel_actors=24,
train_timesteps_per_epoch=350,
train_grad_updates_per_epoch=700,
train_batches_per_epoch=700,
val_interval=20,
val_timesteps_per_epoch=args.max_rollout_length * 2,
ckpt_interval=50,
Expand Down

0 comments on commit 15a6f76

Please sign in to comment.