Skip to content

Commit

Permalink
Priority-based Push and Pull (#147)
Browse files Browse the repository at this point in the history
Messages sent based on a assigned priority
Higher priority messages gets preference over lower priority ones
Add environment variable for setting high water mark for zmq
  • Loading branch information
anandj91 authored and eric-haibin-lin committed Oct 16, 2019
1 parent 411c9bb commit db61897
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 14 deletions.
3 changes: 2 additions & 1 deletion docs/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ additional variables:
- `DMLC_INTERFACE` : the network interface a node should use. in default choose
automatically
- `DMLC_LOCAL` : runs in local machines, no network is needed
- `DMLC_PS_VAN_TYPE` : the type of the Van for transport, can be `ibverbs` for RDMA, `zmq` for TCP
- `DMLC_PS_WATER_MARK` : limit on the maximum number of outstanding messages
- `DMLC_PS_VAN_TYPE` : the type of the Van for transport, can be `ibverbs` for RDMA, `zmq` for TCP, `p3` for TCP with [priority based parameter propagation](https://anandj.in/wp-content/uploads/sysml.pdf).
4 changes: 2 additions & 2 deletions include/ps/internal/customer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <thread>
#include <memory>
#include "ps/internal/message.h"
#include "ps/internal/threadsafe_queue.h"
#include "ps/internal/threadsafe_pqueue.h"
namespace ps {

/**
Expand Down Expand Up @@ -100,7 +100,7 @@ class Customer {
int customer_id_;

RecvHandle recv_handle_;
ThreadsafeQueue<Message> recv_queue_;
ThreadsafePQueue recv_queue_;
std::unique_ptr<std::thread> recv_thread_;

std::mutex tracker_mu_;
Expand Down
2 changes: 2 additions & 0 deletions include/ps/internal/message.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ struct Meta {
Control control;
/** \brief the byte size */
int data_size = 0;
/** \brief message priority */
int priority = 0;
};
/**
* \brief messages that communicated amaong nodes.
Expand Down
59 changes: 59 additions & 0 deletions include/ps/internal/threadsafe_pqueue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/**
* Copyright (c) 2015 by Contributors
*/
#ifndef PS_INTERNAL_THREADSAFE_PQUEUE_H_
#define PS_INTERNAL_THREADSAFE_PQUEUE_H_
#include <queue>
#include <mutex>
#include <condition_variable>
#include <memory>
#include <utility>
#include <vector>
#include "ps/base.h"
namespace ps {

/**
* \brief thread-safe queue allowing push and waited pop
*/
class ThreadsafePQueue {
public:
ThreadsafePQueue() { }
~ThreadsafePQueue() { }

/**
* \brief push an value into the end. threadsafe.
* \param new_value the value
*/
void Push(Message new_value) {
mu_.lock();
queue_.push(std::move(new_value));
mu_.unlock();
cond_.notify_all();
}

/**
* \brief wait until pop an element from the beginning, threadsafe
* \param value the poped value
*/
void WaitAndPop(Message* value) {
std::unique_lock<std::mutex> lk(mu_);
cond_.wait(lk, [this]{return !queue_.empty();});
*value = std::move(queue_.top());
queue_.pop();
}

private:
class Compare {
public:
bool operator()(const Message &l, const Message &r) {
return l.meta.priority <= r.meta.priority;
}
};
mutable std::mutex mu_;
std::priority_queue<Message, std::vector<Message>, Compare> queue_;
std::condition_variable cond_;
};

} // namespace ps

#endif // PS_INTERNAL_THREADSAFE_PQUEUE_H_
30 changes: 22 additions & 8 deletions include/ps/kv_app.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct KVPairs {
SArray<Val> vals;
/** \brief the according value lengths (could be empty) */
SArray<int> lens;
/** \brief priority */
int priority = 0;
};

/**
Expand Down Expand Up @@ -113,9 +115,11 @@ class KVWorker : public SimpleApp {
const std::vector<Val>& vals,
const std::vector<int>& lens = {},
int cmd = 0,
const Callback& cb = nullptr) {
const Callback& cb = nullptr,
int priority = 0) {
return ZPush(
SArray<Key>(keys), SArray<Val>(vals), SArray<int>(lens), cmd, cb);
SArray<Key>(keys), SArray<Val>(vals), SArray<int>(lens), cmd, cb,
priority);
}

/**
Expand Down Expand Up @@ -148,11 +152,13 @@ class KVWorker : public SimpleApp {
std::vector<Val>* vals,
std::vector<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
const Callback& cb = nullptr,
int priority = 0) {
SArray<Key> skeys(keys);
int ts = AddPullCB(skeys, vals, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = skeys;
kvs.priority = priority;
Send(ts, false, true, cmd, kvs);
return ts;
}
Expand Down Expand Up @@ -190,7 +196,8 @@ class KVWorker : public SimpleApp {
std::vector<Val>* outs,
std::vector<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
const Callback& cb = nullptr,
int priority = 0) {
CHECK_NOTNULL(outs);
if (outs->empty())
outs->resize(vals.size());
Expand All @@ -207,7 +214,7 @@ class KVWorker : public SimpleApp {
delete souts;
delete slens;
if (cb) cb();
});
}, priority);
return ts;
}

Expand Down Expand Up @@ -237,13 +244,15 @@ class KVWorker : public SimpleApp {
const SArray<Val>& vals,
const SArray<int>& lens = {},
int cmd = 0,
const Callback& cb = nullptr) {
const Callback& cb = nullptr,
int priority = 0) {
int ts = obj_->NewRequest(kServerGroup);
AddCallback(ts, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
kvs.priority = priority;
Send(ts, true, false, cmd, kvs);
return ts;
}
Expand All @@ -260,10 +269,12 @@ class KVWorker : public SimpleApp {
SArray<Val>* vals,
SArray<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
const Callback& cb = nullptr,
int priority = 0) {
int ts = AddPullCB(keys, vals, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.priority = priority;
Send(ts, false, true, cmd, kvs);
return ts;
}
Expand All @@ -281,11 +292,13 @@ class KVWorker : public SimpleApp {
SArray<Val>* outs,
SArray<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr) {
const Callback& cb = nullptr,
int priority = 0) {
int ts = AddPullCB(keys, outs, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.priority = priority;
if (lens)
kvs.lens = *lens;
Send(ts, true, true, cmd, kvs);
Expand Down Expand Up @@ -586,6 +599,7 @@ void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVP
msg.meta.head = cmd;
msg.meta.timestamp = timestamp;
msg.meta.recver = Postoffice::Get()->ServerRankToID(i);
msg.meta.priority = kvs.priority;
const auto& kvs = s.second;
if (kvs.keys.size()) {
msg.AddData(kvs.keys);
Expand Down
2 changes: 2 additions & 0 deletions src/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,6 @@ message PBMeta {
optional bool simple_app = 6 [default = false];
// message.data_size
optional int32 data_size = 11;
// priority
optional int32 priority = 13 [default = 0];
}
60 changes: 60 additions & 0 deletions src/p3_van.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* Copyright (c) 2015 by Contributors
*/
#ifndef PS_P3_VAN_H_
#define PS_P3_VAN_H_
#include <memory>
namespace ps {

/**
* \brief P3 based Van implementation
*/
class P3Van : public ZMQVan {
public:
P3Van() {}
virtual ~P3Van() {}

protected:
void Start(int customer_id) override {
start_mu_.lock();
if (init_stage == 0) {
// start sender
sender_thread_ = std::unique_ptr<std::thread>(
new std::thread(&P3Van::Sending, this));
init_stage++;
}
start_mu_.unlock();
ZMQVan::Start(customer_id);
}

void Stop() override {
ZMQVan::Stop();
sender_thread_->join();
}

int SendMsg(const Message& msg) override {
send_queue_.Push(msg);
return 0;
}

void Sending() {
while (true) {
Message msg;
send_queue_.WaitAndPop(&msg);
ZMQVan::SendMsg(msg);
if (!msg.meta.control.empty() &&
msg.meta.control.cmd == Control::TERMINATE) {
break;
}
}
}

private:
/** the thread for sending messages */
std::unique_ptr<std::thread> sender_thread_;
ThreadsafePQueue send_queue_;
int init_stage = 0;
};
} // namespace ps

#endif // PS_P3_VAN_H_
2 changes: 1 addition & 1 deletion src/postoffice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void Postoffice::Barrier(int customer_id, int node_group) {
req.meta.customer_id = customer_id;
req.meta.control.barrier_group = node_group;
req.meta.timestamp = van_->GetTimestamp();
CHECK_GT(van_->Send(req), 0);
van_->Send(req);
barrier_cond_.wait(ulk, [this, customer_id] {
return barrier_done_[0][customer_id];
});
Expand Down
8 changes: 7 additions & 1 deletion src/van.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "./ibverbs_van.h"
#include "./resender.h"
#include "./zmq_van.h"
#include "./p3_van.h"

namespace ps {

Expand All @@ -28,6 +29,8 @@ static const int kDefaultHeartbeatInterval = 0;
Van* Van::Create(const std::string& type) {
if (type == "zmq") {
return new ZMQVan();
} else if (type == "p3") {
return new P3Van();
#ifdef DMLC_USE_IBVERBS
} else if (type == "ibverbs") {
return new IBVerbsVan();
Expand Down Expand Up @@ -208,7 +211,7 @@ void Van::ProcessBarrierCommand(Message* msg) {
if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
res.meta.recver = recver_id;
res.meta.timestamp = timestamp_++;
CHECK_GT(Send(res), 0);
Send(res);
}
}
}
Expand Down Expand Up @@ -447,6 +450,7 @@ void Van::PackMetaPB(const Meta& meta, PBMeta* pb) {
pb->set_push(meta.push);
pb->set_request(meta.request);
pb->set_simple_app(meta.simple_app);
pb->set_priority(meta.priority);
pb->set_customer_id(meta.customer_id);
for (auto d : meta.data_type) pb->add_data_type(d);
if (!meta.control.empty()) {
Expand Down Expand Up @@ -481,6 +485,7 @@ void Van::PackMeta(const Meta& meta, char** meta_buf, int* buf_size) {
pb.set_pull(meta.pull);
pb.set_request(meta.request);
pb.set_simple_app(meta.simple_app);
pb.set_priority(meta.priority);
pb.set_customer_id(meta.customer_id);
for (auto d : meta.data_type) pb.add_data_type(d);
if (!meta.control.empty()) {
Expand Down Expand Up @@ -523,6 +528,7 @@ void Van::UnpackMeta(const char* meta_buf, int buf_size, Meta* meta) {
meta->push = pb.push();
meta->pull = pb.pull();
meta->simple_app = pb.simple_app();
meta->priority = pb.priority();
meta->body = pb.body();
meta->customer_id = pb.customer_id();
meta->data_type.resize(pb.data_type_size());
Expand Down
6 changes: 5 additions & 1 deletion src/zmq_van.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <stdlib.h>
#include <zmq.h>
#include <string>
#include <thread>
#include <unordered_map>
#include "ps/internal/van.h"
#if _MSC_VER
Expand Down Expand Up @@ -112,6 +111,11 @@ class ZMQVan : public Van {
if (my_node_.id != Node::kEmpty) {
std::string my_id = "ps" + std::to_string(my_node_.id);
zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size());
const char* watermark = Environment::Get()->find("DMLC_PS_WATER_MARK");
if (watermark) {
const int hwm = atoi(watermark);
zmq_setsockopt(sender, ZMQ_SNDHWM, &hwm, sizeof(hwm));
}
}
// connect
std::string addr = "tcp://" + node.hostname + ":" + std::to_string(node.port);
Expand Down

0 comments on commit db61897

Please sign in to comment.