Skip to content

Commit

Permalink
Merge pull request #206 from tkrs/enhance-error-structure
Browse files Browse the repository at this point in the history
feat!: enhance error structures using thiserror
  • Loading branch information
tkrs authored Dec 17, 2024
2 parents d64ef72 + 34b4776 commit 887ccee
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 103 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
fail-fast: false
matrix:
version:
- 1.70.0
- 1.80.0
- stable
- nightly
services:
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rmp = "0.8"
rmp-serde = "1.1"
serde = "1.0"
serde_derive = "1.0"
thiserror = "2.0"
uuid = { version = "1.11", features = ["v4"] }

[dev-dependencies]
Expand Down
1 change: 0 additions & 1 deletion examples/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ extern crate serde_derive;
use log::info;
use once_cell::sync::Lazy;
use poston::{Client, Settings, WorkerPool};
use pretty_env_logger;
use rand::prelude::*;
use rand::{self, distributions::Alphanumeric};
use std::time::{Duration, Instant, SystemTime};
Expand Down
73 changes: 63 additions & 10 deletions src/buffer.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
use crate::error::Error;
use crate::rmps::decode as rdecode;
use crate::rmps::encode as rencode;
use crate::rmps::Deserializer;
use crate::time_pack::TimePack;
use rmp::encode;
use serde::Deserialize;
use serde::Serialize;
use std::collections::VecDeque;
use std::time::SystemTime;
use thiserror::Error;

pub trait Buffer<T> {
fn pack(&self) -> Result<Vec<u8>, BufferError>;
}

impl<T: Serialize> Buffer<T> for T {
fn pack(&self) -> Result<Vec<u8>, BufferError> {
let mut buf = Vec::new();
rencode::write_named(&mut buf, self).map_err(BufferError::Pack)?;
Ok(buf)
}
}

pub trait Take<T> {
fn take(&mut self, buf: &mut Vec<T>);
Expand Down Expand Up @@ -34,31 +48,70 @@ pub struct AckReply {
pub ack: String,
}

pub fn pack_record<'a>(
impl TryFrom<&[u8]> for AckReply {
type Error = BufferError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
unpack_response(value, value.len())
}
}

pub struct Record<'a> {
tag: &'a str,
entries: &'a [(SystemTime, Vec<u8>)],
chunk: &'a str,
}

impl<'a> Record<'a> {
pub fn new(tag: &'a str, entries: &'a [(SystemTime, Vec<u8>)], chunk: &'a str) -> Self {
Self {
tag,
entries,
chunk,
}
}

pub fn pack(&self) -> Result<Vec<u8>, BufferError> {
let mut buf = Vec::new();
pack_record(&mut buf, self.tag, self.entries, self.chunk)?;
Ok(buf)
}
}

fn pack_record<'a>(
buf: &mut Vec<u8>,
tag: &'a str,
entries: &'a [(SystemTime, Vec<u8>)],
chunk: &'a str,
) -> Result<(), Error> {
) -> Result<(), BufferError> {
buf.push(0x93);
encode::write_str(buf, tag).map_err(|e| Error::Derive(e.to_string()))?;
encode::write_array_len(buf, entries.len() as u32).map_err(|e| Error::Derive(e.to_string()))?;
encode::write_str(buf, tag)
.map_err(|e| BufferError::Pack(rencode::Error::InvalidValueWrite(e)))?;
encode::write_array_len(buf, entries.len() as u32)
.map_err(|e| BufferError::Pack(rencode::Error::InvalidValueWrite(e)))?;
for (t, entry) in entries {
buf.push(0x92);
t.write_time(buf)
.map_err(|e| Error::Derive(e.to_string()))?;
.map_err(|e| BufferError::Pack(rencode::Error::InvalidValueWrite(e)))?;
buf.extend(entry.iter());
}
let options = Options {
chunk: chunk.to_string(),
};

rencode::write_named(buf, &options).map_err(|e| Error::Derive(e.to_string()))
rencode::write_named(buf, &options).map_err(BufferError::Pack)
}

pub fn unpack_response(resp_buf: &[u8], size: usize) -> Result<AckReply, Error> {
fn unpack_response(resp_buf: &[u8], size: usize) -> Result<AckReply, BufferError> {
let mut de = Deserializer::new(&resp_buf[0..size]);
Deserialize::deserialize(&mut de).map_err(|e| Error::Derive(e.to_string()))
Deserialize::deserialize(&mut de).map_err(BufferError::Unpack)
}

#[derive(Error, Debug)]
pub enum BufferError {
#[error("pack failed")]
Pack(#[from] rencode::Error),
#[error("unpack failed")]
Unpack(#[from] rdecode::Error),
}

#[cfg(test)]
Expand Down Expand Up @@ -149,7 +202,7 @@ mod unpack_response {
#[test]
fn it_should_unpack_as_ack_reply() {
let mut resp_buf = [0u8; 64];
for (i, e) in vec![0x81u8, 0xa3, 0x61, 0x63, 0x6b, 0xa4, 0x61, 0x62, 0x63, 0x3d]
for (i, e) in [0x81u8, 0xa3, 0x61, 0x63, 0x6b, 0xa4, 0x61, 0x62, 0x63, 0x3d]
.iter()
.enumerate()
{
Expand Down
22 changes: 10 additions & 12 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::buffer::Buffer;
use crate::connect;
use crate::error::Error;
use crate::rmps::encode as rencode;
use crate::error::ClientError;
use crate::worker::{Message, Worker};
use crossbeam_channel::{bounded, unbounded, Sender};
use serde::Serialize;
Expand All @@ -11,10 +11,10 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, SystemTime};

pub trait Client {
fn send<A>(&self, tag: String, a: &A, timestamp: SystemTime) -> Result<(), Error>
fn send<A>(&self, tag: String, a: &A, timestamp: SystemTime) -> Result<(), ClientError>
where
A: Serialize;
fn terminate(&self) -> Result<(), Error>;
fn terminate(&self) -> Result<(), ClientError>;
}

pub struct WorkerPool {
Expand Down Expand Up @@ -68,25 +68,25 @@ impl WorkerPool {
}

impl Client for WorkerPool {
fn send<A>(&self, tag: String, a: &A, timestamp: SystemTime) -> Result<(), Error>
fn send<A>(&self, tag: String, a: &A, timestamp: SystemTime) -> Result<(), ClientError>
where
A: Serialize,
A: Buffer<A>,
{
if self.terminated.load(Ordering::Acquire) {
debug!("Worker does already closed.");
return Ok(());
}

let mut buf = Vec::new();
rencode::write_named(&mut buf, a).map_err(|e| Error::Derive(e.to_string()))?;
let buf = a.pack().map_err(ClientError::Buffer)?;

self.sender
.send(Message::Queuing(tag, timestamp, buf))
.map_err(|e| Error::Send(e.to_string()))?;
.map_err(ClientError::SendChannel)?;
Ok(())
}

fn terminate(&self) -> Result<(), Error> {
fn terminate(&self) -> Result<(), ClientError> {
if self.terminated.fetch_or(true, Ordering::SeqCst) {
info!("Worker does already terminated.");
return Ok(());
Expand All @@ -96,9 +96,7 @@ impl Client for WorkerPool {

let (sender, receiver) = bounded::<()>(0);
self.sender.send(Message::Terminating(sender)).unwrap();
receiver
.recv()
.map_err(|e| Error::Terminate(e.to_string()))?;
receiver.recv().map_err(ClientError::RecieveChannel)?;

Ok(())
}
Expand Down
40 changes: 24 additions & 16 deletions src/connect.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::buffer;
use crate::error::Error;
use crate::buffer::{AckReply, BufferError};
use backoff::{Error as RetryError, ExponentialBackoff};
use std::cell::RefCell;
use std::fmt::Debug;
Expand All @@ -20,7 +19,7 @@ pub trait Reconnect {
}

pub trait WriteRead {
fn write_and_read(&mut self, buf: &[u8], chunk: &str) -> Result<(), Error>;
fn write_and_read(&mut self, buf: &[u8], chunk: &str) -> Result<(), StreamError>;
}

#[derive(Debug)]
Expand Down Expand Up @@ -137,7 +136,7 @@ where
A: ToSocketAddrs + Clone + Debug,
S: Connect<S> + Read + Write,
{
fn write_and_read(&mut self, buf: &[u8], chunk: &str) -> Result<(), Error> {
fn write_and_read(&mut self, buf: &[u8], chunk: &str) -> Result<(), StreamError> {
let backoff = ExponentialBackoff {
current_interval: self.write_retry_initial_delay(),
initial_interval: self.write_retry_initial_delay(),
Expand All @@ -149,7 +148,7 @@ where
let op = || {
if self.should_reconnect() {
self.reconnect()
.map_err(|e| RetryError::transient(Error::Network(e.to_string())))?;
.map_err(|e| RetryError::transient(StreamError::Network(e)))?;
}
self.write_all(buf)
.and_then(|_| self.flush())
Expand All @@ -158,7 +157,7 @@ where
if let Err(err) = self.close() {
debug!("Failed to close the stream, cause: {:?}", err);
}
RetryError::transient(Error::Network(e.to_string()))
RetryError::transient(StreamError::Network(e))
})?;

let read_backoff = ExponentialBackoff {
Expand All @@ -176,17 +175,15 @@ where
debug!("Failed to read response, chunk: {}, cause: {:?}", chunk, e);
use io::ErrorKind::*;
match e.kind() {
WouldBlock | TimedOut => {
RetryError::transient(Error::Network(e.to_string()))
}
WouldBlock | TimedOut => RetryError::transient(StreamError::Network(e)),
UnexpectedEof | BrokenPipe | ConnectionAborted | ConnectionRefused
| ConnectionReset => {
if let Err(err) = self.close() {
debug!("Failed to close the stream, cause: {:?}", err);
}
RetryError::permanent(Error::Network(e.to_string()))
RetryError::permanent(StreamError::Network(e))
}
_ => RetryError::Permanent(Error::Network(e.to_string())),
_ => RetryError::Permanent(StreamError::Network(e)),
}
})
};
Expand All @@ -205,7 +202,8 @@ where
}
})?;

let reply = buffer::unpack_response(&resp_buf, resp_buf.len())
let reply = AckReply::try_from(resp_buf.as_ref())
.map_err(StreamError::Buffer)
.map_err(RetryError::transient)?;
if reply.ack == chunk {
Ok(())
Expand All @@ -215,7 +213,7 @@ where
reply.ack, chunk
);

Err(RetryError::transient(Error::AckUmatched(
Err(RetryError::transient(StreamError::AckUmatched(
reply.ack,
chunk.to_string(),
)))
Expand All @@ -233,15 +231,14 @@ impl Connect<TcpStream> for TcpStream {
where
A: ToSocketAddrs + Clone + Debug,
{
TcpStream::connect(addr).map(|s| {
TcpStream::connect(addr).inspect(|s| {
s.set_nodelay(true).unwrap();
if !settings.read_timeout.is_zero() {
s.set_read_timeout(Some(settings.read_timeout)).unwrap();
}
if !settings.write_timeout.is_zero() {
s.set_write_timeout(Some(settings.write_timeout)).unwrap();
}
s
})
}

Expand Down Expand Up @@ -280,6 +277,16 @@ where
})
}

#[derive(thiserror::Error, Debug)]
pub enum StreamError {
#[error("network error")]
Network(#[from] std::io::Error),
#[error("buffer error")]
Buffer(#[from] BufferError),
#[error("request chunk and response ack-id did not match, {0} /= {1}")]
AckUmatched(String, String),
}

#[cfg(test)]
mod tests {
use super::{io, Duration, ToSocketAddrs};
Expand All @@ -304,7 +311,7 @@ mod tests {
where
A: ToSocketAddrs + Clone,
{
if let Ok(_) = addr.to_socket_addrs() {
if addr.to_socket_addrs().is_ok() {
let count = CONN_COUNT.fetch_add(1, Ordering::SeqCst);
if count % 20 == 0 {
Ok(TestStream(AtomicUsize::new(1)))
Expand Down Expand Up @@ -383,6 +390,7 @@ mod tests {

#[derive(Debug)]
struct TS;
#[allow(clippy::type_complexity)]
static QUEUE: Lazy<Mutex<RefCell<VecDeque<Result<usize, io::Error>>>>> = Lazy::new(|| {
let mut q = VecDeque::new();

Expand Down
Loading

0 comments on commit 887ccee

Please sign in to comment.