diff --git a/azmq/detail/receive_op.hpp b/azmq/detail/receive_op.hpp index b65c33c..7bd1182 100644 --- a/azmq/detail/receive_op.hpp +++ b/azmq/detail/receive_op.hpp @@ -14,11 +14,19 @@ #include "socket_ops.hpp" #include "reactor_op.hpp" +#include #include +#include +#include +#if BOOST_VERSION >= 107900 +#include +#include +#endif #include #include +#include namespace azmq { namespace detail { @@ -31,6 +39,23 @@ class receive_buffer_op_base : public reactor_op { { } virtual bool do_perform(socket_type& socket) override { + return do_perform_impl(socket); + } + +private: + template + typename std::enable_if::value, bool>::type do_perform_impl(socket_type& socket) + { + ec_ = boost::system::error_code(); + bytes_transferred_ += socket_ops::receive(const_cast(buffers_), socket, flags_ | ZMQ_DONTWAIT, ec_); + if (ec_) + return !try_again(); + return true; + } + + template + typename std::enable_if::value, bool>::type do_perform_impl(socket_type& socket) + { ec_ = boost::system::error_code(); bytes_transferred_ += socket_ops::receive(buffers_, socket, flags_ | ZMQ_DONTWAIT, ec_); if (ec_) @@ -44,7 +69,7 @@ class receive_buffer_op_base : public reactor_op { } private: - MutableBufferSequence buffers_; + typename std::conditional::value, MutableBufferSequence const&, MutableBufferSequence>::type buffers_; flags_type flags_; }; @@ -57,14 +82,30 @@ class receive_buffer_op : public receive_buffer_op_base { socket_ops::flags_type flags) : receive_buffer_op_base(buffers, flags) , handler_(std::move(handler)) + , work_guard(boost::asio::make_work_guard(handler_)) { } virtual void do_complete() override { - handler_(this->ec_, this->bytes_transferred_); +#if BOOST_VERSION >= 107900 + auto alloc = boost::asio::get_associated_allocator( + handler_, boost::asio::recycling_allocator()); +#endif + boost::asio::dispatch(work_guard.get_executor(), +#if BOOST_VERSION >= 107900 + boost::asio::bind_allocator(alloc, +#endif + [ec_ = this->ec_, handler_ = std::move(handler_), bytes_transferred_ = this->bytes_transferred_]() mutable { + handler_(ec_, bytes_transferred_); + }) +#if BOOST_VERSION >= 107900 + ) +#endif + ; } private: Handler handler_; + boost::asio::executor_work_guard::type> work_guard; }; template(buffers, flags) , handler_(std::move(handler)) + , work_guard(boost::asio::make_work_guard(handler_)) { } virtual void do_complete() override { - handler_(this->ec_, std::make_pair(this->bytes_transferred_, this->more())); +#if BOOST_VERSION >= 107900 + auto alloc = boost::asio::get_associated_allocator( + handler_, boost::asio::recycling_allocator()); +#endif + boost::asio::dispatch(work_guard.get_executor(), +#if BOOST_VERSION >= 107900 + boost::asio::bind_allocator(alloc, +#endif + [ec_ = this->ec_, handler_ = std::move(handler_), bytes_transferred_ = this->bytes_transferred_, more = this->more()]() mutable { + handler_(ec_, std::make_pair(bytes_transferred_, more)); + }) +#if BOOST_VERSION >= 107900 + ) +#endif + ; } private: Handler handler_; + boost::asio::executor_work_guard::type> work_guard; }; class receive_op_base : public reactor_op { @@ -112,14 +169,30 @@ class receive_op : public receive_op_base { socket_ops::flags_type flags) : receive_op_base(flags) , handler_(std::move(handler)) + , work_guard(boost::asio::make_work_guard(handler_)) { } virtual void do_complete() override { - handler_(ec_, msg_, bytes_transferred_); +#if BOOST_VERSION >= 107900 + auto alloc = boost::asio::get_associated_allocator( + handler_, boost::asio::recycling_allocator()); +#endif + boost::asio::dispatch(work_guard.get_executor(), +#if BOOST_VERSION >= 107900 + boost::asio::bind_allocator(alloc, +#endif + [ec_ = this->ec_, handler_ = std::move(handler_), msg_ = std::move(msg_), bytes_transferred_ = this->bytes_transferred_]() mutable { + handler_(ec_, msg_, bytes_transferred_); + }) +#if BOOST_VERSION >= 107900 + ) +#endif + ; } private: Handler handler_; + boost::asio::executor_work_guard::type> work_guard; }; } // namespace detail } // namespace azmq diff --git a/azmq/detail/send_op.hpp b/azmq/detail/send_op.hpp index 4823b55..67f868c 100644 --- a/azmq/detail/send_op.hpp +++ b/azmq/detail/send_op.hpp @@ -13,7 +13,14 @@ #include "socket_ops.hpp" #include "reactor_op.hpp" +#include #include +#include +#include +#if BOOST_VERSION >= 107900 +#include +#include +#endif #include #include @@ -52,14 +59,30 @@ class send_buffer_op : public send_buffer_op_base { reactor_op::flags_type flags) : send_buffer_op_base(buffers, flags) , handler_(std::move(handler)) + , work_guard(boost::asio::make_work_guard(handler_)) { } virtual void do_complete() override { - handler_(this->ec_, this->bytes_transferred_); +#if BOOST_VERSION >= 107900 + auto alloc = boost::asio::get_associated_allocator( + handler_, boost::asio::recycling_allocator()); +#endif + boost::asio::dispatch(work_guard.get_executor(), +#if BOOST_VERSION >= 107900 + boost::asio::bind_allocator(alloc, +#endif + [ec_ = this->ec_, handler_ = std::move(handler_), bytes_transferred_ = this->bytes_transferred_]() mutable { + handler_(ec_, bytes_transferred_); + }) +#if BOOST_VERSION >= 107900 + ) +#endif + ; } private: Handler handler_; + boost::asio::executor_work_guard::type> work_guard; }; class send_op_base : public reactor_op { @@ -90,14 +113,31 @@ class send_op : public send_op_base { flags_type flags) : send_op_base(std::move(msg), flags) , handler_(std::move(handler)) + , work_guard(boost::asio::make_work_guard(handler_)) { } virtual void do_complete() override { - handler_(ec_, bytes_transferred_); +#if BOOST_VERSION >= 107900 + auto alloc = boost::asio::get_associated_allocator( + handler_, boost::asio::recycling_allocator()); +#endif + boost::asio::dispatch(work_guard.get_executor(), +#if BOOST_VERSION >= 107900 + boost::asio::bind_allocator(alloc, +#endif + [ec_ = this->ec_, handler_ = std::move(handler_), bytes_transferred_ = this->bytes_transferred_]() mutable { + handler_(ec_, bytes_transferred_); + }) +#if BOOST_VERSION >= 107900 + ) +#endif + ; + } private: Handler handler_; + boost::asio::executor_work_guard::type> work_guard; }; } // namespace detail diff --git a/test/socket/main.cpp b/test/socket/main.cpp index 44f5064..e56e3d8 100644 --- a/test/socket/main.cpp +++ b/test/socket/main.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -18,6 +19,9 @@ #include #include #include +#if BOOST_VERSION >= 107400 +#include +#endif #endif #include @@ -879,19 +883,74 @@ TEST_CASE("Async Operation Send/Receive single message, stackful coroutine, one auto frame1 = azmq::message{}; auto const btb1 = azmq::async_receive(sb, frame1, yield); REQUIRE(btb1 == 5); + REQUIRE(frame1.more()); auto frame2 = azmq::message{}; auto const btb2 = azmq::async_receive(sb, frame2, yield); REQUIRE(btb2 == 2); + REQUIRE(frame2.more()); + REQUIRE(message_ref(snd_bufs.at(0)) == message_ref(frame2)); auto frame3 = azmq::message{}; auto const btb3 = azmq::async_receive(sb, frame3, yield); REQUIRE(btb3 == 2); + REQUIRE(!frame3.more()); + REQUIRE(message_ref(snd_bufs.at(1)) == message_ref(frame3)); }); ios.run(); } + +TEST_CASE("Async Operation Send/Receive single message, check thread safety", "[socket_ops]") { + boost::asio::io_service ios; +#if BOOST_VERSION >= 107400 + boost::asio::strand strand{ios.get_executor()}; +#else + boost::asio::strand strand{ios.get_executor()}; +#endif + + azmq::socket sb(ios, ZMQ_ROUTER); + sb.bind(subj(BOOST_CURRENT_FUNCTION)); + + azmq::socket sc(ios, ZMQ_DEALER); + sc.connect(subj(BOOST_CURRENT_FUNCTION)); + + //send coroutine task + boost::asio::spawn(strand, [&](boost::asio::yield_context yield) { + REQUIRE(strand.running_in_this_thread()); + boost::system::error_code ecc; + auto const btc = azmq::async_send(sc, snd_bufs, yield[ecc]); + REQUIRE(strand.running_in_this_thread()); + REQUIRE(!ecc); + REQUIRE(btc == 4); + }); + + //receive coroutine task + boost::asio::spawn(strand, [&](boost::asio::yield_context yield) { + std::array ident; + std::array a; + std::array b; + + std::array rcv_bufs = { {boost::asio::buffer(ident), + boost::asio::buffer(a), + boost::asio::buffer(b)}}; + + boost::system::error_code ecc; + + REQUIRE(strand.running_in_this_thread()); + auto const btb = azmq::async_receive(sb, rcv_bufs, yield[ecc]); + REQUIRE(strand.running_in_this_thread()); + REQUIRE(!ecc); + REQUIRE(btb == 9); + + REQUIRE(message_ref(snd_bufs.at(0)) == boost::string_ref(a.data(), 2)); + REQUIRE(message_ref(snd_bufs.at(1)) == boost::string_ref(b.data(), 2)); + }); + + ios.run(); +} + #endif // BOOST_VERSION >= 107000 diff --git a/test/socket_ops/main.cpp b/test/socket_ops/main.cpp index b4ae8dd..633f20b 100644 --- a/test/socket_ops/main.cpp +++ b/test/socket_ops/main.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include