From db6189746c2acc5e6c3611df27f38b15029a7797 Mon Sep 17 00:00:00 2001 From: Anand J Date: Wed, 16 Oct 2019 00:55:59 -0400 Subject: [PATCH] Priority-based Push and Pull (#147) Messages sent based on a assigned priority Higher priority messages gets preference over lower priority ones Add environment variable for setting high water mark for zmq --- docs/env.md | 3 +- include/ps/internal/customer.h | 4 +- include/ps/internal/message.h | 2 + include/ps/internal/threadsafe_pqueue.h | 59 ++++++++++++++++++++++++ include/ps/kv_app.h | 30 +++++++++---- src/meta.proto | 2 + src/p3_van.h | 60 +++++++++++++++++++++++++ src/postoffice.cc | 2 +- src/van.cc | 8 +++- src/zmq_van.h | 6 ++- 10 files changed, 162 insertions(+), 14 deletions(-) create mode 100644 include/ps/internal/threadsafe_pqueue.h create mode 100644 src/p3_van.h diff --git a/docs/env.md b/docs/env.md index 18133172..4315ea67 100644 --- a/docs/env.md +++ b/docs/env.md @@ -13,4 +13,5 @@ additional variables: - `DMLC_INTERFACE` : the network interface a node should use. in default choose automatically - `DMLC_LOCAL` : runs in local machines, no network is needed -- `DMLC_PS_VAN_TYPE` : the type of the Van for transport, can be `ibverbs` for RDMA, `zmq` for TCP +- `DMLC_PS_WATER_MARK` : limit on the maximum number of outstanding messages +- `DMLC_PS_VAN_TYPE` : the type of the Van for transport, can be `ibverbs` for RDMA, `zmq` for TCP, `p3` for TCP with [priority based parameter propagation](https://anandj.in/wp-content/uploads/sysml.pdf). \ No newline at end of file diff --git a/include/ps/internal/customer.h b/include/ps/internal/customer.h index e4a04b50..e227b532 100644 --- a/include/ps/internal/customer.h +++ b/include/ps/internal/customer.h @@ -12,7 +12,7 @@ #include #include #include "ps/internal/message.h" -#include "ps/internal/threadsafe_queue.h" +#include "ps/internal/threadsafe_pqueue.h" namespace ps { /** @@ -100,7 +100,7 @@ class Customer { int customer_id_; RecvHandle recv_handle_; - ThreadsafeQueue recv_queue_; + ThreadsafePQueue recv_queue_; std::unique_ptr recv_thread_; std::mutex tracker_mu_; diff --git a/include/ps/internal/message.h b/include/ps/internal/message.h index 49203871..11501063 100644 --- a/include/ps/internal/message.h +++ b/include/ps/internal/message.h @@ -195,6 +195,8 @@ struct Meta { Control control; /** \brief the byte size */ int data_size = 0; + /** \brief message priority */ + int priority = 0; }; /** * \brief messages that communicated amaong nodes. diff --git a/include/ps/internal/threadsafe_pqueue.h b/include/ps/internal/threadsafe_pqueue.h new file mode 100644 index 00000000..1e68c6cc --- /dev/null +++ b/include/ps/internal/threadsafe_pqueue.h @@ -0,0 +1,59 @@ +/** + * Copyright (c) 2015 by Contributors + */ +#ifndef PS_INTERNAL_THREADSAFE_PQUEUE_H_ +#define PS_INTERNAL_THREADSAFE_PQUEUE_H_ +#include +#include +#include +#include +#include +#include +#include "ps/base.h" +namespace ps { + +/** + * \brief thread-safe queue allowing push and waited pop + */ +class ThreadsafePQueue { + public: + ThreadsafePQueue() { } + ~ThreadsafePQueue() { } + + /** + * \brief push an value into the end. threadsafe. + * \param new_value the value + */ + void Push(Message new_value) { + mu_.lock(); + queue_.push(std::move(new_value)); + mu_.unlock(); + cond_.notify_all(); + } + + /** + * \brief wait until pop an element from the beginning, threadsafe + * \param value the poped value + */ + void WaitAndPop(Message* value) { + std::unique_lock lk(mu_); + cond_.wait(lk, [this]{return !queue_.empty();}); + *value = std::move(queue_.top()); + queue_.pop(); + } + + private: + class Compare { + public: + bool operator()(const Message &l, const Message &r) { + return l.meta.priority <= r.meta.priority; + } + }; + mutable std::mutex mu_; + std::priority_queue, Compare> queue_; + std::condition_variable cond_; +}; + +} // namespace ps + +#endif // PS_INTERNAL_THREADSAFE_PQUEUE_H_ diff --git a/include/ps/kv_app.h b/include/ps/kv_app.h index 0d1f34a9..50bdde6b 100644 --- a/include/ps/kv_app.h +++ b/include/ps/kv_app.h @@ -40,6 +40,8 @@ struct KVPairs { SArray vals; /** \brief the according value lengths (could be empty) */ SArray lens; + /** \brief priority */ + int priority = 0; }; /** @@ -113,9 +115,11 @@ class KVWorker : public SimpleApp { const std::vector& vals, const std::vector& lens = {}, int cmd = 0, - const Callback& cb = nullptr) { + const Callback& cb = nullptr, + int priority = 0) { return ZPush( - SArray(keys), SArray(vals), SArray(lens), cmd, cb); + SArray(keys), SArray(vals), SArray(lens), cmd, cb, + priority); } /** @@ -148,11 +152,13 @@ class KVWorker : public SimpleApp { std::vector* vals, std::vector* lens = nullptr, int cmd = 0, - const Callback& cb = nullptr) { + const Callback& cb = nullptr, + int priority = 0) { SArray skeys(keys); int ts = AddPullCB(skeys, vals, lens, cmd, cb); KVPairs kvs; kvs.keys = skeys; + kvs.priority = priority; Send(ts, false, true, cmd, kvs); return ts; } @@ -190,7 +196,8 @@ class KVWorker : public SimpleApp { std::vector* outs, std::vector* lens = nullptr, int cmd = 0, - const Callback& cb = nullptr) { + const Callback& cb = nullptr, + int priority = 0) { CHECK_NOTNULL(outs); if (outs->empty()) outs->resize(vals.size()); @@ -207,7 +214,7 @@ class KVWorker : public SimpleApp { delete souts; delete slens; if (cb) cb(); - }); + }, priority); return ts; } @@ -237,13 +244,15 @@ class KVWorker : public SimpleApp { const SArray& vals, const SArray& lens = {}, int cmd = 0, - const Callback& cb = nullptr) { + const Callback& cb = nullptr, + int priority = 0) { int ts = obj_->NewRequest(kServerGroup); AddCallback(ts, cb); KVPairs kvs; kvs.keys = keys; kvs.vals = vals; kvs.lens = lens; + kvs.priority = priority; Send(ts, true, false, cmd, kvs); return ts; } @@ -260,10 +269,12 @@ class KVWorker : public SimpleApp { SArray* vals, SArray* lens = nullptr, int cmd = 0, - const Callback& cb = nullptr) { + const Callback& cb = nullptr, + int priority = 0) { int ts = AddPullCB(keys, vals, lens, cmd, cb); KVPairs kvs; kvs.keys = keys; + kvs.priority = priority; Send(ts, false, true, cmd, kvs); return ts; } @@ -281,11 +292,13 @@ class KVWorker : public SimpleApp { SArray* outs, SArray* lens = nullptr, int cmd = 0, - const Callback& cb = nullptr) { + const Callback& cb = nullptr, + int priority = 0) { int ts = AddPullCB(keys, outs, lens, cmd, cb); KVPairs kvs; kvs.keys = keys; kvs.vals = vals; + kvs.priority = priority; if (lens) kvs.lens = *lens; Send(ts, true, true, cmd, kvs); @@ -586,6 +599,7 @@ void KVWorker::Send(int timestamp, bool push, bool pull, int cmd, const KVP msg.meta.head = cmd; msg.meta.timestamp = timestamp; msg.meta.recver = Postoffice::Get()->ServerRankToID(i); + msg.meta.priority = kvs.priority; const auto& kvs = s.second; if (kvs.keys.size()) { msg.AddData(kvs.keys); diff --git a/src/meta.proto b/src/meta.proto index c83d0c53..7522dd2f 100644 --- a/src/meta.proto +++ b/src/meta.proto @@ -55,4 +55,6 @@ message PBMeta { optional bool simple_app = 6 [default = false]; // message.data_size optional int32 data_size = 11; + // priority + optional int32 priority = 13 [default = 0]; } diff --git a/src/p3_van.h b/src/p3_van.h new file mode 100644 index 00000000..eaddde53 --- /dev/null +++ b/src/p3_van.h @@ -0,0 +1,60 @@ +/** + * Copyright (c) 2015 by Contributors + */ +#ifndef PS_P3_VAN_H_ +#define PS_P3_VAN_H_ +#include +namespace ps { + +/** + * \brief P3 based Van implementation + */ +class P3Van : public ZMQVan { + public: + P3Van() {} + virtual ~P3Van() {} + + protected: + void Start(int customer_id) override { + start_mu_.lock(); + if (init_stage == 0) { + // start sender + sender_thread_ = std::unique_ptr( + new std::thread(&P3Van::Sending, this)); + init_stage++; + } + start_mu_.unlock(); + ZMQVan::Start(customer_id); + } + + void Stop() override { + ZMQVan::Stop(); + sender_thread_->join(); + } + + int SendMsg(const Message& msg) override { + send_queue_.Push(msg); + return 0; + } + + void Sending() { + while (true) { + Message msg; + send_queue_.WaitAndPop(&msg); + ZMQVan::SendMsg(msg); + if (!msg.meta.control.empty() && + msg.meta.control.cmd == Control::TERMINATE) { + break; + } + } + } + + private: + /** the thread for sending messages */ + std::unique_ptr sender_thread_; + ThreadsafePQueue send_queue_; + int init_stage = 0; +}; +} // namespace ps + +#endif // PS_P3_VAN_H_ diff --git a/src/postoffice.cc b/src/postoffice.cc index 19c0cc1c..e9f4386d 100644 --- a/src/postoffice.cc +++ b/src/postoffice.cc @@ -160,7 +160,7 @@ void Postoffice::Barrier(int customer_id, int node_group) { req.meta.customer_id = customer_id; req.meta.control.barrier_group = node_group; req.meta.timestamp = van_->GetTimestamp(); - CHECK_GT(van_->Send(req), 0); + van_->Send(req); barrier_cond_.wait(ulk, [this, customer_id] { return barrier_done_[0][customer_id]; }); diff --git a/src/van.cc b/src/van.cc index 9abfea75..bfb25170 100644 --- a/src/van.cc +++ b/src/van.cc @@ -16,6 +16,7 @@ #include "./ibverbs_van.h" #include "./resender.h" #include "./zmq_van.h" +#include "./p3_van.h" namespace ps { @@ -28,6 +29,8 @@ static const int kDefaultHeartbeatInterval = 0; Van* Van::Create(const std::string& type) { if (type == "zmq") { return new ZMQVan(); + } else if (type == "p3") { + return new P3Van(); #ifdef DMLC_USE_IBVERBS } else if (type == "ibverbs") { return new IBVerbsVan(); @@ -208,7 +211,7 @@ void Van::ProcessBarrierCommand(Message* msg) { if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) { res.meta.recver = recver_id; res.meta.timestamp = timestamp_++; - CHECK_GT(Send(res), 0); + Send(res); } } } @@ -447,6 +450,7 @@ void Van::PackMetaPB(const Meta& meta, PBMeta* pb) { pb->set_push(meta.push); pb->set_request(meta.request); pb->set_simple_app(meta.simple_app); + pb->set_priority(meta.priority); pb->set_customer_id(meta.customer_id); for (auto d : meta.data_type) pb->add_data_type(d); if (!meta.control.empty()) { @@ -481,6 +485,7 @@ void Van::PackMeta(const Meta& meta, char** meta_buf, int* buf_size) { pb.set_pull(meta.pull); pb.set_request(meta.request); pb.set_simple_app(meta.simple_app); + pb.set_priority(meta.priority); pb.set_customer_id(meta.customer_id); for (auto d : meta.data_type) pb.add_data_type(d); if (!meta.control.empty()) { @@ -523,6 +528,7 @@ void Van::UnpackMeta(const char* meta_buf, int buf_size, Meta* meta) { meta->push = pb.push(); meta->pull = pb.pull(); meta->simple_app = pb.simple_app(); + meta->priority = pb.priority(); meta->body = pb.body(); meta->customer_id = pb.customer_id(); meta->data_type.resize(pb.data_type_size()); diff --git a/src/zmq_van.h b/src/zmq_van.h index 050323f2..cc64bd8b 100644 --- a/src/zmq_van.h +++ b/src/zmq_van.h @@ -7,7 +7,6 @@ #include #include #include -#include #include #include "ps/internal/van.h" #if _MSC_VER @@ -112,6 +111,11 @@ class ZMQVan : public Van { if (my_node_.id != Node::kEmpty) { std::string my_id = "ps" + std::to_string(my_node_.id); zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size()); + const char* watermark = Environment::Get()->find("DMLC_PS_WATER_MARK"); + if (watermark) { + const int hwm = atoi(watermark); + zmq_setsockopt(sender, ZMQ_SNDHWM, &hwm, sizeof(hwm)); + } } // connect std::string addr = "tcp://" + node.hostname + ":" + std::to_string(node.port);