-
Notifications
You must be signed in to change notification settings - Fork 7
/
registry.py
106 lines (82 loc) · 3.38 KB
/
registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from types import FunctionType
import ray
import ray.cloudpickle as pickle
from ray.experimental.internal_kv import _internal_kv_initialized, \
_internal_kv_get, _internal_kv_put
TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
]
def register_trainable(name, trainable):
"""Register a trainable function or class.
Args:
name (str): Name to register.
trainable (obj): Function or tune.Trainable class. Functions must
take (config, status_reporter) as arguments and will be
automatically converted into a class during registration.
"""
from ray.tune.trainable import Trainable, wrap_function
if isinstance(trainable, FunctionType):
trainable = wrap_function(trainable)
if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable",
trainable)
_global_registry.register(TRAINABLE_CLASS, name, trainable)
def register_env(name, env_creator):
"""Register a custom environment for use with RLlib.
Args:
name (str): Name to register.
env_creator (obj): Function that creates an env.
"""
if not isinstance(env_creator, FunctionType):
raise TypeError("Second argument must be a function.", env_creator)
_global_registry.register(ENV_CREATOR, name, env_creator)
def _make_key(category, key):
"""Generate a binary key for the given category and key.
Args:
category (str): The category of the item
key (str): The unique identifier for the item
Returns:
The key to use for storing a the value.
"""
return (b"TuneRegistry:" + category.encode("ascii") + b"/" +
key.encode("ascii"))
class _Registry(object):
def __init__(self):
self._to_flush = {}
def register(self, category, key, value):
if category not in KNOWN_CATEGORIES:
from ray.tune import TuneError
raise TuneError("Unknown category {} not among {}".format(
category, KNOWN_CATEGORIES))
self._to_flush[(category, key)] = pickle.dumps(value)
if _internal_kv_initialized():
self.flush_values()
def contains(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(category, key))
return value is not None
else:
return (category, key) in self._to_flush
def get(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(category, key))
if value is None:
raise ValueError(
"Registry value for {}/{} doesn't exist.".format(
category, key))
return pickle.loads(value)
else:
return pickle.loads(self._to_flush[(category, key)])
def flush_values(self):
for (category, key), value in self._to_flush.items():
_internal_kv_put(_make_key(category, key), value, overwrite=True)
self._to_flush.clear()
_global_registry = _Registry()
ray.worker._post_init_hooks.append(_global_registry.flush_values)