-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
142 lines (126 loc) · 4.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import yaml
import ray
from ray.test.cluster_utils import Cluster
from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.tune import _make_scheduler, run_experiments
EXAMPLE_USAGE = """
Training example via RLlib CLI:
rllib train --run DQN --env CartPole-v0
Grid search example via RLlib CLI:
rllib train -f tuned_examples/cartpole-grid-search-example.yaml
Grid search example via executable:
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
Note that -f overrides all other trial-specific command-line options.
"""
def create_parser(parser_creator=None):
parser = make_parser(
parser_creator=parser_creator,
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Train a reinforcement learning agent.",
epilog=EXAMPLE_USAGE)
# See also the base parser definition in ray/tune/config_parser.py
parser.add_argument(
"--redis-address",
default=None,
type=str,
help="The Redis address of the cluster.")
parser.add_argument(
"--ray-num-cpus",
default=None,
type=int,
help="--num-cpus to pass to Ray."
" This only has an affect in local mode.")
parser.add_argument(
"--ray-num-gpus",
default=None,
type=int,
help="--num-gpus to pass to Ray."
" This only has an affect in local mode.")
parser.add_argument(
"--ray-num-local-schedulers",
default=None,
type=int,
help="Emulate multiple cluster nodes for debugging.")
parser.add_argument(
"--ray-object-store-memory",
default=None,
type=int,
help="--object-store-memory to pass to Ray."
" This only has an affect in local mode.")
parser.add_argument(
"--experiment-name",
default="default",
type=str,
help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
"--queue-trials",
action='store_true',
help=(
"Whether to queue trials when the cluster does not currently have "
"enough resources to launch one. This should be set to True when "
"running on an autoscaling cluster to enable automatic scale-up."))
parser.add_argument(
"-f",
"--config-file",
default=None,
type=str,
help="If specified, use config options from this file. Note that this "
"overrides any trial-specific options set via flags above.")
return parser
def run(args, parser):
if args.config_file:
with open(args.config_file) as f:
experiments = yaml.load(f)
else:
# Note: keep this in sync with tune/config_parser.py
experiments = {
args.experiment_name: { # i.e. log to ~/ray_results/default
"run": args.run,
"checkpoint_freq": args.checkpoint_freq,
"local_dir": args.local_dir,
"trial_resources": (
args.trial_resources and
resources_to_json(args.trial_resources)),
"stop": args.stop,
"config": dict(args.config, env=args.env),
"restore": args.restore,
"num_samples": args.num_samples,
"upload_dir": args.upload_dir,
}
}
for exp in experiments.values():
if not exp.get("run"):
parser.error("the following arguments are required: --run")
if not exp.get("env") and not exp.get("config", {}).get("env"):
parser.error("the following arguments are required: --env")
if args.ray_num_local_schedulers:
cluster = Cluster()
for _ in range(args.ray_num_local_schedulers):
cluster.add_node(
resources={
"num_cpus": args.ray_num_cpus or 1,
"num_gpus": args.ray_num_gpus or 0,
},
object_store_memory=args.ray_object_store_memory)
ray.init(redis_address=cluster.redis_address)
else:
ray.init(
redis_address=args.redis_address,
object_store_memory=args.ray_object_store_memory,
num_cpus=args.ray_num_cpus,
num_gpus=args.ray_num_gpus)
run_experiments(
experiments,
scheduler=_make_scheduler(args),
queue_trials=args.queue_trials)
if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
run(args, parser)