diff --git a/Cargo.lock b/Cargo.lock index 715f87e..0eebaa9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -842,6 +842,7 @@ dependencies = [ "thousands", "tokio", "tokio-serde", + "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", @@ -1351,6 +1352,17 @@ dependencies = [ "serde", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.11" diff --git a/Cargo.toml b/Cargo.toml index f09d179..62dcb7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ interprocess = { version = "2.1", features = ["tokio"], optional = true } hex = "0.4.3" futures = { version = "0.3.30", optional = true } anyhow = "1.0.73" +tokio-stream = { version = "0.1.15", optional = true } [dependencies.educe] # This is an unused dependency, it is needed to make the minimal @@ -59,11 +60,12 @@ futures-buffered = "0.2.4" hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"] quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"] flume-transport = ["dep:flume"] +tokio-mpsc-transport = ["dep:tokio-util", "dep:tokio-stream"] interprocess-transport = ["quinn-transport", "quinn-flume-socket", "dep:quinn-udp", "dep:interprocess", "dep:bytes", "dep:tokio-util", "dep:futures"] combined-transport = [] quinn-flume-socket = ["dep:flume", "dep:quinn", "dep:quinn-udp", "dep:bytes", "dep:tokio-util"] macros = [] -default = ["flume-transport"] +default = ["flume-transport", "tokio-mpsc-transport"] [package.metadata.docs.rs] all-features = true diff --git a/src/transport/boxed.rs b/src/transport/boxed.rs index 4315540..eacd11a 100644 --- a/src/transport/boxed.rs +++ b/src/transport/boxed.rs @@ -8,7 +8,7 @@ use std::{ use futures_lite::FutureExt; use futures_sink::Sink; -#[cfg(feature = "quinn-transport")] +#[cfg(any(feature = "quinn-transport", feature = "tokio-mpsc-transport"))] use futures_util::TryStreamExt; use futures_util::{future::BoxFuture, SinkExt, Stream, StreamExt}; use pin_project::pin_project; @@ -21,6 +21,8 @@ type BoxedFuture<'a, T> = Pin + Send + Sync + 'a>>; enum SendSinkInner { Direct(::flume::r#async::SendSink<'static, T>), + #[cfg(feature = "tokio-mpsc-transport")] + DirectTokio(tokio_util::sync::PollSender), Boxed(Pin + Send + Sync + 'static>>), } @@ -42,6 +44,11 @@ impl SendSink { pub(crate) fn direct(sink: ::flume::r#async::SendSink<'static, T>) -> Self { Self(SendSinkInner::Direct(sink)) } + + #[cfg(feature = "tokio-mpsc-transport")] + pub(crate) fn direct_tokio(sink: tokio_util::sync::PollSender) -> Self { + Self(SendSinkInner::DirectTokio(sink)) + } } impl Sink for SendSink { @@ -53,6 +60,10 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] + SendSinkInner::DirectTokio(sink) => { + sink.poll_ready_unpin(cx).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from), } } @@ -60,6 +71,10 @@ impl Sink for SendSink { fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { match self.project().0 { SendSinkInner::Direct(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] + SendSinkInner::DirectTokio(sink) => { + sink.start_send_unpin(item).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from), } } @@ -70,6 +85,10 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] + SendSinkInner::DirectTokio(sink) => { + sink.poll_flush_unpin(cx).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from), } } @@ -80,6 +99,10 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] + SendSinkInner::DirectTokio(sink) => { + sink.poll_close_unpin(cx).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from), } } @@ -87,6 +110,8 @@ impl Sink for SendSink { enum RecvStreamInner { Direct(::flume::r#async::RecvStream<'static, T>), + #[cfg(feature = "tokio-mpsc-transport")] + DirectTokio(tokio_stream::wrappers::ReceiverStream), Boxed(Pin> + Send + Sync + 'static>>), } @@ -109,6 +134,12 @@ impl RecvStream { pub(crate) fn direct(stream: ::flume::r#async::RecvStream<'static, T>) -> Self { Self(RecvStreamInner::Direct(stream)) } + + /// Create a new receive stream from a direct flume receive stream + #[cfg(feature = "tokio-mpsc-transport")] + pub(crate) fn direct_tokio(stream: tokio_stream::wrappers::ReceiverStream) -> Self { + Self(RecvStreamInner::DirectTokio(stream)) + } } impl Stream for RecvStream { @@ -121,6 +152,12 @@ impl Stream for RecvStream { Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, }, + #[cfg(feature = "tokio-mpsc-transport")] + RecvStreamInner::DirectTokio(stream) => match stream.poll_next_unpin(cx) { + Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }, RecvStreamInner::Boxed(stream) => stream.poll_next_unpin(cx), } } @@ -129,6 +166,9 @@ impl Stream for RecvStream { enum OpenFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future (todo) Direct(super::flume::OpenBiFuture), + /// A direct future (todo) + #[cfg(feature = "tokio-mpsc-transport")] + DirectTokio(BoxFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), /// A boxed future Boxed(BoxFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), } @@ -141,6 +181,13 @@ impl<'a, In: RpcMessage, Out: RpcMessage> OpenFuture<'a, In, Out> { fn direct(f: super::flume::OpenBiFuture) -> Self { Self(OpenFutureInner::Direct(f)) } + /// Create a new boxed future + #[cfg(feature = "tokio-mpsc-transport")] + pub fn direct_tokio( + f: impl Future, RecvStream)>> + Send + Sync + 'a, + ) -> Self { + Self(OpenFutureInner::DirectTokio(Box::pin(f))) + } /// Create a new boxed future pub fn boxed( @@ -159,6 +206,8 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> { .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) .map_err(|e| e.into()), + #[cfg(feature = "tokio-mpsc-transport")] + OpenFutureInner::DirectTokio(f) => f.poll(cx), OpenFutureInner::Boxed(f) => f.poll(cx), } } @@ -167,6 +216,9 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> { enum AcceptFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future Direct(super::flume::AcceptBiFuture), + /// A direct future + #[cfg(feature = "tokio-mpsc-transport")] + DirectTokio(BoxedFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), /// A boxed future Boxed(BoxedFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), } @@ -180,6 +232,14 @@ impl<'a, In: RpcMessage, Out: RpcMessage> AcceptFuture<'a, In, Out> { Self(AcceptFutureInner::Direct(f)) } + /// bla + #[cfg(feature = "tokio-mpsc-transport")] + pub fn direct_tokio( + f: impl Future, RecvStream)>> + Send + Sync + 'a, + ) -> Self { + Self(AcceptFutureInner::DirectTokio(Box::pin(f))) + } + /// Create a new boxed future pub fn boxed( f: impl Future, RecvStream)>> + Send + Sync + 'a, @@ -197,6 +257,8 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<'a, In, Out> { .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) .map_err(|e| e.into()), + #[cfg(feature = "tokio-mpsc-transport")] + AcceptFutureInner::DirectTokio(f) => f.poll(cx), AcceptFutureInner::Boxed(f) => f.poll(cx), } } @@ -368,6 +430,46 @@ impl BoxableServerEndpoint for super::flume::FlumeSe } } +#[cfg(feature = "tokio-mpsc-transport")] +impl BoxableConnection for super::tokio_mpsc::Connection { + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn open_boxed(&self) -> OpenFuture { + let f = Box::pin(async move { + let (send, recv) = super::Connection::open(self).await?; + // return the boxed streams + anyhow::Ok(( + SendSink::direct_tokio(send.0), + RecvStream::direct_tokio(recv.0), + )) + }); + OpenFuture::direct_tokio(f) + } +} + +#[cfg(feature = "tokio-mpsc-transport")] +impl BoxableServerEndpoint for super::tokio_mpsc::ServerEndpoint { + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn accept_bi_boxed(&self) -> AcceptFuture { + let f = async move { + let (send, recv) = super::ServerEndpoint::accept(self).await?; + let send = send.sink_map_err(anyhow::Error::from); + let recv = recv.map_err(anyhow::Error::from); + anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv))) + }; + AcceptFuture::direct_tokio(f) + } + + fn local_addr(&self) -> &[super::LocalAddr] { + super::ServerEndpoint::local_addr(self) + } +} + #[cfg(test)] mod tests { use crate::Service; diff --git a/src/transport/mod.rs b/src/transport/mod.rs index ebbfe3f..dc76621 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -7,7 +7,7 @@ use std::{ fmt::{self, Debug, Display}, net::SocketAddr, }; -#[cfg(feature = "flume-transport")] +#[cfg(all(feature = "flume-transport", feature = "tokio-mpsc-transport"))] pub mod boxed; #[cfg(feature = "combined-transport")] pub mod combined; @@ -21,6 +21,8 @@ pub mod interprocess; pub mod quinn; #[cfg(feature = "quinn-flume-socket")] pub mod quinn_flume_socket; +#[cfg(feature = "tokio-mpsc-transport")] +pub mod tokio_mpsc; pub mod misc; diff --git a/src/transport/tokio_mpsc.rs b/src/transport/tokio_mpsc.rs new file mode 100644 index 0000000..a56f682 --- /dev/null +++ b/src/transport/tokio_mpsc.rs @@ -0,0 +1,340 @@ +//! Memory transport implementation using [tokio::sync::mpsc] + +use futures_lite::Stream; +use futures_sink::Sink; +use tokio_util::sync::PollSender; + +use crate::{ + transport::{self, ConnectionErrors, LocalAddr}, + RpcMessage, Service, +}; +use core::fmt; +use std::{error, fmt::Display, future::Future, pin::Pin, result, sync::Arc, task::Poll}; +use tokio::sync::{mpsc, Mutex}; + +use super::ConnectionCommon; + +/// Error when receiving from a channel +/// +/// This type has zero inhabitants, so it is always safe to unwrap a result with this error type. +#[derive(Debug)] +pub enum RecvError {} + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +/// Sink for memory channels +pub struct SendSink(pub(crate) tokio_util::sync::PollSender); + +impl fmt::Debug for SendSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendSink").finish() + } +} + +impl Sink for SendSink { + type Error = self::SendError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_ready(cx) + .map_err(|_| SendError::ReceiverDropped) + } + + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + Pin::new(&mut self.0) + .start_send(item) + .map_err(|_| SendError::ReceiverDropped) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_flush(cx) + .map_err(|_| SendError::ReceiverDropped) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_close(cx) + .map_err(|_| SendError::ReceiverDropped) + } +} + +/// Stream for memory channels +pub struct RecvStream(pub(crate) tokio_stream::wrappers::ReceiverStream); + +impl fmt::Debug for RecvStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RecvStream").finish() + } +} + +impl Stream for RecvStream { + type Item = result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match Pin::new(&mut self.0).poll_next(cx) { + Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl error::Error for RecvError {} + +/// A `tokio::sync::mpsc` based server endpoint. +/// +/// Created using [connection]. +pub struct ServerEndpoint { + #[allow(clippy::type_complexity)] + stream: Arc, RecvStream)>>>, +} + +impl Clone for ServerEndpoint { + fn clone(&self) -> Self { + Self { + stream: self.stream.clone(), + } + } +} + +impl fmt::Debug for ServerEndpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerEndpoint") + .field("stream", &self.stream) + .finish() + } +} + +impl ConnectionErrors for ServerEndpoint { + type SendError = self::SendError; + + type RecvError = self::RecvError; + + type OpenError = self::AcceptBiError; +} + +type Socket = (self::SendSink, self::RecvStream); + +impl ConnectionCommon for ServerEndpoint { + type SendSink = SendSink; + type RecvStream = RecvStream; +} + +impl transport::ServerEndpoint for ServerEndpoint { + async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> { + let (send, recv) = self + .stream + .lock() + .await + .recv() + .await + .ok_or(AcceptBiError::RemoteDropped)?; + Ok((send, recv)) + } + + fn local_addr(&self) -> &[LocalAddr] { + &[LocalAddr::Mem] + } +} + +impl ConnectionErrors for Connection { + type SendError = self::SendError; + + type RecvError = self::RecvError; + + type OpenError = self::OpenBiError; +} + +impl ConnectionCommon for Connection { + type SendSink = SendSink; + type RecvStream = RecvStream; +} + +impl transport::Connection for Connection { + #[allow(refining_impl_trait)] + fn open(&self) -> OpenBiFuture { + let (local_send, remote_recv) = mpsc::channel::(128); + let (remote_send, local_recv) = mpsc::channel::(128); + let remote_chan = ( + SendSink(tokio_util::sync::PollSender::new(remote_send)), + RecvStream(tokio_stream::wrappers::ReceiverStream::new(remote_recv)), + ); + let local_chan = ( + SendSink(tokio_util::sync::PollSender::new(local_send)), + RecvStream(tokio_stream::wrappers::ReceiverStream::new(local_recv)), + ); + let sender = PollSender::new(self.sink.clone()); + OpenBiFuture::new(sender, remote_chan, local_chan) + } +} + +/// Future returned by [FlumeConnection::open] +pub struct OpenBiFuture { + inner: PollSender>, + send: Option>, + res: Option>, +} + +impl fmt::Debug for OpenBiFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpenBiFuture").finish() + } +} + +impl OpenBiFuture { + fn new( + inner: PollSender>, + send: Socket, + res: Socket, + ) -> Self { + Self { + inner, + send: Some(send), + res: Some(res), + } + } +} + +impl Future for OpenBiFuture { + type Output = result::Result, self::OpenBiError>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match Pin::new(&mut self.inner).poll_reserve(cx) { + Poll::Ready(Ok(())) => { + let Some(item) = self.send.take() else { + return Poll::Pending; + }; + let Ok(_) = self.inner.send_item(item) else { + return Poll::Ready(Err(self::OpenBiError::RemoteDropped)); + }; + self.res + .take() + .map(|x| Poll::Ready(Ok(x))) + .unwrap_or(Poll::Pending) + } + Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenBiError::RemoteDropped)), + Poll::Pending => Poll::Pending, + } + } +} + +/// A tokio::sync::mpsc based connection to a server endpoint. +/// +/// Created using [connection]. +pub struct Connection { + #[allow(clippy::type_complexity)] + sink: mpsc::Sender<(SendSink, RecvStream)>, +} + +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + sink: self.sink.clone(), + } + } +} + +impl fmt::Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientChannel") + .field("sink", &self.sink) + .finish() + } +} + +/// AcceptBiError for mem channels. +/// +/// There is not much that can go wrong with mem channels. +#[derive(Debug)] +pub enum AcceptBiError { + /// The remote side of the channel was dropped + RemoteDropped, +} + +impl fmt::Display for AcceptBiError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl error::Error for AcceptBiError {} + +/// SendError for mem channels. +/// +/// There is not much that can go wrong with mem channels. +#[derive(Debug)] +pub enum SendError { + /// Receiver was dropped + ReceiverDropped, +} + +impl Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for SendError {} + +/// OpenBiError for mem channels. +#[derive(Debug)] +pub enum OpenBiError { + /// The remote side of the channel was dropped + RemoteDropped, +} + +impl Display for OpenBiError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for OpenBiError {} + +/// CreateChannelError for mem channels. +/// +/// You can always create a mem channel, so there is no possible error. +/// Nevertheless we need a type for it. +#[derive(Debug, Clone, Copy)] +pub enum CreateChannelError {} + +impl Display for CreateChannelError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for CreateChannelError {} + +/// Create a mpsc server endpoint and a connected mpsc client channel. +/// +/// `buffer` the size of the buffer for each channel. Keep this at a low value to get backpressure +pub fn connection(buffer: usize) -> (ServerEndpoint, Connection) { + let (sink, stream) = mpsc::channel(buffer); + ( + ServerEndpoint { + stream: Arc::new(Mutex::new(stream)), + }, + Connection { sink }, + ) +} diff --git a/tests/tokio_mpsc.rs b/tests/tokio_mpsc.rs new file mode 100644 index 0000000..a2f5735 --- /dev/null +++ b/tests/tokio_mpsc.rs @@ -0,0 +1,113 @@ +#![cfg(feature = "tokio-mpsc-transport")] +#![allow(non_local_definitions)] +mod math; +use math::*; +use quic_rpc::{ + server::{RpcChannel, RpcServerError}, + transport::tokio_mpsc, + RpcClient, RpcServer, Service, +}; + +#[tokio::test] +async fn tokio_mpsc_channel_bench() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = tokio_mpsc::connection::(1); + + let server = RpcServer::::new(server); + let server_handle = tokio::task::spawn(ComputeService::server(server)); + let client = RpcClient::::new(client); + bench(client, 1000000).await?; + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +} + +#[tokio::test] +async fn tokio_mpsc_channel_mapped_bench() -> anyhow::Result<()> { + use derive_more::{From, TryInto}; + use serde::{Deserialize, Serialize}; + + tracing_subscriber::fmt::try_init().ok(); + + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum OuterRequest { + Inner(InnerRequest), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum InnerRequest { + Compute(ComputeRequest), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum OuterResponse { + Inner(InnerResponse), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum InnerResponse { + Compute(ComputeResponse), + } + #[derive(Debug, Clone)] + struct OuterService; + impl Service for OuterService { + type Req = OuterRequest; + type Res = OuterResponse; + } + #[derive(Debug, Clone)] + struct InnerService; + impl Service for InnerService { + type Req = InnerRequest; + type Res = InnerResponse; + } + let (server, client) = tokio_mpsc::connection::(1); + + let server = RpcServer::new(server); + let server_handle: tokio::task::JoinHandle>> = + tokio::task::spawn(async move { + let service = ComputeService; + loop { + let (req, chan) = server.accept().await?.read_first().await?; + let service = service.clone(); + tokio::spawn(async move { + let req: OuterRequest = req; + match req { + OuterRequest::Inner(InnerRequest::Compute(req)) => { + let chan: RpcChannel = chan.map(); + let chan: RpcChannel = chan.map(); + ComputeService::handle_rpc_request(service, req, chan).await + } + } + }); + } + }); + + let client = RpcClient::::new(client); + let client: RpcClient = client.map(); + let client: RpcClient = client.map(); + bench(client, 1000000).await?; + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +} + +/// simple happy path test for all 4 patterns +#[tokio::test] +async fn tokio_mpsc_channel_smoke() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = tokio_mpsc::connection::(1); + + let server = RpcServer::::new(server); + let server_handle = tokio::task::spawn(ComputeService::server(server)); + smoke_test(client).await?; + + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +}