forked from dmlc/mshadow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mshadow_ps.h
358 lines (350 loc) · 12.1 KB
/
mshadow_ps.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
/*!
* Copyright (c) 2014 by Contributors
* \file mshadow_ps.h
* \brief parameter server abstraction for mshadow tensor
* this is a plugin of mshadow that can be used to syncrhonize
* parameters across device and machines
*
* \author Tianqi Chen, Mu Li
*/
#ifndef MSHADOW_PS_H_ // NOLINT(*)
#define MSHADOW_PS_H_ // NOLINT(*)
#include <vector>
// optionally support of lambda function in C++11, if available
#if __cplusplus >= 201103L
#include <functional>
#endif // C++11
#include "../mshadow/tensor.h"
/*! \brief whether to adapt distributed PS from parameter-server */
#ifndef MSHADOW_DIST_PS
#define MSHADOW_DIST_PS 1
#endif
/*! \brief whether to support BSP rabit API of PS*/
#ifndef MSHADOW_RABIT_PS
#define MSHADOW_RABIT_PS 1
#endif
namespace mshadow {
/*! \brief namespace of mshadow-ps */
namespace ps {
/*!
* \brief interface of parameter server
* \tparam xpu the device of the data lies
* \tparam DType the type of element in the tensor
*/
template<typename xpu,
typename DType MSHADOW_DEFAULT_DTYPE>
class ISharedModel {
public:
/*!
* \brief callback function that will be executed when pull request finishes
* before calling the callback, the thread context is already switched
* to the device of pullrequest
* \param stream the stream of callback thread, it is recommended to operate using this stream
* \param arg the argument of callback function
*/
typedef void (CallbackFunction) (Stream<xpu> *stream, void *arg);
/*! \brief virtual destructor */
virtual ~ISharedModel(void) {}
/*!
* \brief Set param for the layer from string
* \param name parameter name
* \param val string for configuration
*/
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief initialize the paramerver server client
* \param devices specifies the possible device id
* to be input from Push and Pull,
*/
virtual void Init(const std::vector<int> &devices) {}
/*!
* \brief initialize the paramerver server client
* without specifying the devices, only device 0 is allowed
*/
inline void Init(void) {
std::vector<int> dev;
dev.push_back(0);
this->Init(dev);
}
/*!
* \brief initialize a key with certain shape
* must be called before using Push/PullReq/PullWait
* on the corresponding key
* \param shape the shape content of the key
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
*/
template<int dim>
inline void InitKey(Shape<dim> shape,
int key, int devid) {
this->InitKey_(shape.FlatTo2D(), key, devid);
}
/*!
* \brief wait until the pull event finishes
* if there was no pull request, wait will directly returns
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
*/
virtual void PullWait(int key, int devid) = 0;
/*!
* \brief check if the weight was correct on the current device
*
* \param data the data
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
*/
template<int dim>
inline void CheckWeight(Tensor<xpu, dim, DType> data,
int key,
int devid) {
this->CheckWeight_(data.FlatTo2D(), key, devid);
}
/*!
* \brief push out a tensor to parameter server
* this call is asynchronize and returns immediately
*
* \param data the data
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
* \param priority the priority of this operation,
* the bigger the number is the higher the priority will be
*/
template<int dim>
inline void Push(Tensor<xpu, dim, DType> data,
int key,
int devid,
int priority = 0) {
this->Push_(data.FlatTo2D(), key, devid, priority);
}
/*!
* \brief send a pull request, to pull parameter into data
* this call is asynchronize and returns immediately
* use PullWait to wait the event of copy finish
*
* \param data the data
* \param key the unique key to indicate the tensor,
* this is unique per device
* \param devid the device id this tensor lies in
* \param priority the priority of this operation,
* the bigger the number is the higher the priority will be
* \param callback the callback function that will
* be invoked when the request finishes
* \param callback_arg the argument to pass to callback
*/
template<int dim>
inline void PullReq(Tensor<xpu, dim, DType> data,
int key,
int devid,
int priority = 0,
CallbackFunction callback = NULL,
void *callback_arg = NULL) {
this->PullReq_(data.FlatTo2D(), key,
devid, priority, callback, callback_arg);
}
#if __cplusplus >= 201103L
/*!
* \brief send a pull request, to pull parameter into data
* this call is asynchronize and returns immediately
* use PullWait to wait the event of copy finish
* this is the c++11 version that allows lambda function as callback
* \param data the data
* \param key the unique key to indicate the tensor,
* this is unique per device
* \param devid the device id this tensor lies in
* \param priority the priority of this operation,
* the bigger the number is the higher the priority will be
* \param callback the callback function
*/
template<int dim>
inline void PullReq(Tensor<xpu, dim, DType> data,
int key,
int devid,
int priority,
std::function<void(Stream<xpu> *stream)> callback) {
// need to allocate space, because callback can happen latter..
auto calbk = new std::function<void(Stream<xpu> *stream)>();
*calbk = callback;
this->PullReq(data, key, devid, priority, InvokeLambda_, calbk);
}
#endif // C++11
/*!
* \brief set weight of corresponding key in server
* this is a debug function that was not necessarily
* implemented by the server
* \param data the data to set
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
*/
virtual void SetWeight_(Tensor<xpu, 2, DType> data,
int key,
int devid) = 0;
/*!
* \brief check if the weight matches the server side
* this is a debug function that was not necessarily
* implemented by the server
* \param data the data to set
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
*/
virtual void CheckWeight_(Tensor<xpu, 2, DType> data,
int key,
int devid) = 0;
protected:
/*!
* \brief initialize a key with certain shape
* \param shape the shape content of the key
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
*/
virtual void InitKey_(Shape<2> shape,
int key, int devid) = 0;
/*!
* \brief push out a tensor to parameter server
* this call is asynchronize and returns immediately
*
* \param data the data
* \param key the unique key to indicate the tensor
* this is unique per device
* \param devid the device id this tensor lies in
* \param priority the priority of this operation,
* the bigger the number is the higher the priority will be
*/
virtual void Push_(Tensor<xpu, 2, DType> data,
int key,
int devid,
int priority = 0) = 0;
/*!
* \brief send a pull request, to pull parameter into data
* this call is asynchronize and returns immediately
* use PullWait to wait the event of copy finish
*
* \param data the data
* \param key the unique key to indicate the tensor,
* this is unique per device
* \param devid the device id this tensor lies in
* \param priority the priority of this operation,
* the bigger the number is the higher the priority will be
* \param callback the callback function that will
* be invoked when the request finishes
* \param callback_arg the argument to pass to callback
*/
virtual void PullReq_(Tensor<xpu, 2, DType> data,
int key,
int devid,
int priority,
CallbackFunction callback,
void *callback_arg) = 0;
private:
// C++11 support for lambda prepare function
#if __cplusplus >= 201103L
/*! \brief hack function to convert lambda to callback function */
inline static void InvokeLambda_(Stream<xpu> *stream, void *fun) {
auto *fp = static_cast<std::function<void(Stream<xpu> *stream)>*>(fun);
(*fp)(stream);
delete fp;
}
#endif // C++11
};
/*! \brief interface for customized mshadow server */
template<typename DType>
class IModelUpdater {
public:
virtual ~IModelUpdater(void) {}
/*!
* \brief set parameters from outside
* \param name name of parameter
* \param val value of parameter
*/
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief init the model updater
* \param rank the rank of the node
* \param argc number of arguments
* \param argv arguments
*/
virtual void InitUpdater(int rank, int argc, char *argv[]) {}
/*!
* \brief initialize the model
* \param key the key of data we point to
* \param dptr the data pointer
* \param size size of the parameter key
*/
virtual void InitModel(int key, DType *dptr, size_t size) {
this->InitModel_(key, Tensor<cpu, 1, DType>(dptr, Shape1(size)));
}
/*!
* update the model
* \param key the key of data we point to
* \param dptr the data pointer
* \param size size of the parameter key
*/
virtual void Update(int key, DType *dptr, size_t size) {
this->Update_(key, Tensor<cpu, 1, DType>(dptr, Shape1(size)));
}
protected:
/*!
* \brief initialize the model, user can implement this one
* to take advantage of tensor operations
* \param key the key of data we point to
* \param data the tensor data corresponding to the data we want to initialize
*/
virtual void InitModel_(int key, Tensor<cpu, 1, DType> data) {
LOG(FATAL) << "InitModel: not implemented";
}
/*!
* \brief update the model, user can implement this one
* to take advantage of tensor operations
* \param key the key of data we point to
* \param data the tensor data corresponding to the data we want to initialize
*/
virtual void Update_(int key, Tensor<cpu, 1, DType> data) {
LOG(FATAL) << "InitModel: not implemented";
}
};
/*!
* \brief create customized server
* this is a server defined by user
* \return new server
*/
template<typename DType>
IModelUpdater<DType> *CreateModelUpdater(void);
} // namespace ps
} // namespace mshadow
#include "./ps_local-inl.h"
#include "./ps_dist-inl.h"
#include "./ps_rabit-inl.h"
namespace mshadow {
namespace ps {
/*!
* \brief create a parameter server implementation
* \param type the type of paramerver server
* can either be "local" or "dist"
* \return the ISharedModel that can be used to synchronize weights
*/
template<typename xpu, typename DType>
inline ISharedModel<xpu, DType> *CreateSharedModel(const char *type) {
if (!strcmp("local", type)) {
#if MSHADOW_RABIT_PS
// allreduce on one machine pays no cost
if (rabit::IsDistributed()) {
return new RabitModel<xpu, DType>();
}
#endif
return new LocalModel<xpu, DType>();
}
#if MSHADOW_DIST_PS
if (!strcmp("dist", type)) return new DistModel<xpu, DType>();
#endif
LOG(FATAL) << "unknown server type " << type;
return NULL;
}
} // namespace ps
} // namespace mshadow
#endif // MSHADOW_PS_H_ NOLINT(*)