-
Notifications
You must be signed in to change notification settings - Fork 2
/
console.py
79 lines (68 loc) · 2.67 KB
/
console.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 15 16:46:10 2016
@author: lanlin
"""
import os
import argparse
import yaml
from utils import client
def main(args):
with open(args.config, "r") as f:
config = yaml.load(f)
try:
k8s_config = config["kubernetes"]
except KeyError:
k8s_config = {}
try:
tf_config = config["tensorflow"]
except KeyError:
tf_config = {}
try:
data_config = config["data"]
except KeyError:
data_config = {}
groups = client.request(int(k8s_config["num_gpus"]))
k8s_config.pop("num_gpus")
k8s_config["num_groups"] = len(groups)
k8s_config["ps_replicas"] = len(groups) \
if len(groups) % 2 != 0 else len(groups) - 1
k8s_config["worker_replicas_each_group"] = groups
k8s_args = ["--{} \"{}\"".format(key, value)
for key, value in k8s_config.iteritems()]
k8s_args = " ".join(k8s_args) if k8s_args else ""
tf_args = [(key, value) for key, value in tf_config.iteritems()]
tf_args = "--parameters \"{}\"".format(tf_args) if tf_args else ""
render_args = "{} {}".format(k8s_args, tf_args)
os.system("python ./kubernetes/render.py {}".format(render_args))
copy_files = ["./tensorflow/distribute_tf.py"]
if "user_model_file" not in tf_config.keys() \
or tf_config["user_model_file"] == "example_user_model.py":
copy_files.append("./examples/example_user_model.py")
else:
copy_files.append(tf_config["user_model_file"])
if "data_dir" not in data_config.keys() \
or data_config["data_dir"] == "./examples/data":
copy_files.append("./examples/data")
else:
copy_files.append(data_config["data_dir"])
work_dir = ROOT_DIR + "../GPU_cluster_storage/{}/{}".format(
k8s_config["namespace"],
k8s_config["name"])
os.system("mkdir -p {}".format(work_dir))
os.system("echo {} > {}/.groups".format(groups, work_dir))
os.system("cp -r {} {}".format(" ".join(copy_files), work_dir))
os.system("chmod +x {}/start_cluster.sh".format(work_dir))
os.system("chmod +x {}/stop_cluster.sh".format(work_dir))
os.system("chmod +x {}/start_tensorboard.sh".format(work_dir))
os.system("chmod +x {}/stop_tensorboard.sh".format(work_dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="")
parser.add_argument(
"--config",
default="./examples/example_config.yaml")
args = parser.parse_args()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
ROOT_DIR = "{}/".format(os.getcwd())
config=main(args)