-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy path05_dark_key_door.py
101 lines (93 loc) · 3.51 KB
/
05_dark_key_door.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from argparse import ArgumentParser
import wandb
from amago.envs.builtin.toy_gym import RoomKeyDoor
from amago.envs import AMAGOEnv
from amago.cli_utils import *
def add_cli(parser):
parser.add_argument(
"--meta_horizon",
type=int,
default=500,
help="Total meta-adaptation timestep budget for the agent to explore the same room layout.",
)
parser.add_argument(
"--room_size",
type=int,
default=8,
help="Size of the room. Exploration is sparse and difficulty scales quickly with room size.",
)
parser.add_argument(
"--episode_length",
type=int,
default=50,
help="Maximum length of a single episode in the environment.",
)
parser.add_argument(
"--light_room_observation",
action="store_true",
help="Demonstrate how meta-RL relies on partial observability by revealing the goal location as part of the observation. This version of the environment can be solved without memory!",
)
parser.add_argument(
"--randomize_actions",
action="store_true",
help="Randomize the agent's action space to make the task harder.",
)
return parser
if __name__ == "__main__":
parser = ArgumentParser()
add_common_cli(parser)
add_cli(parser)
args = parser.parse_args()
config = {}
tstep_encoder_type = switch_tstep_encoder(
config, arch="ff", n_layers=2, d_hidden=128, d_output=64
)
traj_encoder_type = switch_traj_encoder(
config,
arch=args.traj_encoder,
memory_size=args.memory_size,
layers=args.memory_layers,
)
agent_type = switch_agent(config, args.agent_type, reward_multiplier=100.0)
# the fancier exploration schedule mentioned in the appendix can help
# when the domain is a true meta-RL problem and the "horizon" time limit
# (above) is actually relevant for resetting the task.
exploration_type = switch_exploration(
config, "bilevel", steps_anneal=500_000, rollout_horizon=args.meta_horizon
)
use_config(config, args.configs)
group_name = f"{args.run_name}_dark_key_door"
for trial in range(args.trials):
run_name = group_name + f"_trial_{trial}"
make_train_env = lambda: AMAGOEnv(
env=RoomKeyDoor(
size=args.room_size,
max_episode_steps=args.episode_length,
meta_rollout_horizon=args.meta_horizon,
dark=not args.light_room_observation,
randomize_actions=args.randomize_actions,
),
env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}",
)
experiment = create_experiment_from_cli(
args,
agent_type=agent_type,
tstep_encoder_type=tstep_encoder_type,
traj_encoder_type=traj_encoder_type,
make_train_env=make_train_env,
make_val_env=make_train_env,
max_seq_len=args.meta_horizon,
traj_save_len=args.meta_horizon,
group_name=group_name,
run_name=run_name,
val_timesteps_per_epoch=args.meta_horizon * 4,
exploration_wrapper_type=exploration_type,
)
switch_async_mode(experiment, args.mode)
experiment.start()
if args.ckpt is not None:
experiment.load_checkpoint(args.ckpt)
experiment.learn()
experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
experiment.delete_buffer_from_disk()
wandb.finish()