Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new transport based on tokio::sync::mpsc #86

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ interprocess = { version = "2.1", features = ["tokio"], optional = true }
hex = "0.4.3"
futures = { version = "0.3.30", optional = true }
anyhow = "1.0.73"
tokio-stream = { version = "0.1.15", optional = true }

[dependencies.educe]
# This is an unused dependency, it is needed to make the minimal
Expand All @@ -59,11 +60,12 @@ futures-buffered = "0.2.4"
hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"]
quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"]
flume-transport = ["dep:flume"]
tokio-mpsc-transport = ["dep:tokio-util", "dep:tokio-stream"]
interprocess-transport = ["quinn-transport", "quinn-flume-socket", "dep:quinn-udp", "dep:interprocess", "dep:bytes", "dep:tokio-util", "dep:futures"]
combined-transport = []
quinn-flume-socket = ["dep:flume", "dep:quinn", "dep:quinn-udp", "dep:bytes", "dep:tokio-util"]
macros = []
default = ["flume-transport"]
default = ["flume-transport", "tokio-mpsc-transport"]

[package.metadata.docs.rs]
all-features = true
Expand Down
104 changes: 103 additions & 1 deletion src/transport/boxed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{

use futures_lite::FutureExt;
use futures_sink::Sink;
#[cfg(feature = "quinn-transport")]
#[cfg(any(feature = "quinn-transport", feature = "tokio-mpsc-transport"))]
use futures_util::TryStreamExt;
use futures_util::{future::BoxFuture, SinkExt, Stream, StreamExt};
use pin_project::pin_project;
Expand All @@ -21,6 +21,8 @@ type BoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>;

enum SendSinkInner<T: RpcMessage> {
Direct(::flume::r#async::SendSink<'static, T>),
#[cfg(feature = "tokio-mpsc-transport")]
DirectTokio(tokio_util::sync::PollSender<T>),
Boxed(Pin<Box<dyn Sink<T, Error = anyhow::Error> + Send + Sync + 'static>>),
}

Expand All @@ -42,6 +44,11 @@ impl<T: RpcMessage> SendSink<T> {
pub(crate) fn direct(sink: ::flume::r#async::SendSink<'static, T>) -> Self {
Self(SendSinkInner::Direct(sink))
}

#[cfg(feature = "tokio-mpsc-transport")]
pub(crate) fn direct_tokio(sink: tokio_util::sync::PollSender<T>) -> Self {
Self(SendSinkInner::DirectTokio(sink))
}
}

impl<T: RpcMessage> Sink<T> for SendSink<T> {
Expand All @@ -53,13 +60,21 @@ impl<T: RpcMessage> Sink<T> for SendSink<T> {
) -> Poll<Result<(), Self::Error>> {
match self.project().0 {
SendSinkInner::Direct(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from),
#[cfg(feature = "tokio-mpsc-transport")]
SendSinkInner::DirectTokio(sink) => {
sink.poll_ready_unpin(cx).map_err(anyhow::Error::from)
}
SendSinkInner::Boxed(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from),
}
}

fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
match self.project().0 {
SendSinkInner::Direct(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from),
#[cfg(feature = "tokio-mpsc-transport")]
SendSinkInner::DirectTokio(sink) => {
sink.start_send_unpin(item).map_err(anyhow::Error::from)
}
SendSinkInner::Boxed(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from),
}
}
Expand All @@ -70,6 +85,10 @@ impl<T: RpcMessage> Sink<T> for SendSink<T> {
) -> Poll<Result<(), Self::Error>> {
match self.project().0 {
SendSinkInner::Direct(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from),
#[cfg(feature = "tokio-mpsc-transport")]
SendSinkInner::DirectTokio(sink) => {
sink.poll_flush_unpin(cx).map_err(anyhow::Error::from)
}
SendSinkInner::Boxed(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from),
}
}
Expand All @@ -80,13 +99,19 @@ impl<T: RpcMessage> Sink<T> for SendSink<T> {
) -> Poll<Result<(), Self::Error>> {
match self.project().0 {
SendSinkInner::Direct(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from),
#[cfg(feature = "tokio-mpsc-transport")]
SendSinkInner::DirectTokio(sink) => {
sink.poll_close_unpin(cx).map_err(anyhow::Error::from)
}
SendSinkInner::Boxed(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from),
}
}
}

enum RecvStreamInner<T: RpcMessage> {
Direct(::flume::r#async::RecvStream<'static, T>),
#[cfg(feature = "tokio-mpsc-transport")]
DirectTokio(tokio_stream::wrappers::ReceiverStream<T>),
Boxed(Pin<Box<dyn Stream<Item = Result<T, anyhow::Error>> + Send + Sync + 'static>>),
}

Expand All @@ -109,6 +134,12 @@ impl<T: RpcMessage> RecvStream<T> {
pub(crate) fn direct(stream: ::flume::r#async::RecvStream<'static, T>) -> Self {
Self(RecvStreamInner::Direct(stream))
}

/// Create a new receive stream from a direct flume receive stream
#[cfg(feature = "tokio-mpsc-transport")]
pub(crate) fn direct_tokio(stream: tokio_stream::wrappers::ReceiverStream<T>) -> Self {
Self(RecvStreamInner::DirectTokio(stream))
}
}

impl<T: RpcMessage> Stream for RecvStream<T> {
Expand All @@ -121,6 +152,12 @@ impl<T: RpcMessage> Stream for RecvStream<T> {
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
},
#[cfg(feature = "tokio-mpsc-transport")]
RecvStreamInner::DirectTokio(stream) => match stream.poll_next_unpin(cx) {
Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
},
RecvStreamInner::Boxed(stream) => stream.poll_next_unpin(cx),
}
}
Expand All @@ -129,6 +166,9 @@ impl<T: RpcMessage> Stream for RecvStream<T> {
enum OpenFutureInner<'a, In: RpcMessage, Out: RpcMessage> {
/// A direct future (todo)
Direct(super::flume::OpenBiFuture<In, Out>),
/// A direct future (todo)
#[cfg(feature = "tokio-mpsc-transport")]
DirectTokio(BoxFuture<'a, anyhow::Result<(SendSink<Out>, RecvStream<In>)>>),
/// A boxed future
Boxed(BoxFuture<'a, anyhow::Result<(SendSink<Out>, RecvStream<In>)>>),
}
Expand All @@ -141,6 +181,13 @@ impl<'a, In: RpcMessage, Out: RpcMessage> OpenFuture<'a, In, Out> {
fn direct(f: super::flume::OpenBiFuture<In, Out>) -> Self {
Self(OpenFutureInner::Direct(f))
}
/// Create a new boxed future
#[cfg(feature = "tokio-mpsc-transport")]
pub fn direct_tokio(
f: impl Future<Output = anyhow::Result<(SendSink<Out>, RecvStream<In>)>> + Send + Sync + 'a,
) -> Self {
Self(OpenFutureInner::DirectTokio(Box::pin(f)))
}

/// Create a new boxed future
pub fn boxed(
Expand All @@ -159,6 +206,8 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> {
.poll(cx)
.map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0)))
.map_err(|e| e.into()),
#[cfg(feature = "tokio-mpsc-transport")]
OpenFutureInner::DirectTokio(f) => f.poll(cx),
OpenFutureInner::Boxed(f) => f.poll(cx),
}
}
Expand All @@ -167,6 +216,9 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> {
enum AcceptFutureInner<'a, In: RpcMessage, Out: RpcMessage> {
/// A direct future
Direct(super::flume::AcceptBiFuture<In, Out>),
/// A direct future
#[cfg(feature = "tokio-mpsc-transport")]
DirectTokio(BoxedFuture<'a, anyhow::Result<(SendSink<Out>, RecvStream<In>)>>),
/// A boxed future
Boxed(BoxedFuture<'a, anyhow::Result<(SendSink<Out>, RecvStream<In>)>>),
}
Expand All @@ -180,6 +232,14 @@ impl<'a, In: RpcMessage, Out: RpcMessage> AcceptFuture<'a, In, Out> {
Self(AcceptFutureInner::Direct(f))
}

/// bla
#[cfg(feature = "tokio-mpsc-transport")]
pub fn direct_tokio(
f: impl Future<Output = anyhow::Result<(SendSink<Out>, RecvStream<In>)>> + Send + Sync + 'a,
) -> Self {
Self(AcceptFutureInner::DirectTokio(Box::pin(f)))
}

/// Create a new boxed future
pub fn boxed(
f: impl Future<Output = anyhow::Result<(SendSink<Out>, RecvStream<In>)>> + Send + Sync + 'a,
Expand All @@ -197,6 +257,8 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<'a, In, Out> {
.poll(cx)
.map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0)))
.map_err(|e| e.into()),
#[cfg(feature = "tokio-mpsc-transport")]
AcceptFutureInner::DirectTokio(f) => f.poll(cx),
AcceptFutureInner::Boxed(f) => f.poll(cx),
}
}
Expand Down Expand Up @@ -368,6 +430,46 @@ impl<S: Service> BoxableServerEndpoint<S::Req, S::Res> for super::flume::FlumeSe
}
}

#[cfg(feature = "tokio-mpsc-transport")]
impl<S: Service> BoxableConnection<S::Res, S::Req> for super::tokio_mpsc::Connection<S> {
fn clone_box(&self) -> Box<dyn BoxableConnection<S::Res, S::Req>> {
Box::new(self.clone())
}

fn open_boxed(&self) -> OpenFuture<S::Res, S::Req> {
let f = Box::pin(async move {
let (send, recv) = super::Connection::open(self).await?;
// return the boxed streams
anyhow::Ok((
SendSink::direct_tokio(send.0),
RecvStream::direct_tokio(recv.0),
))
});
OpenFuture::direct_tokio(f)
}
}

#[cfg(feature = "tokio-mpsc-transport")]
impl<S: Service> BoxableServerEndpoint<S::Req, S::Res> for super::tokio_mpsc::ServerEndpoint<S> {
fn clone_box(&self) -> Box<dyn BoxableServerEndpoint<S::Req, S::Res>> {
Box::new(self.clone())
}

fn accept_bi_boxed(&self) -> AcceptFuture<S::Req, S::Res> {
let f = async move {
let (send, recv) = super::ServerEndpoint::accept(self).await?;
let send = send.sink_map_err(anyhow::Error::from);
let recv = recv.map_err(anyhow::Error::from);
anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv)))
};
AcceptFuture::direct_tokio(f)
}

fn local_addr(&self) -> &[super::LocalAddr] {
super::ServerEndpoint::local_addr(self)
}
}

#[cfg(test)]
mod tests {
use crate::Service;
Expand Down
4 changes: 3 additions & 1 deletion src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
fmt::{self, Debug, Display},
net::SocketAddr,
};
#[cfg(feature = "flume-transport")]
#[cfg(all(feature = "flume-transport", feature = "tokio-mpsc-transport"))]
pub mod boxed;
#[cfg(feature = "combined-transport")]
pub mod combined;
Expand All @@ -21,6 +21,8 @@ pub mod interprocess;
pub mod quinn;
#[cfg(feature = "quinn-flume-socket")]
pub mod quinn_flume_socket;
#[cfg(feature = "tokio-mpsc-transport")]
pub mod tokio_mpsc;

pub mod misc;

Expand Down
Loading
Loading