-
Notifications
You must be signed in to change notification settings - Fork 0
/
unbounded.rs
69 lines (59 loc) · 2.1 KB
/
unbounded.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
use std::time::Duration;
use tokio::sync::mpsc::{unbounded_channel as _req_channel, UnboundedSender as _ReqTx, UnboundedReceiver as _ReqRx};
use crate::{Req, ReqPayload, ResTx, ResRx, ReqError, ReqSendError, ResRecvError, _res_channel};
pub struct ReqTx<R: Req> {
req_tx: _ReqTx<ReqPayload<R>>,
timeout: Option<Duration>,
}
pub struct ReqRx<R: Req> {
req_rx: _ReqRx<ReqPayload<R>>,
}
impl<R: Req> ReqTx<R> {
pub fn send(&self, req: R) -> Result<ResRx<R::Res>, ReqSendError<R>> {
let (res_tx, res_rx) = _res_channel::<R::Res>();
self.req_tx
.send(ReqPayload { req, res_tx: ResTx { res_tx } })
.map_err(|payload| ReqSendError(payload.0.req))?;
let res_rx = ResRx { res_rx: Some(res_rx), timeout: self.timeout };
Ok(res_rx)
}
pub async fn send_recv(&self, request: R) -> Result<R::Res, ReqError<R>> {
let mut res_rx = self.send(request)
.map_err(|err| ReqError::SendError(err.0))?;
res_rx.recv().await
.map_err(|err| match err {
ResRecvError::RecvError => ReqError::RecvError,
ResRecvError::RecvTimeout => ReqError::RecvTimeout,
})
}
pub fn is_closed(&self) -> bool {
self.req_tx.is_closed()
}
}
impl<R: Req> Clone for ReqTx<R> {
fn clone(&self) -> Self {
ReqTx {
req_tx: self.req_tx.clone(),
timeout: self.timeout,
}
}
}
impl<R: Req> ReqRx<R> {
pub async fn recv(&mut self) -> Result<ReqPayload<R>, ReqError<R>> {
match self.req_rx.recv().await {
Some(payload) => Ok(payload),
None => Err(ReqError::RecvError),
}
}
pub fn close(&mut self) {
self.req_rx.close()
}
}
pub fn channel<R: Req>() -> (ReqTx<R>, ReqRx<R>) {
let (req_tx, req_rx) = _req_channel::<ReqPayload<R>>();
(ReqTx { req_tx, timeout: None }, ReqRx { req_rx })
}
pub fn channel_with_timeout<R: Req>(timeout: Duration) -> (ReqTx<R>, ReqRx<R>) {
let (req_tx, req_rx) = _req_channel::<ReqPayload<R>>();
(ReqTx { req_tx, timeout: Some(timeout) }, ReqRx { req_rx })
}