forked from facebookresearch/mtrl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
99 lines (84 loc) · 3.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""This is the main entry point for the code."""
import hydra
import wandb
import time
import os
from toolbox.aws import shutdown
from mtrl.app.run import run
from mtrl.utils import config as config_utils
from mtrl.utils.types import ConfigType
from toolbox.printing import print_visible
def start_wandb(config, seed: int = -1):
wandb_seed = seed if seed >= 0 else config.setup.seed_ref
wandb_name = f"{config.setup.name}_{wandb_seed}"
group_wandb = config.setup.name
config_wandb = {
"num_tasks": config.agent.multitask.num_envs,
"agent": config.agent.name,
"encoder": config.agent.encoder.type_to_select,
"seed": wandb_seed,
"agent/encoder_feature_dim": config.agent.encoder_feature_dim,
"agent/num_layers": config.agent.num_layers,
"agent/num_filters": config.agent.num_filters,
"actor/num_layer": config.agent.actor.num_layers,
"actor/hidden_dim": config.agent.actor.hidden_dim,
"num_train_steps": config.experiment.num_train_steps,
"eval_freq": config.experiment.eval_freq,
"num_eval_episodes": config.experiment.num_eval_episodes,
"lr/actor": config.agent.optimizers.actor.lr,
"lr/critic": config.agent.optimizers.critic.lr,
"lr/alpha": config.agent.optimizers.alpha.lr,
"lr/decoder": config.agent.optimizers.decoder.lr,
"lr/encoder": config.agent.optimizers.encoder.lr,
"batch_size": config.replay_buffer.batch_size,
"pal/dim": config.agent.multitask.pal_cfg.pal_dim,
"pal/shared": config.agent.multitask.pal_cfg.shared_projection,
"pal/residual": config.agent.multitask.pal_cfg.residual_mode,
}
wandb.init(
project=f"MTRL{config.agent.multitask.num_envs}",
name=wandb_name,
group=group_wandb,
config=config_wandb,
)
def launch_one_seed(config, seed: int, time_start: int = -1):
if time_start < 0:
time_start = time.time()
try:
# Run "mv logs/* logs_saved/"
print_visible("Starting seed " + str(seed))
mtrl_path = os.path.normpath(
os.path.join(os.path.realpath(__file__), "../../../../")
)
logs_path = os.path.normpath(os.path.join(mtrl_path, "logs"))
logs_path = os.path.normpath(logs_path)
print_visible("Repo Path: " + mtrl_path)
print_visible("Logs Path: " + logs_path)
os.system(f'rm -r -f " {logs_path}"')
start_wandb(config, seed=seed)
run(config, seed=seed)
except Exception as e:
# If it has been running for less than 5 minutes, then it is probably a bug
# Otherwise, it is probably a timeout, so shutdown the instance
if time.time() - time_start < 5 * 60:
raise e
else:
print(f"Exception: {e}")
print("Timeout, shutting down")
wandb.finish()
# shutdown()
return
wandb.finish()
return
@hydra.main(config_path="config", config_name="config")
def launch(config: ConfigType) -> None:
seed_ref = config.setup.seed_ref
config = config_utils.process_config(config)
time_start = time.time()
for seed_inc in range(config.setup.num_seeds):
seed = seed_ref + seed_inc
launch_one_seed(config, seed=seed, time_start=time_start)
return
if __name__ == "__main__":
launch()