-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·51 lines (46 loc) · 1.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import argparse
import os
import numpy as np
import random
from agents.smfg import SMFG
from agents.m3rl import M3RL
from agents.ebpg import EBPG
from agents.utils import get_config
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--w_leader', action="store_false", default=True, help='Whether consider leader incentives')
arg_parser.add_argument('--algo', type=str, default="ebpg", help='smfu, tirl, m3rl, ebpg')
arg_parser.add_argument('--regularizer', default=[0, 1, 0, 1, 0, 0, 0], # one-hot: 0-turn off, 1-turn on
help='0-no reg, 1-L1 Norm, 2-L2 Norm, 3-Weight Clip, 4-Dropout, 5-Batch Normalization, 6-Entropy')
arg_parser.add_argument('--follower_number', default=300, help='Follower number')
arg_parser.add_argument('--seed', default=0, help='Random seed')
arg_parser.add_argument('--episode_length', type=int, default=25, help='Maximum episode length')
arg_parser.add_argument('--episode_number', default=2000, help='Maximum number of training episodes')
arg_parser.add_argument('--record_dir', type=str, default='results', help='Directory to store result data')
arg_parser.add_argument('--env', type=str, default='edrp', help='edrp, mtfg')
algo_choices = ["smfu", "tirl", "m3rl", "ebpg"]
env_choices = ["edrp", "mtfg"]
if __name__ == '__main__':
device_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(device_id)
args = arg_parser.parse_args()
args.cuda = torch.cuda.is_available()
args.cuda_id = device_id
if args.cuda:
torch.cuda.manual_seed_all(args.seed)
else:
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
assert args.env in env_choices
assert args.algo in algo_choices
config = get_config(args)
if args.algo == "m3rl":
smfg = M3RL(args, config)
smfg.train()
elif args.algo == "ebpg":
smfg = EBPG(args, config)
smfg.train()
else:
smfg = SMFG(args, config)
smfg.train()