-
Notifications
You must be signed in to change notification settings - Fork 7
/
function_runner.py
130 lines (101 loc) · 3.81 KB
/
function_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import time
import threading
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TIMESTEPS_TOTAL
logger = logging.getLogger(__name__)
class StatusReporter(object):
"""Object passed into your main() that you can report status through.
Example:
>>> reporter = StatusReporter()
>>> reporter(timesteps_total=1)
"""
def __init__(self):
self._latest_result = None
self._last_result = None
self._lock = threading.Lock()
self._error = None
self._done = False
def __call__(self, **kwargs):
"""Report updated training status.
Args:
kwargs: Latest training result status.
"""
with self._lock:
self._latest_result = self._last_result = kwargs.copy()
def _get_and_clear_status(self):
if self._error:
raise TuneError("Error running trial: " + str(self._error))
if self._done and not self._latest_result:
if not self._last_result:
raise TuneError("Trial finished without reporting result!")
self._last_result.update(done=True)
return self._last_result
with self._lock:
res = self._latest_result
self._latest_result = None
return res
def _stop(self):
self._error = "Agent stopped"
DEFAULT_CONFIG = {
# batch results to at least this granularity
"script_min_iter_time_s": 1,
}
class _RunnerThread(threading.Thread):
"""Supervisor thread that runs your script."""
def __init__(self, entrypoint, config, status_reporter):
self._entrypoint = entrypoint
self._entrypoint_args = [config, status_reporter]
self._status_reporter = status_reporter
threading.Thread.__init__(self)
self.daemon = True
def run(self):
try:
self._entrypoint(*self._entrypoint_args)
except Exception as e:
self._status_reporter._error = e
logger.exception("Runner Thread raised error.")
raise e
finally:
self._status_reporter._done = True
class FunctionRunner(Trainable):
"""Trainable that runs a user function returning training results.
This mode of execution does not support checkpoint/restore."""
_name = "func"
_default_config = DEFAULT_CONFIG
def _setup(self, config):
entrypoint = self._trainable_func()
self._status_reporter = StatusReporter()
scrubbed_config = config.copy()
for k in self._default_config:
if k in scrubbed_config:
del scrubbed_config[k]
self._runner = _RunnerThread(entrypoint, scrubbed_config,
self._status_reporter)
self._start_time = time.time()
self._last_reported_timestep = 0
self._runner.start()
def _trainable_func(self):
"""Subclasses can override this to set the trainable func."""
raise NotImplementedError
def _train(self):
time.sleep(
self.config.get("script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
time.sleep(1)
result = self._status_reporter._get_and_clear_status()
curr_ts_total = result.get(TIMESTEPS_TOTAL)
if curr_ts_total is not None:
result.update(
timesteps_this_iter=(
curr_ts_total - self._last_reported_timestep))
self._last_reported_timestep = curr_ts_total
return result
def _stop(self):
self._status_reporter._stop()