diff --git a/Cargo.toml b/Cargo.toml index 8762e37..1abd709 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,13 +19,16 @@ secrecy = "0.8.0" serde = { version = "1.0.214", optional = true, features = ["derive"] } sync_docs = { path = "sync_docs" } thiserror = "1.0.67" +tokio = { version = "1.41.1", features = ["time"] } tonic = { version = "0.12.3", features = ["tls", "tls-webpki-roots"] } [build-dependencies] tonic-build = { version = "0.12.3", features = ["prost"] } [dev-dependencies] -tokio = { version = "*", features = ["full"] } +rstest = "0.23.0" +tokio = { version = "1.41.1", features = ["full", "test-util"] } +tokio-stream = "0.1.16" [features] serde = ["dep:serde"] diff --git a/src/streams.rs b/src/streams.rs index 2656491..0358e80 100644 --- a/src/streams.rs +++ b/src/streams.rs @@ -1,10 +1,12 @@ use std::{ pin::Pin, task::{Context, Poll}, + time::Duration, }; use bytesize::ByteSize; -use futures::{Stream, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; +use tokio::time::Sleep; use crate::types::{self, MeteredSize as _}; @@ -21,6 +23,8 @@ pub struct AppendRecordStreamOpts { pub match_seq_num: Option, /// Enforce a fencing token. pub fencing_token: Option>, + /// Linger duration for ready records to send together as a batch. + pub linger: Option, } impl Default for AppendRecordStreamOpts { @@ -30,6 +34,7 @@ impl Default for AppendRecordStreamOpts { max_batch_size: ByteSize::mib(1), match_seq_num: None, fencing_token: None, + linger: None, } } } @@ -41,9 +46,9 @@ impl AppendRecordStreamOpts { } /// Construct from existing options with the new maximum batch records. - pub fn with_max_batch_records(self, max_batch_records: impl Into) -> Self { + pub fn with_max_batch_records(self, max_batch_records: usize) -> Self { Self { - max_batch_records: max_batch_records.into(), + max_batch_records, ..self } } @@ -71,6 +76,19 @@ impl AppendRecordStreamOpts { ..self } } + + /// Construct from existing options with the linger time. + pub fn with_linger(self, linger_duration: impl Into) -> Self { + Self { + linger: Some(linger_duration.into()), + ..self + } + } + + fn linger_sleep_fut(&self) -> Option>> { + self.linger + .map(|duration| Box::pin(tokio::time::sleep(duration))) + } } #[derive(Debug, thiserror::Error)] @@ -89,6 +107,7 @@ where stream: S, peeked_record: Option, terminated: bool, + linger_sleep: Option>>, opts: AppendRecordStreamOpts, } @@ -107,6 +126,7 @@ where stream, peeked_record: None, terminated: false, + linger_sleep: opts.linger_sleep_fut(), opts, }) } @@ -140,6 +160,14 @@ where return Poll::Ready(None); } + if self + .linger_sleep + .as_mut() + .is_some_and(|fut| fut.poll_unpin(cx).is_pending()) + { + return Poll::Pending; + } + let mut batch = Vec::with_capacity(self.opts.max_batch_records); let mut batch_size = ByteSize::b(0); @@ -169,6 +197,9 @@ where if self.terminated { Poll::Ready(None) } else { + // Since we don't have any batches to send, we want to ignore the linger + // interval for the next poll. + self.linger_sleep = None; Poll::Pending } } else { @@ -182,6 +213,9 @@ where *m += batch.len() as u64 } + // Reset the linger sleep future since the old one is polled ready. + self.linger_sleep = self.opts.linger_sleep_fut(); + Poll::Ready(Some(types::AppendInput { records: batch, match_seq_num, @@ -190,3 +224,120 @@ where } } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use bytesize::ByteSize; + use futures::StreamExt as _; + use rstest::rstest; + use tokio::sync::mpsc; + use tokio_stream::wrappers::UnboundedReceiverStream; + + use crate::{ + streams::{AppendRecordStream, AppendRecordStreamOpts}, + types, + }; + + #[rstest] + #[case(Some(2), None)] + #[case(None, Some(ByteSize::b(30)))] + #[case(Some(2), Some(ByteSize::b(100)))] + #[case(Some(10), Some(ByteSize::b(30)))] + #[tokio::test] + async fn test_append_record_stream_batching( + #[case] max_batch_records: Option, + #[case] max_batch_size: Option, + ) { + let stream_iter = (0..100).map(|i| types::AppendRecord::new(format!("r_{i}"))); + let stream = futures::stream::iter(stream_iter); + + let mut opts = AppendRecordStreamOpts::new(); + if let Some(max_batch_records) = max_batch_records { + opts = opts.with_max_batch_records(max_batch_records); + } + if let Some(max_batch_size) = max_batch_size { + opts = opts.with_max_batch_size(max_batch_size); + } + + let batch_stream = AppendRecordStream::new(stream, opts).unwrap(); + + let batches = batch_stream + .map(|batch| batch.records) + .collect::>() + .await; + + let mut i = 0; + for batch in batches { + assert_eq!(batch.len(), 2); + for record in batch { + assert_eq!(record.body, format!("r_{i}").into_bytes()); + i += 1; + } + } + } + + #[tokio::test(start_paused = true)] + async fn test_append_record_stream_linger() { + let (stream_tx, stream_rx) = mpsc::unbounded_channel::(); + let mut i = 0; + + let collect_batches_handle = tokio::spawn(async move { + let batch_stream = AppendRecordStream::new( + UnboundedReceiverStream::new(stream_rx), + AppendRecordStreamOpts::new().with_linger(Duration::from_secs(2)), + ) + .unwrap(); + + batch_stream + .map(|batch| { + batch + .records + .into_iter() + .map(|rec| rec.body) + .collect::>() + }) + .collect::>() + .await + }); + + let mut send_next = || { + stream_tx + .send(types::AppendRecord::new(format!("r_{i}"))) + .unwrap(); + i += 1; + }; + + async fn sleep_secs(secs: u64) { + let dur = Duration::from_secs(secs) + Duration::from_millis(10); + tokio::time::sleep(dur).await; + } + + send_next(); + send_next(); + + sleep_secs(2).await; + + send_next(); + + sleep_secs(1).await; + + send_next(); + + sleep_secs(1).await; + + send_next(); + std::mem::drop(stream_tx); // Should close the stream + + let batches = collect_batches_handle.await.unwrap(); + + let expected_batches = vec![ + vec![b"r_0".to_owned(), b"r_1".to_owned()], + vec![b"r_2".to_owned(), b"r_3".to_owned()], + vec![b"r_4".to_owned()], + ]; + + assert_eq!(batches, expected_batches); + } +}