From eaff8b99b0f977aa1d4a63d9e977a80c9d2074f6 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 9 Feb 2024 20:53:59 +0100 Subject: [PATCH] Adding ZSTD compression to chitchat's Deltas. --- chitchat/Cargo.toml | 10 +- chitchat/src/delta.rs | 623 +++++++++++++++++++----------- chitchat/src/digest.rs | 21 +- chitchat/src/lib.rs | 27 +- chitchat/src/message.rs | 40 +- chitchat/src/serialize.rs | 477 ++++++++++++++++++++--- chitchat/src/server.rs | 2 +- chitchat/src/state.rs | 53 ++- chitchat/src/transport/channel.rs | 3 +- chitchat/src/transport/udp.rs | 2 +- chitchat/tests/cluster_test.rs | 4 +- 11 files changed, 920 insertions(+), 342 deletions(-) diff --git a/chitchat/Cargo.toml b/chitchat/Cargo.toml index a3a79a8..f486241 100644 --- a/chitchat/Cargo.toml +++ b/chitchat/Cargo.toml @@ -16,14 +16,22 @@ bytes = "1" itertools = "0.12" rand = { version = "0.8", features = ["small_rng"] } serde = { version = "1", features = ["derive"] } -tokio = { version = "1.28.0", features = ["net", "sync", "rt-multi-thread", "macros", "time"] } +tokio = { version = "1.28.0", features = [ + "net", + "sync", + "rt-multi-thread", + "macros", + "time", +] } tokio-stream = { version = "0.1", features = ["sync"] } tracing = "0.1" +zstd = "0.13" [dev-dependencies] assert-json-diff = "2" mock_instant = "0.3" tracing-subscriber = "0.3" +proptest = "1.4" [features] testsuite = [] diff --git a/chitchat/src/delta.rs b/chitchat/src/delta.rs index 34e72a6..8260068 100644 --- a/chitchat/src/delta.rs +++ b/chitchat/src/delta.rs @@ -1,59 +1,243 @@ -use std::collections::{BTreeMap, HashSet}; -use std::mem; +use std::collections::HashSet; use crate::serialize::*; use crate::{ChitchatId, Heartbeat, VersionedValue}; +/// A delta is the message we send to another node to update it. +/// +/// Its serialization is done by transforming it into a sequence of operations, +/// encoded one after the other in a compressed stream. #[derive(Debug, Default, Eq, PartialEq)] pub struct Delta { - pub(crate) node_deltas: BTreeMap, - pub(crate) nodes_to_reset: HashSet, + pub(crate) nodes_to_reset: Vec, + pub(crate) node_deltas: Vec, } -impl Serializable for Delta { - fn serialize(&self, buf: &mut Vec) { - (self.node_deltas.len() as u16).serialize(buf); - for (chitchat_id, node_delta) in &self.node_deltas { - chitchat_id.serialize(buf); - node_delta.serialize(buf); - } - (self.nodes_to_reset.len() as u16).serialize(buf); - for chitchat_id in &self.nodes_to_reset { - chitchat_id.serialize(buf); +impl Delta { + fn get_operations<'a>(&'a self) -> impl Iterator> { + let nodes_to_reset_ops = self + .nodes_to_reset + .iter() + .map(|node_to_reset| DeltaOpRef::NodeToReset(node_to_reset)); + let node_deltas = self.node_deltas.iter().flat_map(|node_delta| { + std::iter::once(DeltaOpRef::Node { + chitchat_id: &node_delta.chitchat_id, + heartbeat: node_delta.heartbeat, + }) + .chain(node_delta.key_values.iter().map(|(key, versioned_value)| { + DeltaOpRef::KeyValue { + key, + versioned_value, + } + })) + }); + nodes_to_reset_ops.chain(node_deltas) + } +} + +pub(crate) enum DeltaOp { + NodeToReset(ChitchatId), + Node { + chitchat_id: ChitchatId, + heartbeat: Heartbeat, + }, + KeyValue { + key: String, + versioned_value: VersionedValue, + }, +} + +pub(crate) enum DeltaOpRef<'a> { + NodeToReset(&'a ChitchatId), + Node { + chitchat_id: &'a ChitchatId, + heartbeat: Heartbeat, + }, + KeyValue { + key: &'a str, + versioned_value: &'a VersionedValue, + }, +} + +#[repr(u8)] +enum DeltaOpTag { + Node = 0u8, + KeyValue = 1u8, + NodeToReset = 2u8, +} + +impl TryFrom for DeltaOpTag { + type Error = anyhow::Error; + + fn try_from(tag_byte: u8) -> anyhow::Result { + match tag_byte { + 0u8 => Ok(DeltaOpTag::Node), + 1u8 => Ok(DeltaOpTag::KeyValue), + 2u8 => Ok(DeltaOpTag::NodeToReset), + _ => { + anyhow::bail!("Unknown tag: {tag_byte}") + } } } +} +impl From for u8 { + fn from(tag: DeltaOpTag) -> u8 { + tag as u8 + } +} + +impl Deserializable for DeltaOp { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let mut node_deltas: BTreeMap = Default::default(); - let num_nodes = u16::deserialize(buf)?; - for _ in 0..num_nodes { - let chitchat_id = ChitchatId::deserialize(buf)?; - let node_delta = NodeDelta::deserialize(buf)?; - node_deltas.insert(chitchat_id, node_delta); + let tag_bytes: [u8; 1] = Deserializable::deserialize(buf)?; + let tag = DeltaOpTag::try_from(tag_bytes[0])?; + match tag { + DeltaOpTag::NodeToReset => { + let chitchat_id = ChitchatId::deserialize(buf)?; + Ok(DeltaOp::NodeToReset(chitchat_id)) + } + DeltaOpTag::Node => { + let chitchat_id = ChitchatId::deserialize(buf)?; + let heartbeat = Heartbeat::deserialize(buf)?; + Ok(DeltaOp::Node { + chitchat_id, + heartbeat, + }) + } + DeltaOpTag::KeyValue => { + let key = String::deserialize(buf)?; + let value = String::deserialize(buf)?; + let version = u64::deserialize(buf)?; + let tombstone = Option::::deserialize(buf)?; + let versioned_value: VersionedValue = VersionedValue { + value, + version, + tombstone, + }; + Ok(DeltaOp::KeyValue { + key, + versioned_value, + }) + } } - let num_nodes_to_reset = u16::deserialize(buf)?; - let mut nodes_to_reset = HashSet::with_capacity(num_nodes_to_reset as usize); - for _ in 0..num_nodes_to_reset { - let chitchat_id = ChitchatId::deserialize(buf)?; - nodes_to_reset.insert(chitchat_id); + } +} + +impl DeltaOp { + pub(crate) fn as_ref(&self) -> DeltaOpRef { + match self { + DeltaOp::Node { + chitchat_id, + heartbeat, + } => DeltaOpRef::Node { + chitchat_id, + heartbeat: *heartbeat, + }, + DeltaOp::KeyValue { + key, + versioned_value, + } => DeltaOpRef::KeyValue { + key, + versioned_value, + }, + DeltaOp::NodeToReset(node_to_reset) => DeltaOpRef::NodeToReset(node_to_reset), } - Ok(Delta { - node_deltas, - nodes_to_reset, - }) + } +} + +impl Serializable for DeltaOp { + fn serialize(&self, buf: &mut Vec) { + self.as_ref().serialize(buf) } fn serialized_len(&self) -> usize { - let mut len = 2; - for (chitchat_id, node_delta) in &self.node_deltas { - len += chitchat_id.serialized_len(); - len += node_delta.serialized_len(); + self.as_ref().serialized_len() + } +} + +impl<'a> Serializable for DeltaOpRef<'a> { + fn serialize(&self, buf: &mut Vec) { + match self { + Self::Node { + chitchat_id, + heartbeat, + } => { + buf.push(DeltaOpTag::Node.into()); + chitchat_id.serialize(buf); + heartbeat.serialize(buf); + } + Self::KeyValue { + key, + versioned_value, + } => { + buf.push(DeltaOpTag::KeyValue.into()); + key.serialize(buf); + versioned_value.value.serialize(buf); + versioned_value.version.serialize(buf); + versioned_value.tombstone.serialize(buf); + } + Self::NodeToReset(chitchat_id) => { + buf.push(DeltaOpTag::NodeToReset.into()); + chitchat_id.serialize(buf); + } } - len += 2; - for chitchat_id in &self.nodes_to_reset { - len += chitchat_id.serialized_len(); + } + + fn serialized_len(&self) -> usize { + 1 + match self { + Self::Node { + chitchat_id, + heartbeat, + } => chitchat_id.serialized_len() + heartbeat.serialized_len(), + Self::KeyValue { + key, + versioned_value, + } => { + key.serialized_len() + + versioned_value.value.serialized_len() + + versioned_value.version.serialized_len() + + versioned_value.tombstone.serialized_len() + } + Self::NodeToReset(chitchat_id) => chitchat_id.serialized_len(), + } + } +} + +/// Slow serializable implementation but it is only here for tests. +impl Serializable for Delta { + fn serialize(&self, buf: &mut Vec) { + let mut compressed_stream_writer = CompressedStreamWriter::with_block_threshold(16_384); + for op in self.get_operations() { + compressed_stream_writer.append(&op); } - len + let payload = compressed_stream_writer.finish(); + buf.extend(&payload); + } + + fn serialized_len(&self) -> usize { + // Slow, but never called in practise + let mut buf: Vec = Vec::new(); + self.serialize(&mut buf); + buf.len() + } +} + +impl Deserializable for Delta { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let ops: Vec = crate::serialize::deserialize_stream(buf)?; + ops.try_into() + } +} + +impl TryFrom> for Delta { + type Error = anyhow::Error; + + fn try_from(delta_ops: Vec) -> anyhow::Result { + let mut delta_builder = DeltaBuilder::default(); + for op in delta_ops { + delta_builder.apply_op(op)?; + } + Ok(delta_builder.finish()) } } @@ -61,18 +245,22 @@ impl Serializable for Delta { impl Delta { pub fn num_tuples(&self) -> usize { self.node_deltas - .values() + .iter() .map(|node_delta| node_delta.num_tuples()) .sum() } pub fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) { - self.node_deltas - .entry(chitchat_id) - .or_insert_with(|| NodeDelta { - heartbeat, - ..Default::default() - }); + assert!(self + .node_deltas + .iter() + .find(|node_delta| { &node_delta.chitchat_id == &chitchat_id }) + .is_none()); + self.node_deltas.push(NodeDelta { + chitchat_id, + heartbeat, + key_values: Vec::new(), + }); } pub fn add_kv( @@ -83,26 +271,37 @@ impl Delta { version: crate::Version, tombstone: Option, ) { - let node_delta = self.node_deltas.get_mut(chitchat_id).unwrap(); - node_delta.key_values.insert( + let node_delta = self + .node_deltas + .iter_mut() + .find(|node_delta| &node_delta.chitchat_id == chitchat_id) + .unwrap(); + node_delta.key_values.push(( key.to_string(), VersionedValue { value: value.to_string(), version, tombstone, }, - ); + )); + } + + pub(crate) fn get(&self, chitchat_id: &ChitchatId) -> Option<&NodeDelta> { + self.node_deltas + .iter() + .find(|node_delta| &node_delta.chitchat_id == chitchat_id) } pub fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) { - self.nodes_to_reset.insert(chitchat_id); + self.nodes_to_reset.push(chitchat_id); } } -#[derive(Debug, Default, Eq, PartialEq, serde::Serialize)] +#[derive(Debug, Eq, PartialEq, serde::Serialize)] pub(crate) struct NodeDelta { + pub chitchat_id: ChitchatId, pub heartbeat: Heartbeat, - pub key_values: BTreeMap, + pub key_values: Vec<(String, VersionedValue)>, } #[cfg(test)] @@ -112,164 +311,123 @@ impl NodeDelta { } } -pub struct DeltaWriter { +#[derive(Default)] +pub(crate) struct DeltaBuilder { + existing_nodes: HashSet, delta: Delta, - mtu: usize, - num_bytes: usize, - current_chitchat_id: Option, - current_node_delta: NodeDelta, - reached_capacity: bool, + current_node_delta: Option, } -impl DeltaWriter { - pub fn with_mtu(mtu: usize) -> Self { - DeltaWriter { - delta: Delta::default(), - mtu, - num_bytes: 2 + 2, /* 2 bytes for `nodes_to_reset.len()` + 2 bytes for - * `node_deltas.len()` */ - current_chitchat_id: None, - current_node_delta: NodeDelta::default(), - reached_capacity: false, - } - } - fn flush(&mut self) { - let chitchat_id_opt = mem::take(&mut self.current_chitchat_id); - let node_delta = mem::take(&mut self.current_node_delta); - if let Some(chitchat_id) = chitchat_id_opt { - self.delta.node_deltas.insert(chitchat_id, node_delta); - } +impl DeltaBuilder { + pub fn finish(mut self) -> Delta { + self.flush(); + self.delta } - pub fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) -> bool { - assert!(!self.delta.nodes_to_reset.contains(&chitchat_id)); - if !self.attempt_add_bytes(chitchat_id.serialized_len()) { - return false; + pub(crate) fn apply_op(&mut self, op: DeltaOp) -> anyhow::Result<()> { + match op { + DeltaOp::Node { + chitchat_id, + heartbeat, + } => { + self.flush(); + anyhow::ensure!(!self.existing_nodes.contains(&chitchat_id)); + self.existing_nodes.insert(chitchat_id.clone()); + self.current_node_delta = Some(NodeDelta { + chitchat_id, + heartbeat, + key_values: Vec::new(), + }); + } + DeltaOp::KeyValue { + key, + versioned_value, + } => { + let Some(current_node_delta) = self.current_node_delta.as_mut() else { + anyhow::bail!("received a key-value op without a node op before."); + }; + current_node_delta + .key_values + .push((key.to_string(), versioned_value)); + } + DeltaOp::NodeToReset(chitchat_id) => { + anyhow::ensure!(!self.delta.nodes_to_reset.contains(&chitchat_id)); + self.delta.nodes_to_reset.push(chitchat_id); + } } - self.delta.nodes_to_reset.insert(chitchat_id); - true + Ok(()) } - pub fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) -> bool { - assert!(self.current_chitchat_id.as_ref() != Some(&chitchat_id)); - assert!(!self.delta.node_deltas.contains_key(&chitchat_id)); - self.flush(); - // Reserve bytes for [`ChitchatId`], [`Hearbeat`], and for an empty [`NodeDelta`] which has - // a size of 2 bytes. - if !self.attempt_add_bytes(chitchat_id.serialized_len() + heartbeat.serialized_len() + 2) { - return false; - } - self.current_chitchat_id = Some(chitchat_id); - self.current_node_delta.heartbeat = heartbeat; - true + fn flush(&mut self) { + let Some(node_delta) = self.current_node_delta.take() else { + // There are no nodes in the builder. + // (this happens when the delta builder is freshly created and no ops have been received + // yet.) + return; + }; + self.delta.node_deltas.push(node_delta); } +} - fn attempt_add_bytes(&mut self, num_bytes: usize) -> bool { - assert!(!self.reached_capacity); - let new_num_bytes = self.num_bytes + num_bytes; - if new_num_bytes > self.mtu { - self.reached_capacity = true; - return false; +pub struct DeltaWriter { + mtu: usize, + delta_builder: DeltaBuilder, + compressed_stream_writer: CompressedStreamWriter, +} + +const BLOCK_THRESHOLD: u16 = 16_384u16; + +impl DeltaWriter { + pub fn with_mtu(mtu: usize) -> Self { + assert!(mtu >= 100); + let block_threshold = u16::try_from((BLOCK_THRESHOLD as usize).min(mtu)).unwrap(); + DeltaWriter { + mtu, + delta_builder: DeltaBuilder::default(), + compressed_stream_writer: CompressedStreamWriter::with_block_threshold(block_threshold), } - self.num_bytes = new_num_bytes; - true } - /// Returns false if the KV could not be added because mtu was reached. - pub fn add_kv(&mut self, key: &str, versioned_value: VersionedValue) -> bool { - assert!(!self.current_node_delta.key_values.contains_key(key)); - // Reserve bytes for the key (2 bytes are used to store the key length) and versioned value. - if !self.attempt_add_bytes( - 2 + key.len() - + versioned_value.value.serialized_len() - + versioned_value.version.serialized_len() - + versioned_value.tombstone.serialized_len(), - ) { + fn add_op(&mut self, delta_op: DeltaOp) -> bool { + if self + .compressed_stream_writer + .serialized_len_upperbound_after(&delta_op) + > self.mtu + { return false; } - self.current_node_delta - .key_values - .insert(key.to_string(), versioned_value); + self.compressed_stream_writer.append(&delta_op); + assert!(self.delta_builder.apply_op(delta_op).is_ok()); true } -} -impl From for Delta { - fn from(mut delta_writer: DeltaWriter) -> Delta { - delta_writer.flush(); - if cfg!(debug_assertions) { - let mut buf = Vec::new(); - assert_eq!(delta_writer.num_bytes, delta_writer.delta.serialized_len()); - delta_writer.delta.serialize(&mut buf); - assert_eq!(buf.len(), delta_writer.num_bytes); - } - delta_writer.delta + /// Returns false if the node could not be added because the payload would exceed the mtu. + pub fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) -> bool { + let new_node_op = DeltaOp::Node { + chitchat_id, + heartbeat, + }; + self.add_op(new_node_op) } -} -impl Serializable for NodeDelta { - fn serialize(&self, buf: &mut Vec) { - self.heartbeat.serialize(buf); - (self.key_values.len() as u16).serialize(buf); - for ( - key, - VersionedValue { - value, - version, - tombstone, - }, - ) in &self.key_values - { - key.serialize(buf); - value.serialize(buf); - version.serialize(buf); - tombstone.serialize(buf); - } + /// Returns false if the KV could not be added because the payload would exceed the mtu. + pub fn add_kv(&mut self, key: &str, versioned_value: VersionedValue) -> bool { + let key_value_op = DeltaOp::KeyValue { + key: key.to_string(), + versioned_value, + }; + self.add_op(key_value_op) } - fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let heartbeat = Heartbeat::deserialize(buf)?; - let mut key_values: BTreeMap = Default::default(); - let num_key_values = u16::deserialize(buf)?; - for _ in 0..num_key_values { - let key = String::deserialize(buf)?; - let value = String::deserialize(buf)?; - let version = u64::deserialize(buf)?; - let tombstone = >::deserialize(buf)?; - key_values.insert( - key, - VersionedValue { - value, - version, - tombstone, - }, - ); - } - Ok(Self { - heartbeat, - key_values, - }) + /// Returns false if the node to reset could not be added because the payload would exceed the + /// mtu. + pub fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) -> bool { + let delta_op = DeltaOp::NodeToReset(chitchat_id); + self.add_op(delta_op) } - fn serialized_len(&self) -> usize { - let mut len = 2; - len += self.heartbeat.serialized_len(); - - for ( - key, - VersionedValue { - value, - version, - tombstone, - }, - ) in &self.key_values - { - len += key.serialized_len(); - len += value.serialized_len(); - len += version.serialized_len(); - len += tombstone.serialized_len(); - } - len + pub fn finish(self) -> Delta { + self.delta_builder.finish() } } @@ -279,7 +437,7 @@ mod tests { #[test] fn test_delta_serialization_default() { - test_serdeser_aux(&Delta::default(), 4); + test_serdeser_aux(&Delta::default(), 1); } #[test] @@ -337,22 +495,21 @@ mod tests { tombstone: None, }, )); - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 173); + test_aux_delta_writer(delta_writer, 99); } #[test] fn test_delta_serialization_simple_node() { - // 4 bytes - let mut delta_writer = DeltaWriter::with_mtu(124); + // 1 bytes (End tag) + let mut delta_writer = DeltaWriter::with_mtu(128); // ChitchatId takes 27 bytes = 15 bytes + 2 bytes for node length + "node-10001".len(). let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); - // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (node). + // +37 bytes = 8 bytes (heartbeat) + 27 bytes (node) + 2bytes (block length) assert!(delta_writer.add_node(node1, heartbeat)); - // +23 bytes. + // +24 bytes (kv + op tag) assert!(delta_writer.add_kv( "key11", VersionedValue { @@ -361,7 +518,8 @@ mod tests { tombstone: None, } )); - // +23 bytes. + + // +24 bytes. (kv + op tag) assert!(delta_writer.add_kv( "key12", VersionedValue { @@ -375,24 +533,33 @@ mod tests { let heartbeat = Heartbeat(0); // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (node). assert!(delta_writer.add_node(node2, heartbeat)); + test_aux_delta_writer(delta_writer, 80); + } - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 124); + #[track_caller] + fn test_aux_delta_writer(delta_writer: DeltaWriter, expected_len: usize) { + let delta: Delta = delta_writer.finish(); + test_serdeser_aux(&delta, expected_len) } #[test] fn test_delta_serialization_simple_with_nodes_to_reset() { - // 4 bytes - let mut delta_writer = DeltaWriter::with_mtu(151); - // +27 bytes (ChitchatId). + // 1 bytes (end tag) + let mut delta_writer = DeltaWriter::with_mtu(155); + + // +27 bytes (ChitchatId) + 1 (op tag) + 3 bytes (block len) + // = 32 bytes assert!(delta_writer.add_node_to_reset(ChitchatId::for_local_test(10_000))); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); - // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (ChitchatId). + + // +8 bytes (heartbeat) + 27 bytes (ChitchatId) + (1 op tag) + 3 bytes (pessimistic new + // block) = 71 assert!(delta_writer.add_node(node1, heartbeat)); - // +23 bytes. + // +23 bytes (kv) + 1 (op tag) + // = 95 assert!(delta_writer.add_kv( "key11", VersionedValue { @@ -401,7 +568,8 @@ mod tests { tombstone: None, } )); - // +23 bytes. + // +23 bytes (kv) + 1 (op tag) + // = 119 assert!(delta_writer.add_kv( "key12", VersionedValue { @@ -413,17 +581,17 @@ mod tests { let node2 = ChitchatId::for_local_test(10_002); let heartbeat = Heartbeat(0); - // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (ChitchatId). + // +8 bytes (heartbeat) + 27 bytes (ChitchatId) + 1 byte (op tag) + // = 155 assert!(delta_writer.add_node(node2, heartbeat)); - - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 151); + // The block got compressed. + test_aux_delta_writer(delta_writer, 85); } #[test] fn test_delta_serialization_exceed_mtu_on_add_node() { // 4 bytes. - let mut delta_writer = DeltaWriter::with_mtu(87); + let mut delta_writer = DeltaWriter::with_mtu(100); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); @@ -454,14 +622,14 @@ mod tests { // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (ChitchatId). assert!(!delta_writer.add_node(node2, heartbeat)); - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 87); + // The block got compressed. + test_aux_delta_writer(delta_writer, 72); } #[test] fn test_delta_serialization_exceed_mtu_on_add_node_to_reset() { // 4 bytes. - let mut delta_writer = DeltaWriter::with_mtu(90); + let mut delta_writer = DeltaWriter::with_mtu(100); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); @@ -490,20 +658,24 @@ mod tests { let node2 = ChitchatId::for_local_test(10_002); assert!(!delta_writer.add_node_to_reset(node2)); - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 87); + // The block got compressed. + test_aux_delta_writer(delta_writer, 72); } #[test] fn test_delta_serialization_exceed_mtu_on_add_kv() { - // 4 bytes. - let mut delta_writer = DeltaWriter::with_mtu(86); + // 1 bytes. + let mut delta_writer = DeltaWriter::with_mtu(100); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); - // +37 bytes. + + // + 3 bytes (block tag) + 35 bytes (node) + 1 byte (op tag) + // = 40 assert!(delta_writer.add_node(node1, heartbeat)); - // +23 bytes. + + // +23 bytes (kv) + 1 (op tag) + 3 bytes (pessimistic block tag) + // = 67 assert!(delta_writer.add_kv( "key11", VersionedValue { @@ -512,18 +684,18 @@ mod tests { tombstone: None, } )); - // +23 bytes. + + // +33 bytes (kv) + 1 (op tag) + // = 101 (exceeding mtu!) assert!(!delta_writer.add_kv( "key12", VersionedValue { - value: "val12".to_string(), + value: "val12aaaaaaaaaa".to_string(), version: 2, tombstone: None, } )); - - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 64); + test_aux_delta_writer(delta_writer, 64); } #[test] @@ -560,4 +732,17 @@ mod tests { }, ); } + + #[test] + fn test_delta_op_tag() { + let mut num_valid_tags = 0; + for b in 0..=u8::MAX { + if let Ok(tag) = DeltaOpTag::try_from(b) { + let tag_byte: u8 = tag.into(); + assert_eq!(b, tag_byte); + num_valid_tags += 1; + } + } + assert_eq!(num_valid_tags, 3); + } } diff --git a/chitchat/src/digest.rs b/chitchat/src/digest.rs index 9e0252b..6cb50e9 100644 --- a/chitchat/src/digest.rs +++ b/chitchat/src/digest.rs @@ -45,7 +45,18 @@ impl Serializable for Digest { node_digest.max_version.serialize(buf); } } + fn serialized_len(&self) -> usize { + let mut len = (self.node_digests.len() as u16).serialized_len(); + for (chitchat_id, node_digest) in &self.node_digests { + len += chitchat_id.serialized_len(); + len += node_digest.heartbeat.serialized_len(); + len += node_digest.max_version.serialized_len(); + } + len + } +} +impl Deserializable for Digest { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let num_nodes = u16::deserialize(buf)?; let mut node_digests: BTreeMap = Default::default(); @@ -59,14 +70,4 @@ impl Serializable for Digest { } Ok(Digest { node_digests }) } - - fn serialized_len(&self) -> usize { - let mut len = (self.node_digests.len() as u16).serialized_len(); - for (chitchat_id, node_digest) in &self.node_digests { - len += chitchat_id.serialized_len(); - len += node_digest.heartbeat.serialized_len(); - len += node_digest.max_version.serialized_len(); - } - len - } } diff --git a/chitchat/src/lib.rs b/chitchat/src/lib.rs index 7ffc261..aeeebd2 100644 --- a/chitchat/src/lib.rs +++ b/chitchat/src/lib.rs @@ -29,7 +29,6 @@ use tracing::{error, warn}; pub use self::configuration::ChitchatConfig; pub use self::state::{ClusterStateSnapshot, NodeState}; use crate::digest::Digest; -use crate::message::syn_ack_serialized_len; pub use crate::message::ChitchatMessage; pub use crate::server::{spawn_chitchat, ChitchatHandle}; use crate::state::ClusterState; @@ -94,6 +93,11 @@ impl Chitchat { } } + fn process_delta(&mut self, delta: Delta) { + self.report_heartbeats(&delta); + self.cluster_state.apply_delta(delta); + } + pub(crate) fn process_message(&mut self, msg: ChitchatMessage) -> Option { match msg { ChitchatMessage::Syn { cluster_id, digest } => { @@ -105,13 +109,10 @@ impl Chitchat { ); return Some(ChitchatMessage::BadCluster); } - // Ensure for every reply from this node, at least the heartbeat is changed. let dead_nodes: HashSet<_> = self.dead_nodes().collect(); let self_digest = self.compute_digest(); - let empty_delta = Delta::default(); - let delta_mtu = MAX_UDP_DATAGRAM_PAYLOAD_SIZE - - syn_ack_serialized_len(&self_digest, &empty_delta); - let delta = self.cluster_state.compute_delta( + let delta_mtu = MAX_UDP_DATAGRAM_PAYLOAD_SIZE - 1 - digest.serialized_len(); + let delta = self.cluster_state.compute_partial_delta_respecting_mtu( &digest, delta_mtu, &dead_nodes, @@ -123,10 +124,9 @@ impl Chitchat { }) } ChitchatMessage::SynAck { digest, delta } => { - self.report_heartbeats(&delta); - self.cluster_state.apply_delta(delta); + self.process_delta(delta); let dead_nodes = self.dead_nodes().collect::>(); - let delta = self.cluster_state.compute_delta( + let delta = self.cluster_state.compute_partial_delta_respecting_mtu( &digest, MAX_UDP_DATAGRAM_PAYLOAD_SIZE - 1, &dead_nodes, @@ -135,8 +135,7 @@ impl Chitchat { Some(ChitchatMessage::Ack { delta }) } ChitchatMessage::Ack { delta } => { - self.report_heartbeats(&delta); - self.cluster_state.apply_delta(delta); + self.process_delta(delta); None } ChitchatMessage::BadCluster => { @@ -163,8 +162,10 @@ impl Chitchat { self.failure_detector.report_heartbeat(chitchat_id); } } else { - self.failure_detector.report_unknown(chitchat_id); - self.failure_detector.update_node_liveness(chitchat_id); + self.failure_detector + .report_unknown(&node_delta.chitchat_id); + self.failure_detector + .update_node_liveness(&node_delta.chitchat_id); } } } diff --git a/chitchat/src/message.rs b/chitchat/src/message.rs index 863b040..95e9bf3 100644 --- a/chitchat/src/message.rs +++ b/chitchat/src/message.rs @@ -4,7 +4,7 @@ use anyhow::Context; use crate::delta::Delta; use crate::digest::Digest; -use crate::serialize::Serializable; +use crate::serialize::{Deserializable, Serializable}; /// Chitchat message. /// @@ -61,7 +61,7 @@ impl Serializable for ChitchatMessage { ChitchatMessage::SynAck { digest, delta } => { buf.push(MessageType::SynAck.to_code()); digest.serialize(buf); - delta.serialize(buf); + Serializable::serialize(delta, buf); } ChitchatMessage::Ack { delta } => { buf.push(MessageType::Ack.to_code()); @@ -73,6 +73,21 @@ impl Serializable for ChitchatMessage { } } + fn serialized_len(&self) -> usize { + match self { + ChitchatMessage::Syn { cluster_id, digest } => { + 1 + cluster_id.serialized_len() + digest.serialized_len() + } + ChitchatMessage::SynAck { digest, delta } => { + 1 + digest.serialized_len() + delta.serialized_len() + } + ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), + ChitchatMessage::BadCluster => 1, + } + } +} + +impl Deserializable for ChitchatMessage { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let code = buf .first() @@ -98,21 +113,6 @@ impl Serializable for ChitchatMessage { MessageType::BadCluster => Ok(Self::BadCluster), } } - - fn serialized_len(&self) -> usize { - match self { - ChitchatMessage::Syn { cluster_id, digest } => { - 1 + cluster_id.serialized_len() + digest.serialized_len() - } - ChitchatMessage::SynAck { digest, delta } => syn_ack_serialized_len(digest, delta), - ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), - ChitchatMessage::BadCluster => 1, - } - } -} - -pub(crate) fn syn_ack_serialized_len(digest: &Digest, delta: &Delta) -> usize { - 1 + digest.serialized_len() + delta.serialized_len() } #[cfg(test)] @@ -149,7 +149,8 @@ mod tests { digest: Digest::default(), delta: Delta::default(), }; - test_serdeser_aux(&syn_ack, 7); + // 1 (message tag) + 2 (digest len) + 1 (delta end op) + test_serdeser_aux(&syn_ack, 4); } { // 2 bytes. @@ -177,7 +178,7 @@ mod tests { { let delta = Delta::default(); let ack = ChitchatMessage::Ack { delta }; - test_serdeser_aux(&ack, 5); + test_serdeser_aux(&ack, 2); } { // 4 bytes. @@ -187,7 +188,6 @@ mod tests { delta.add_node(node.clone(), Heartbeat(0)); // +29 bytes. delta.add_kv(&node, "key", "value", 0, Some(5)); - let ack = ChitchatMessage::Ack { delta }; test_serdeser_aux(&ack, 71); } diff --git a/chitchat/src/serialize.rs b/chitchat/src/serialize.rs index 5536383..d225a15 100644 --- a/chitchat/src/serialize.rs +++ b/chitchat/src/serialize.rs @@ -1,7 +1,9 @@ use std::io::BufRead; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use anyhow::bail; +use anyhow::{bail, Context}; +use bytes::Buf; +use zstd; use crate::{ChitchatId, Heartbeat}; @@ -10,7 +12,7 @@ use crate::{ChitchatId, Heartbeat}; /// Chitchat uses a custom binary serialization format. /// The point of this format is to make it possible /// to truncate the delta payload to a given mtu. -pub trait Serializable: Sized { +pub trait Serializable { fn serialize(&self, buf: &mut Vec); fn serialize_to_vec(&self) -> Vec { @@ -19,19 +21,33 @@ pub trait Serializable: Sized { buf } - fn deserialize(buf: &mut &[u8]) -> anyhow::Result; - fn serialized_len(&self) -> usize; } -impl Serializable for u16 { +pub trait Deserializable: Sized { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result; +} + +impl Serializable for u8 { fn serialize(&self, buf: &mut Vec) { - self.to_le_bytes().serialize(buf); + buf.push(*self) } + fn serialized_len(&self) -> usize { + 1 + } +} + +impl Deserializable for u8 { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let u16_bytes: [u8; 2] = Serializable::deserialize(buf)?; - Ok(Self::from_le_bytes(u16_bytes)) + let byte: [u8; 1] = Deserializable::deserialize(buf)?; + Ok(byte[0]) + } +} + +impl Serializable for u16 { + fn serialize(&self, buf: &mut Vec) { + self.to_le_bytes().serialize(buf); } fn serialized_len(&self) -> usize { @@ -39,20 +55,44 @@ impl Serializable for u16 { } } -impl Serializable for u64 { +impl Deserializable for u16 { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let u16_bytes: [u8; 2] = Deserializable::deserialize(buf)?; + Ok(Self::from_le_bytes(u16_bytes)) + } +} + +impl Serializable for u32 { fn serialize(&self, buf: &mut Vec) { self.to_le_bytes().serialize(buf); } + fn serialized_len(&self) -> usize { + 4 + } +} + +impl Deserializable for u32 { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let u64_bytes: [u8; 8] = Serializable::deserialize(buf)?; - Ok(Self::from_le_bytes(u64_bytes)) + let u32_bytes: [u8; 4] = Deserializable::deserialize(buf)?; + Ok(Self::from_le_bytes(u32_bytes)) } +} +impl Serializable for u64 { + fn serialize(&self, buf: &mut Vec) { + self.to_le_bytes().serialize(buf); + } fn serialized_len(&self) -> usize { 8 } } +impl Deserializable for u64 { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let u64_bytes: [u8; 8] = Deserializable::deserialize(buf)?; + Ok(Self::from_le_bytes(u64_bytes)) + } +} impl Serializable for Option { fn serialize(&self, buf: &mut Vec) { @@ -61,16 +101,6 @@ impl Serializable for Option { tombstone.serialize(buf); } } - - fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let is_some: bool = Serializable::deserialize(buf)?; - if is_some { - let u64_value = Serializable::deserialize(buf)?; - return Ok(Some(u64_value)); - } - Ok(None) - } - fn serialized_len(&self) -> usize { if self.is_some() { 9 @@ -80,21 +110,53 @@ impl Serializable for Option { } } -impl Serializable for bool { +impl Serializable for Vec { fn serialize(&self, buf: &mut Vec) { - buf.push(*self as u8); + (self.len() as u32).serialize(buf); + buf.extend(self.as_slice()) + } + + fn serialized_len(&self) -> usize { + 4 + self.len() } +} +impl Deserializable for Vec { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let bool_byte: [u8; 1] = Serializable::deserialize(buf)?; - Ok(bool_byte[0] != 0) + let len = u32::deserialize(buf)? as usize; + let payload = buf[..len].to_vec(); + buf.advance(len); + Ok(payload) } +} + +impl Deserializable for Option { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let is_some: bool = Deserializable::deserialize(buf)?; + if is_some { + let u64_value = Deserializable::deserialize(buf)?; + return Ok(Some(u64_value)); + } + Ok(None) + } +} +impl Serializable for bool { + fn serialize(&self, buf: &mut Vec) { + buf.push(*self as u8); + } fn serialized_len(&self) -> usize { 1 } } +impl Deserializable for bool { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let bool_byte: [u8; 1] = Deserializable::deserialize(buf)?; + Ok(bool_byte[0] != 0) + } +} + #[repr(u8)] enum IpVersion { V4 = 4u8, @@ -129,42 +191,55 @@ impl Serializable for IpAddr { } } + fn serialized_len(&self) -> usize { + 1 + match self { + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 16, + } + } +} + +impl Deserializable for IpAddr { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let ip_version_byte: [u8; 1] = Serializable::deserialize(buf)?; + let ip_version_byte: [u8; 1] = Deserializable::deserialize(buf)?; let ip_version = IpVersion::try_from(ip_version_byte[0])?; - match ip_version { IpVersion::V4 => { - let bytes: [u8; 4] = Serializable::deserialize(buf)?; + let bytes: [u8; 4] = Deserializable::deserialize(buf)?; Ok(Ipv4Addr::from(bytes).into()) } IpVersion::V6 => { - let bytes: [u8; 16] = Serializable::deserialize(buf)?; + let bytes: [u8; 16] = Deserializable::deserialize(buf)?; Ok(Ipv6Addr::from(bytes).into()) } } } - - fn serialized_len(&self) -> usize { - 1 + match self { - IpAddr::V4(_) => 4, - IpAddr::V6(_) => 16, - } - } } impl Serializable for String { fn serialize(&self, buf: &mut Vec) { - (self.len() as u16).serialize(buf); - buf.extend(self.as_bytes()) + self.as_str().serialize(buf) + } + + fn serialized_len(&self) -> usize { + self.as_str().serialized_len() } +} +impl Deserializable for String { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let len: usize = u16::deserialize(buf)? as usize; let s = std::str::from_utf8(&buf[..len])?.to_string(); buf.consume(len); Ok(s) } +} + +impl Serializable for str { + fn serialize(&self, buf: &mut Vec) { + (self.len() as u16).serialize(buf); + buf.extend(self.as_bytes()) + } fn serialized_len(&self) -> usize { 2 + self.len() @@ -175,7 +250,12 @@ impl Serializable for [u8; N] { fn serialize(&self, buf: &mut Vec) { buf.extend_from_slice(&self[..]); } + fn serialized_len(&self) -> usize { + N + } +} +impl Deserializable for [u8; N] { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { if buf.len() < N { bail!("Buffer too short"); @@ -184,10 +264,6 @@ impl Serializable for [u8; N] { buf.consume(N); Ok(val_bytes) } - - fn serialized_len(&self) -> usize { - N - } } impl Serializable for SocketAddr { @@ -196,15 +272,17 @@ impl Serializable for SocketAddr { self.port().serialize(buf); } + fn serialized_len(&self) -> usize { + self.ip().serialized_len() + self.port().serialized_len() + } +} + +impl Deserializable for SocketAddr { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let ip_addr = IpAddr::deserialize(buf)?; let port = u16::deserialize(buf)?; Ok(SocketAddr::new(ip_addr, port)) } - - fn serialized_len(&self) -> usize { - self.ip().serialized_len() + self.port().serialized_len() - } } impl Serializable for ChitchatId { @@ -214,6 +292,14 @@ impl Serializable for ChitchatId { self.gossip_advertise_addr.serialize(buf) } + fn serialized_len(&self) -> usize { + self.node_id.serialized_len() + + self.generation_id.serialized_len() + + self.gossip_advertise_addr.serialized_len() + } +} + +impl Deserializable for ChitchatId { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let node_id = String::deserialize(buf)?; let generation_id = u64::deserialize(buf)?; @@ -224,12 +310,6 @@ impl Serializable for ChitchatId { gossip_advertise_addr, }) } - - fn serialized_len(&self) -> usize { - self.node_id.serialized_len() - + self.generation_id.serialized_len() - + self.gossip_advertise_addr.serialized_len() - } } impl Serializable for Heartbeat { @@ -237,19 +317,213 @@ impl Serializable for Heartbeat { self.0.serialize(buf); } + fn serialized_len(&self) -> usize { + self.0.serialized_len() + } +} + +impl Deserializable for Heartbeat { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let heartbeat = u64::deserialize(buf)?; Ok(Self(heartbeat)) } +} + +/// A compressed stream writer receives a sequence of `Serializable` and +/// serialize/compresses into blocks of a configurable size. +/// +/// Block are tagged, so that blocks with a high entropy are stored kept "uncompressed". +/// +/// The stream gives the client an upperbound of what the overall payload length would +/// be if another item was appended. +/// This makes it possible to enforce a `mtu`. +pub struct CompressedStreamWriter { + output: Vec, + // temporary buffer used for block compression. + uncompressed_block: Vec, + // ongoing block being serialized. + compressed_block: Vec, + block_threshold: usize, +} + +impl CompressedStreamWriter { + pub fn with_block_threshold(block_threshold: u16) -> CompressedStreamWriter { + let block_threshold = block_threshold as usize; + let output = Vec::with_capacity(block_threshold); + CompressedStreamWriter { + output, + uncompressed_block: Vec::with_capacity(block_threshold * 2), + compressed_block: Vec::with_capacity(block_threshold), + block_threshold, + } + } + + /// Returns an upperbound of the serialized len after appending `s` + pub fn serialized_len_upperbound_after(&self, item: &S) -> usize { + self.output.len() + // already serialized block + if self.uncompressed_block.is_empty() { 0 } else { 3 } + // current block len + self.uncompressed_block.len() + + 3 + // possibly another block that will be created. (this is unlikely but possible and we want an upperbound) + item.serialized_len() + // the new item. This assume no compression will be possible. + 1 // End of stream flag + } + + /// Appends a new item to the stream. Items must be at most `u16::MAX` bytes long. + pub fn append(&mut self, item: &S) { + let item_len = item.serialized_len(); + assert!(item_len <= u16::MAX as usize); + if self.uncompressed_block.len() + item_len >= self.block_threshold { + // time to flush our current block. + self.flush_block(); + } + item.serialize(&mut self.uncompressed_block); + if self.uncompressed_block.len() >= self.block_threshold {} + } + + /// Flush the ongoing block as compressed or an uncompressed block (whichever is the smallest). + /// If the ongoing block is empty, this function is no op. + fn flush_block(&mut self) { + if self.uncompressed_block.is_empty() { + return; + } + let uncompressed_len = self.uncompressed_block.len(); + let uncompressed_len_u16 = + u16::try_from(uncompressed_len).expect("uncompressed block too big"); + self.compressed_block.resize(uncompressed_len, 0u8); + match zstd::bulk::compress_to_buffer( + &self.uncompressed_block, + &mut self.compressed_block[..], + 0, + ) { + Ok(compressed_len) => { + let compressed_len_u16 = u16::try_from(compressed_len).unwrap(); + let block_meta = BlockMeta::CompressedBlock { + len: compressed_len_u16, + }; + block_meta.serialize(&mut self.output); + self.output.extend(&self.compressed_block[..compressed_len]); + } + // The compressed version was actually longer than the decompressed one. + // Let's keep the block uncomopressed + Err(_) => { + let block_meta = BlockMeta::UncompressedBlock { + len: uncompressed_len_u16, + }; + block_meta.serialize(&mut self.output); + self.output.extend(&self.uncompressed_block); + } + } + self.uncompressed_block.clear(); + self.compressed_block.clear(); + } + + pub fn finish(mut self) -> Vec { + self.flush_block(); + BlockMeta::NoMoreBlocks.serialize(&mut self.output); + self.output + } +} + +pub fn deserialize_stream(buf: &mut &[u8]) -> anyhow::Result> { + let mut items: Vec = Vec::new(); + let mut decompression_buffer = vec![0; u16::MAX as usize]; + while !buf.is_empty() { + let block_meta = BlockMeta::deserialize(buf)?; + match block_meta { + BlockMeta::CompressedBlock { len } => { + let len = len as usize; + let compressed_block_bytes = &buf[..len]; + let uncompressed_len = zstd::bulk::decompress_to_buffer( + compressed_block_bytes, + &mut decompression_buffer[..u16::MAX as usize], + ) + .context("failed to decompress block")?; + buf.advance(len as usize); + let mut block_bytes = &decompression_buffer[..uncompressed_len]; + while !block_bytes.is_empty() { + let item = D::deserialize(&mut block_bytes)?; + items.push(item); + } + } + BlockMeta::UncompressedBlock { len } => { + let len = len as usize; + let mut block_bytes = &buf[..len]; + buf.advance(len as usize); + while !block_bytes.is_empty() { + let item = D::deserialize(&mut block_bytes)?; + items.push(item); + } + } + BlockMeta::NoMoreBlocks => { + return Ok(items); + } + }; + } + anyhow::bail!("compressed stream error: reached end of buffer without NoMoreBlock tag"); +} + +#[derive(Eq, PartialEq, Debug)] +enum BlockMeta { + CompressedBlock { len: u16 }, + UncompressedBlock { len: u16 }, + NoMoreBlocks, +} + +const NO_MORE_BLOCKS_TAG: u8 = 0u8; +const COMPRESSED_BLOCK_TAG: u8 = 1u8; +const UNCOMPRESSED_BLOCK_TAG: u8 = 2u8; + +impl Serializable for BlockMeta { + fn serialize(&self, buf: &mut Vec) { + match self { + BlockMeta::CompressedBlock { len } => { + COMPRESSED_BLOCK_TAG.serialize(buf); + len.serialize(buf); + } + BlockMeta::UncompressedBlock { len } => { + UNCOMPRESSED_BLOCK_TAG.serialize(buf); + len.serialize(buf); + } + BlockMeta::NoMoreBlocks => { + NO_MORE_BLOCKS_TAG.serialize(buf); + } + } + } fn serialized_len(&self) -> usize { - self.0.serialized_len() + match self { + BlockMeta::CompressedBlock { .. } | BlockMeta::UncompressedBlock { .. } => 3, + BlockMeta::NoMoreBlocks => 1, + } + } +} + +impl Deserializable for BlockMeta { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let tag = u8::deserialize(buf)?; + match tag { + UNCOMPRESSED_BLOCK_TAG => { + let len = u16::deserialize(buf)?; + Ok(BlockMeta::UncompressedBlock { len }) + } + COMPRESSED_BLOCK_TAG => { + let len = u16::deserialize(buf)?; + Ok(BlockMeta::CompressedBlock { len }) + } + NO_MORE_BLOCKS_TAG => Ok(BlockMeta::NoMoreBlocks), + _ => { + anyhow::bail!("Unknown block meta tag: {tag}") + } + } } } #[cfg(test)] #[track_caller] -pub fn test_serdeser_aux(obj: &T, num_bytes: usize) { +pub fn test_serdeser_aux( + obj: &T, + num_bytes: usize, +) { let mut buf = Vec::new(); obj.serialize(&mut buf); assert_eq!(buf.len(), obj.serialized_len()); @@ -260,6 +534,10 @@ pub fn test_serdeser_aux(obj: &T, #[cfg(test)] mod tests { + use proptest::proptest; + use rand::distributions::Alphanumeric; + use rand::Rng; + use super::*; #[test] @@ -296,4 +574,95 @@ mod tests { test_serdeser_aux(&Some(1), 9); test_serdeser_aux(&None, 1); } + + #[test] + fn test_serialize_block_meta() { + test_serdeser_aux(&BlockMeta::CompressedBlock { len: 10u16 }, 3); + test_serdeser_aux(&BlockMeta::UncompressedBlock { len: 18u16 }, 3); + test_serdeser_aux(&BlockMeta::NoMoreBlocks, 1); + } + + // An array of 10 small sentences for tests. + const TEXT_SAMPLES: [&str; 10] = [ + "I'm happy.", + "She exercises every morning.", + "His dog barks loudly.", + "My school starts at 8:00.", + "We always eat dinner together.", + "They take the bus to work.", + "He doesn't like vegetables.", + "I don't want anything to drink.", + "hello Happy tax payer", + "do you like tea?", + ]; + + #[test] + fn test_compressed_serialized_stream() { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(1_000); + let mut uncompressed_len = 0; + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + compressed_stream_writer.append(sentence); + uncompressed_len += sentence.len(); + } + let buf = compressed_stream_writer.finish(); + let mut cursor = &buf[..]; + assert!(buf.len() * 3 < uncompressed_len); + let vals: Vec = super::deserialize_stream(&mut cursor).unwrap(); + assert_eq!(vals.len(), 100); + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + assert_eq!(&vals[i], sentence); + } + assert!(cursor.is_empty()); + } + + #[test] + fn test_compressed_serialized_stream_with_random_data() { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(200); + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + compressed_stream_writer.append(sentence); + } + for _ in 0..30 { + let rng = rand::thread_rng(); + let random_sentence: String = rng + .sample_iter(&Alphanumeric) + .take(30) + .map(char::from) + .collect(); + compressed_stream_writer.append(random_sentence.as_str()); + } + let buf = compressed_stream_writer.finish(); + let mut cursor = &buf[..]; + let vals: Vec = deserialize_stream(&mut cursor).unwrap(); + assert_eq!(vals.len(), 130); + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + assert_eq!(&vals[i], sentence); + } + assert!(cursor.is_empty()); + } + + proptest! { + #[test] + fn test_proptest_compressed_stream(payload in proptest::collection::vec(".{0,1000}", 1..100)) { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(1_000); + for s in &payload[..payload.len() - 1] { + compressed_stream_writer.append(s); + } + let len_upper_bound = compressed_stream_writer.serialized_len_upperbound_after(&payload[payload.len() - 1]); + compressed_stream_writer.append(&payload[payload.len() - 1]); + let buf = compressed_stream_writer.finish(); + let vals: Vec = deserialize_stream(&mut &buf[..]).unwrap(); + assert!(buf.len() <= len_upper_bound); + assert_eq!(vals.len(), payload.len()); + for (left, right) in vals.iter().zip(payload.iter()) { + assert_eq!(left, right); + } + } + } } diff --git a/chitchat/src/server.rs b/chitchat/src/server.rs index 51676db..4124607 100644 --- a/chitchat/src/server.rs +++ b/chitchat/src/server.rs @@ -601,7 +601,7 @@ mod tests { }; }; - let node_delta = &delta.node_deltas.get(&server_id).unwrap(); + let node_delta = delta.get(&server_id).unwrap(); let heartbeat = node_delta.heartbeat; assert_eq!(heartbeat, Heartbeat(3)); diff --git a/chitchat/src/state.rs b/chitchat/src/state.rs index 0b2be33..e1c05f0 100644 --- a/chitchat/src/state.rs +++ b/chitchat/src/state.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use tokio::sync::watch; use tracing::warn; -use crate::delta::{Delta, DeltaWriter}; +use crate::delta::{Delta, DeltaWriter, NodeDelta}; use crate::digest::{Digest, NodeDigest}; use crate::listener::Listeners; use crate::{ChitchatId, Heartbeat, KeyChangeEvent, Version, VersionedValue}; @@ -296,16 +296,21 @@ impl ClusterState { .retain(|chitchat_id, _| !delta.nodes_to_reset.contains(chitchat_id)); // Apply delta. - for (chitchat_id, node_delta) in delta.node_deltas { + for node_delta in delta.node_deltas { + let NodeDelta { + chitchat_id, + heartbeat, + key_values, + } = node_delta; let node_state = self .node_states .entry(chitchat_id.clone()) .or_insert_with(|| NodeState::new(chitchat_id, self.listeners.clone())); - if node_state.heartbeat < node_delta.heartbeat { - node_state.heartbeat = node_delta.heartbeat; + if node_state.heartbeat < heartbeat { + node_state.heartbeat = heartbeat; node_state.last_heartbeat = Instant::now(); } - for (key, versioned_value) in node_delta.key_values { + for (key, versioned_value) in key_values { node_state.max_version = node_state.max_version.max(versioned_value.version); node_state.set_versioned_value(key, versioned_value); } @@ -336,7 +341,7 @@ impl ClusterState { } /// Implements the Scuttlebutt reconciliation with the scuttle-depth ordering. - pub fn compute_delta( + pub fn compute_partial_delta_respecting_mtu( &self, digest: &Digest, mtu: usize, @@ -377,8 +382,7 @@ impl ClusterState { for (key, versioned_value) in stale_node.stale_key_values() { added_something = true; if !delta_writer.add_kv(key, versioned_value.clone()) { - let delta: Delta = delta_writer.into(); - return delta; + return delta_writer.finish(); } } if !added_something && nodes_to_reset.contains(&stale_node.chitchat_id) { @@ -393,12 +397,11 @@ impl ClusterState { tombstone: Some(0), }, ) { - let delta: Delta = delta_writer.into(); - return delta; + return delta_writer.finish(); } } } - delta_writer.into() + delta_writer.finish() } } @@ -945,12 +948,18 @@ mod tests { dead_nodes: &HashSet<&ChitchatId>, expected_delta_atoms: &[(&ChitchatId, &str, &str, Version, Option)], ) { - let max_delta = cluster_state.compute_delta(digest, usize::MAX, dead_nodes, 10_000); + let max_delta = cluster_state.compute_partial_delta_respecting_mtu( + digest, + usize::MAX, + dead_nodes, + 10_000, + ); let mut buf = Vec::new(); - max_delta.serialize(&mut buf); + Serializable::serialize(&max_delta, &mut buf); let mut mtu_per_num_entries = Vec::new(); - for mtu in 2..buf.len() { - let delta = cluster_state.compute_delta(digest, mtu, dead_nodes, 10_000); + for mtu in 100..buf.len() { + let delta = + cluster_state.compute_partial_delta_respecting_mtu(digest, mtu, dead_nodes, 10_000); let num_tuples = delta.num_tuples(); if mtu_per_num_entries.len() == num_tuples + 1 { continue; @@ -966,11 +975,17 @@ mod tests { expected_delta.add_kv(node, key, val, version, tombstone); } { - let delta = cluster_state.compute_delta(digest, mtu, dead_nodes, 10_000); + let delta = cluster_state + .compute_partial_delta_respecting_mtu(digest, mtu, dead_nodes, 10_000); assert_eq!(&delta, &expected_delta); } { - let delta = cluster_state.compute_delta(digest, mtu + 1, dead_nodes, 10_000); + let delta = cluster_state.compute_partial_delta_respecting_mtu( + digest, + mtu + 1, + dead_nodes, + 10_000, + ); assert_eq!(&delta, &expected_delta); } } @@ -1099,7 +1114,7 @@ mod tests { let node1 = ChitchatId::for_local_test(10_001); digest.add_node(node1.clone(), Heartbeat(0), 1); { - let delta = cluster_state.compute_delta( + let delta = cluster_state.compute_partial_delta_respecting_mtu( &digest, MAX_UDP_DATAGRAM_PAYLOAD_SIZE, &HashSet::new(), @@ -1117,7 +1132,7 @@ mod tests { // Node 1 heartbeat in digest + grace period (9_999) is inferior to the // node1's hearbeat in the cluster state. Thus we expect the cluster to compute a // delta that will reset node 1. - let delta = cluster_state.compute_delta( + let delta = cluster_state.compute_partial_delta_respecting_mtu( &digest, MAX_UDP_DATAGRAM_PAYLOAD_SIZE, &HashSet::new(), diff --git a/chitchat/src/transport/channel.rs b/chitchat/src/transport/channel.rs index 11ba515..137076c 100644 --- a/chitchat/src/transport/channel.rs +++ b/chitchat/src/transport/channel.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Mutex}; use anyhow::{bail, Context}; use async_trait::async_trait; use tokio::sync::mpsc::{Receiver, Sender}; -use tracing::{debug, info}; +use tracing::info; use crate::serialize::Serializable; use crate::transport::{Socket, Transport}; @@ -98,7 +98,6 @@ impl ChannelTransport { bail!("Serialized message size exceeds MTU."); } } - debug!(num_bytes = num_bytes, "send"); let mut inner_lock = self.inner.lock().unwrap(); inner_lock.statistics.record_message_len(num_bytes); if let Some(to_addrs) = inner_lock.removed_links.get(&from_addr) { diff --git a/chitchat/src/transport/udp.rs b/chitchat/src/transport/udp.rs index 59d5e89..d1d8fb4 100644 --- a/chitchat/src/transport/udp.rs +++ b/chitchat/src/transport/udp.rs @@ -4,7 +4,7 @@ use anyhow::Context; use async_trait::async_trait; use tracing::warn; -use crate::serialize::Serializable; +use crate::serialize::{Deserializable, Serializable}; use crate::transport::{Socket, Transport}; use crate::{ChitchatMessage, MAX_UDP_DATAGRAM_PAYLOAD_SIZE}; diff --git a/chitchat/tests/cluster_test.rs b/chitchat/tests/cluster_test.rs index ccaf816..f493a8a 100644 --- a/chitchat/tests/cluster_test.rs +++ b/chitchat/tests/cluster_test.rs @@ -9,7 +9,7 @@ use chitchat::{ }; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; -use tracing::{debug, info}; +use tracing::{debug, error, info}; #[derive(Debug)] enum Operation { @@ -158,7 +158,7 @@ impl Simulator { } assert!(predicate_value); } else { - info!(node_id=%chitchat_id.node_id, state_snapshot=?chitchat_guard.state_snapshot(), "Node state missing."); + error!(node_id=%chitchat_id.node_id, state_snapshot=?chitchat_guard.state_snapshot(), "Node state missing."); panic!("Node state missing"); } }