forked from baidu/Senta
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lanch.py
130 lines (112 loc) · 4.61 KB
/
lanch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""lanch.py"""
import argparse
import copy
import logging
import os
import subprocess
import sys
from senta.utils.args import ArgumentGroup, print_arguments
from senta.utils import log
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
multip_g = ArgumentGroup(parser, "multiprocessing",
"start paddle training using multi-processing mode.")
multip_g.add_arg("node_ips", str, None,
"paddle trainer ips")
multip_g.add_arg("node_id", int, None,
"the trainer id of the node for multi-node distributed training.")
multip_g.add_arg("print_config", bool, True,
"print the config of multi-processing mode.")
multip_g.add_arg("current_node_ip", str, None,
"the ip of current node.")
multip_g.add_arg("split_log_path", str, "log",
"log path for each trainer.")
multip_g.add_arg("log_prefix", str, "",
"the prefix name of job log.")
multip_g.add_arg("nproc_per_node", int, 8,
"the number of process to use on each node.")
multip_g.add_arg("selected_gpus", str, "0,1,2,3,4,5,6,7",
"the gpus selected to use.")
multip_g.add_arg("training_script", str, None, "the program/script to be lauched "
"in parallel followed by all the arguments", positional_arg=True)
multip_g.add_arg("training_script_args", str, None,
"training script args", positional_arg=True, nargs=argparse.REMAINDER)
# yapf: enable
def start_procs(args):
"""start process"""
procs = []
log_fns = []
default_env = os.environ.copy()
node_id = args.node_id
node_ips = [x.strip() for x in args.node_ips.split(',')]
current_ip = args.current_node_ip
num_nodes = len(node_ips)
selected_gpus = [x.strip() for x in args.selected_gpus.split(',')]
selected_gpu_num = len(selected_gpus)
all_trainer_endpoints = ""
for ip in node_ips:
for i in range(args.nproc_per_node):
if all_trainer_endpoints != "":
all_trainer_endpoints += ","
all_trainer_endpoints += "%s:618%d" % (ip, i)
nranks = num_nodes * args.nproc_per_node
gpus_per_proc = args.nproc_per_node % selected_gpu_num
if gpus_per_proc == 0:
gpus_per_proc = selected_gpu_num // args.nproc_per_node
else:
gpus_per_proc = selected_gpu_num // args.nproc_per_node + 1
selected_gpus_per_proc = [selected_gpus[i:i + gpus_per_proc] for i in range(0, len(selected_gpus), gpus_per_proc)]
if args.print_config:
logging.info("all_trainer_endpoints: %s" % all_trainer_endpoints)
logging.info("node_id: %s" % node_id)
logging.info("current_ip: %s" % node_id)
logging.info("num_nodes: %s" % num_nodes)
logging.info("node_ips: %s" % node_ips)
logging.info("gpus_per_proc: %s" % gpus_per_proc)
logging.info("selected_gpus_per_proc: %s" % selected_gpus_per_proc)
logging.info("nranks: %s" % nranks)
current_env = copy.copy(default_env)
procs = []
cmds = []
log_fns = []
for i in range(0, args.nproc_per_node):
trainer_id = node_id * args.nproc_per_node + i
current_env.update({
"FLAGS_selected_gpus": "%s" % ",".join([str(s) for s in selected_gpus_per_proc[i]]),
"PADDLE_TRAINER_ID": "%d" % trainer_id,
"PADDLE_CURRENT_ENDPOINT": "%s:618%d" % (current_ip, i),
"PADDLE_TRAINERS_NUM": "%d" % nranks,
"PADDLE_TRAINER_ENDPOINTS": all_trainer_endpoints,
"PADDLE_NODES_NUM": "%d" % num_nodes
})
cmd = [sys.executable, "-u",
args.training_script] + args.training_script_args
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%s.job.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
process = subprocess.Popen(cmd, env=current_env)
procs.append(process)
for i in range(len(procs)):
proc = procs[i]
proc.wait()
if len(log_fns) > 0:
log_fns[i].close()
if proc.returncode != 0:
logging.info("proc %d run failed" % i)
raise subprocess.CalledProcessError(returncode=procs[i].returncode,
cmd=cmds[i])
else:
logging.info("proc %d run success" % i)
def main(args):
"""main"""
if args.print_config:
print_arguments(args)
start_procs(args)
if __name__ == "__main__":
"""start"""
lanch_args = parser.parse_args()
log.init_log(os.path.join(lanch_args.split_log_path, "lanch"), level=logging.DEBUG)
main(lanch_args)