From e9a1ea1b25a6838c38126171895abb50f0d2114b Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 15 Feb 2024 12:25:38 +0900 Subject: [PATCH] Adding ZSTD compression to chitchat's Deltas. (#112) The difficulty in this PR is coming from the fact that we need to observe the maximum transmissible unit, and due to the fact that we want to keep the simplficity of using the same `MessageDigest` object when sending and emitting messages. The solution I went for requires to compress things twice. When building our message object, we rely on a stream compressor that makes it possible to get an upperbound of the resulting payload after we insert an arbitrary serializable object. We then decompose the delta into as many mutations. I had to refactor the delta object, so that we have a proper bijection between a delta object and the list of mutations that were used to build it. The delta object hence encodes the order in which nodes, and values were inserted. The point there is to ensure that the actual serialization will yield the same size as the "simulated one" and hence respect the mtu. The code is a bit simpler that way too. Closes #72 --- chitchat/Cargo.toml | 10 +- chitchat/src/delta.rs | 715 +++++++++++++++++++----------- chitchat/src/digest.rs | 21 +- chitchat/src/lib.rs | 33 +- chitchat/src/message.rs | 40 +- chitchat/src/serialize.rs | 484 +++++++++++++++++--- chitchat/src/server.rs | 2 +- chitchat/src/state.rs | 65 ++- chitchat/src/transport/channel.rs | 3 +- chitchat/src/transport/udp.rs | 2 +- chitchat/tests/cluster_test.rs | 4 +- 11 files changed, 997 insertions(+), 382 deletions(-) diff --git a/chitchat/Cargo.toml b/chitchat/Cargo.toml index a3a79a8..f486241 100644 --- a/chitchat/Cargo.toml +++ b/chitchat/Cargo.toml @@ -16,14 +16,22 @@ bytes = "1" itertools = "0.12" rand = { version = "0.8", features = ["small_rng"] } serde = { version = "1", features = ["derive"] } -tokio = { version = "1.28.0", features = ["net", "sync", "rt-multi-thread", "macros", "time"] } +tokio = { version = "1.28.0", features = [ + "net", + "sync", + "rt-multi-thread", + "macros", + "time", +] } tokio-stream = { version = "0.1", features = ["sync"] } tracing = "0.1" +zstd = "0.13" [dev-dependencies] assert-json-diff = "2" mock_instant = "0.3" tracing-subscriber = "0.3" +proptest = "1.4" [features] testsuite = [] diff --git a/chitchat/src/delta.rs b/chitchat/src/delta.rs index 34e72a6..e0f3234 100644 --- a/chitchat/src/delta.rs +++ b/chitchat/src/delta.rs @@ -1,81 +1,267 @@ -use std::collections::{BTreeMap, HashSet}; -use std::mem; +use std::collections::HashSet; use crate::serialize::*; use crate::{ChitchatId, Heartbeat, VersionedValue}; -#[derive(Debug, Default, Eq, PartialEq)] +/// A delta is the message we send to another node to update it. +/// +/// Its serialization is done by transforming it into a sequence of operations, +/// encoded one after the other in a compressed stream. +#[derive(Debug, Eq, PartialEq)] pub struct Delta { - pub(crate) node_deltas: BTreeMap, - pub(crate) nodes_to_reset: HashSet, + pub(crate) nodes_to_reset: Vec, + pub(crate) node_deltas: Vec, + serialized_len: usize, } -impl Serializable for Delta { - fn serialize(&self, buf: &mut Vec) { - (self.node_deltas.len() as u16).serialize(buf); - for (chitchat_id, node_delta) in &self.node_deltas { - chitchat_id.serialize(buf); - node_delta.serialize(buf); +impl Default for Delta { + fn default() -> Self { + Delta { + nodes_to_reset: Vec::new(), + node_deltas: Vec::new(), + serialized_len: 1, } - (self.nodes_to_reset.len() as u16).serialize(buf); - for chitchat_id in &self.nodes_to_reset { - chitchat_id.serialize(buf); + } +} + +impl Delta { + fn get_operations(&self) -> impl Iterator> { + let nodes_to_reset_ops = self.nodes_to_reset.iter().map(DeltaOpRef::NodeToReset); + let node_deltas = self.node_deltas.iter().flat_map(|node_delta| { + std::iter::once(DeltaOpRef::Node { + chitchat_id: &node_delta.chitchat_id, + heartbeat: node_delta.heartbeat, + }) + .chain(node_delta.key_values.iter().map(|(key, versioned_value)| { + DeltaOpRef::KeyValue { + key, + versioned_value, + } + })) + }); + nodes_to_reset_ops.chain(node_deltas) + } +} + +enum DeltaOp { + NodeToReset(ChitchatId), + Node { + chitchat_id: ChitchatId, + heartbeat: Heartbeat, + }, + KeyValue { + key: String, + versioned_value: VersionedValue, + }, +} + +enum DeltaOpRef<'a> { + NodeToReset(&'a ChitchatId), + Node { + chitchat_id: &'a ChitchatId, + heartbeat: Heartbeat, + }, + KeyValue { + key: &'a str, + versioned_value: &'a VersionedValue, + }, +} + +#[repr(u8)] +enum DeltaOpTag { + Node = 0u8, + KeyValue = 1u8, + NodeToReset = 2u8, +} + +impl TryFrom for DeltaOpTag { + type Error = anyhow::Error; + + fn try_from(tag_byte: u8) -> anyhow::Result { + match tag_byte { + 0u8 => Ok(DeltaOpTag::Node), + 1u8 => Ok(DeltaOpTag::KeyValue), + 2u8 => Ok(DeltaOpTag::NodeToReset), + _ => { + anyhow::bail!("Unknown tag: {tag_byte}") + } } } +} + +impl From for u8 { + fn from(tag: DeltaOpTag) -> u8 { + tag as u8 + } +} +impl Deserializable for DeltaOp { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let mut node_deltas: BTreeMap = Default::default(); - let num_nodes = u16::deserialize(buf)?; - for _ in 0..num_nodes { - let chitchat_id = ChitchatId::deserialize(buf)?; - let node_delta = NodeDelta::deserialize(buf)?; - node_deltas.insert(chitchat_id, node_delta); + let tag_bytes: [u8; 1] = Deserializable::deserialize(buf)?; + let tag = DeltaOpTag::try_from(tag_bytes[0])?; + match tag { + DeltaOpTag::NodeToReset => { + let chitchat_id = ChitchatId::deserialize(buf)?; + Ok(DeltaOp::NodeToReset(chitchat_id)) + } + DeltaOpTag::Node => { + let chitchat_id = ChitchatId::deserialize(buf)?; + let heartbeat = Heartbeat::deserialize(buf)?; + Ok(DeltaOp::Node { + chitchat_id, + heartbeat, + }) + } + DeltaOpTag::KeyValue => { + let key = String::deserialize(buf)?; + let value = String::deserialize(buf)?; + let version = u64::deserialize(buf)?; + let tombstone = Option::::deserialize(buf)?; + let versioned_value: VersionedValue = VersionedValue { + value, + version, + tombstone, + }; + Ok(DeltaOp::KeyValue { + key, + versioned_value, + }) + } + } + } +} + +impl DeltaOp { + fn as_ref(&self) -> DeltaOpRef { + match self { + DeltaOp::Node { + chitchat_id, + heartbeat, + } => DeltaOpRef::Node { + chitchat_id, + heartbeat: *heartbeat, + }, + DeltaOp::KeyValue { + key, + versioned_value, + } => DeltaOpRef::KeyValue { + key, + versioned_value, + }, + DeltaOp::NodeToReset(node_to_reset) => DeltaOpRef::NodeToReset(node_to_reset), } - let num_nodes_to_reset = u16::deserialize(buf)?; - let mut nodes_to_reset = HashSet::with_capacity(num_nodes_to_reset as usize); - for _ in 0..num_nodes_to_reset { - let chitchat_id = ChitchatId::deserialize(buf)?; - nodes_to_reset.insert(chitchat_id); + } +} + +impl Serializable for DeltaOp { + fn serialize(&self, buf: &mut Vec) { + self.as_ref().serialize(buf) + } + + fn serialized_len(&self) -> usize { + self.as_ref().serialized_len() + } +} + +impl<'a> Serializable for DeltaOpRef<'a> { + fn serialize(&self, buf: &mut Vec) { + match self { + Self::Node { + chitchat_id, + heartbeat, + } => { + buf.push(DeltaOpTag::Node.into()); + chitchat_id.serialize(buf); + heartbeat.serialize(buf); + } + Self::KeyValue { + key, + versioned_value, + } => { + buf.push(DeltaOpTag::KeyValue.into()); + key.serialize(buf); + versioned_value.value.serialize(buf); + versioned_value.version.serialize(buf); + versioned_value.tombstone.serialize(buf); + } + Self::NodeToReset(chitchat_id) => { + buf.push(DeltaOpTag::NodeToReset.into()); + chitchat_id.serialize(buf); + } } - Ok(Delta { - node_deltas, - nodes_to_reset, - }) } fn serialized_len(&self) -> usize { - let mut len = 2; - for (chitchat_id, node_delta) in &self.node_deltas { - len += chitchat_id.serialized_len(); - len += node_delta.serialized_len(); + 1 + match self { + Self::Node { + chitchat_id, + heartbeat, + } => chitchat_id.serialized_len() + heartbeat.serialized_len(), + Self::KeyValue { + key, + versioned_value, + } => { + key.serialized_len() + + versioned_value.value.serialized_len() + + versioned_value.version.serialized_len() + + versioned_value.tombstone.serialized_len() + } + Self::NodeToReset(chitchat_id) => chitchat_id.serialized_len(), } - len += 2; - for chitchat_id in &self.nodes_to_reset { - len += chitchat_id.serialized_len(); + } +} + +impl Serializable for Delta { + fn serialize(&self, buf: &mut Vec) { + let mut compressed_stream_writer = CompressedStreamWriter::with_block_threshold(16_384); + for op in self.get_operations() { + compressed_stream_writer.append(&op); } - len + let payload = compressed_stream_writer.finish(); + assert_eq!(payload.len(), self.serialized_len); + buf.extend(&payload); + } + + fn serialized_len(&self) -> usize { + self.serialized_len + } +} + +impl Deserializable for Delta { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let original_len = buf.len(); + let ops: Vec = crate::serialize::deserialize_stream(buf)?; + let consumed_len = original_len - buf.len(); + let mut delta_builder = DeltaBuilder::default(); + for op in ops { + delta_builder.apply_op(op)?; + } + Ok(delta_builder.finish(consumed_len)) } } #[cfg(test)] impl Delta { - pub fn num_tuples(&self) -> usize { + pub(crate) fn num_tuples(&self) -> usize { self.node_deltas - .values() + .iter() .map(|node_delta| node_delta.num_tuples()) .sum() } - pub fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) { - self.node_deltas - .entry(chitchat_id) - .or_insert_with(|| NodeDelta { - heartbeat, - ..Default::default() - }); + pub(crate) fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) { + assert!(!self + .node_deltas + .iter() + .any(|node_delta| { node_delta.chitchat_id == chitchat_id })); + self.node_deltas.push(NodeDelta { + chitchat_id, + heartbeat, + key_values: Vec::new(), + }); } - pub fn add_kv( + pub(crate) fn add_kv( &mut self, chitchat_id: &ChitchatId, key: &str, @@ -83,26 +269,41 @@ impl Delta { version: crate::Version, tombstone: Option, ) { - let node_delta = self.node_deltas.get_mut(chitchat_id).unwrap(); - node_delta.key_values.insert( + let node_delta = self + .node_deltas + .iter_mut() + .find(|node_delta| &node_delta.chitchat_id == chitchat_id) + .unwrap(); + node_delta.key_values.push(( key.to_string(), VersionedValue { value: value.to_string(), version, tombstone, }, - ); + )); + } + + pub(crate) fn set_serialized_len(&mut self, serialized_len: usize) { + self.serialized_len = serialized_len; + } + + pub(crate) fn get(&self, chitchat_id: &ChitchatId) -> Option<&NodeDelta> { + self.node_deltas + .iter() + .find(|node_delta| &node_delta.chitchat_id == chitchat_id) } - pub fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) { - self.nodes_to_reset.insert(chitchat_id); + pub(crate) fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) { + self.nodes_to_reset.push(chitchat_id); } } -#[derive(Debug, Default, Eq, PartialEq, serde::Serialize)] +#[derive(Debug, Eq, PartialEq, serde::Serialize)] pub(crate) struct NodeDelta { + pub chitchat_id: ChitchatId, pub heartbeat: Heartbeat, - pub key_values: BTreeMap, + pub key_values: Vec<(String, VersionedValue)>, } #[cfg(test)] @@ -112,164 +313,141 @@ impl NodeDelta { } } -pub struct DeltaWriter { +#[derive(Default)] +struct DeltaBuilder { + existing_nodes: HashSet, delta: Delta, - mtu: usize, - num_bytes: usize, - current_chitchat_id: Option, - current_node_delta: NodeDelta, - reached_capacity: bool, + current_node_delta: Option, } -impl DeltaWriter { - pub fn with_mtu(mtu: usize) -> Self { - DeltaWriter { - delta: Delta::default(), - mtu, - num_bytes: 2 + 2, /* 2 bytes for `nodes_to_reset.len()` + 2 bytes for - * `node_deltas.len()` */ - current_chitchat_id: None, - current_node_delta: NodeDelta::default(), - reached_capacity: false, - } +impl DeltaBuilder { + fn finish(mut self, len: usize) -> Delta { + self.flush(); + self.delta.serialized_len = len; + self.delta } - fn flush(&mut self) { - let chitchat_id_opt = mem::take(&mut self.current_chitchat_id); - let node_delta = mem::take(&mut self.current_node_delta); - if let Some(chitchat_id) = chitchat_id_opt { - self.delta.node_deltas.insert(chitchat_id, node_delta); + + fn apply_op(&mut self, op: DeltaOp) -> anyhow::Result<()> { + match op { + DeltaOp::Node { + chitchat_id, + heartbeat, + } => { + self.flush(); + anyhow::ensure!(!self.existing_nodes.contains(&chitchat_id)); + self.existing_nodes.insert(chitchat_id.clone()); + self.current_node_delta = Some(NodeDelta { + chitchat_id, + heartbeat, + key_values: Vec::new(), + }); + } + DeltaOp::KeyValue { + key, + versioned_value, + } => { + let Some(current_node_delta) = self.current_node_delta.as_mut() else { + anyhow::bail!("received a key-value op without a node op before."); + }; + if let Some((_last_key, last_versioned_value)) = + current_node_delta.key_values.last() + { + anyhow::ensure!( + last_versioned_value.version < versioned_value.version, + "kv version should be increasing" + ); + } + current_node_delta + .key_values + .push((key.to_string(), versioned_value)); + } + DeltaOp::NodeToReset(chitchat_id) => { + anyhow::ensure!( + self.delta.node_deltas.is_empty(), + "nodes_to_reset should be encoded before node_deltas" + ); + self.delta.nodes_to_reset.push(chitchat_id); + } } + Ok(()) } - pub fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) -> bool { - assert!(!self.delta.nodes_to_reset.contains(&chitchat_id)); - if !self.attempt_add_bytes(chitchat_id.serialized_len()) { - return false; - } - self.delta.nodes_to_reset.insert(chitchat_id); - true + fn flush(&mut self) { + let Some(node_delta) = self.current_node_delta.take() else { + // There are no nodes in the builder. + // (this happens when the delta builder is freshly created and no ops have been received + // yet.) + return; + }; + self.delta.node_deltas.push(node_delta); } +} - pub fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) -> bool { - assert!(self.current_chitchat_id.as_ref() != Some(&chitchat_id)); - assert!(!self.delta.node_deltas.contains_key(&chitchat_id)); - self.flush(); - // Reserve bytes for [`ChitchatId`], [`Hearbeat`], and for an empty [`NodeDelta`] which has - // a size of 2 bytes. - if !self.attempt_add_bytes(chitchat_id.serialized_len() + heartbeat.serialized_len() + 2) { - return false; +/// The delta serializer is just helping us with the task of serializing +/// part of a delta, while respecting a given `mtu`. +/// +/// We do it by calling `try_add_node_reset`, `try_add_node`, and `try_add_kv` +/// and stopping as soon as one of this methods returns `false`. +pub struct DeltaSerializer { + mtu: usize, + delta_builder: DeltaBuilder, + compressed_stream_writer: CompressedStreamWriter, +} + +const BLOCK_THRESHOLD: u16 = 16_384u16; + +impl DeltaSerializer { + pub fn with_mtu(mtu: usize) -> Self { + assert!(mtu >= 100); + let block_threshold = u16::try_from((BLOCK_THRESHOLD as usize).min(mtu)).unwrap(); + DeltaSerializer { + mtu, + delta_builder: DeltaBuilder::default(), + compressed_stream_writer: CompressedStreamWriter::with_block_threshold(block_threshold), } - self.current_chitchat_id = Some(chitchat_id); - self.current_node_delta.heartbeat = heartbeat; - true } - fn attempt_add_bytes(&mut self, num_bytes: usize) -> bool { - assert!(!self.reached_capacity); - let new_num_bytes = self.num_bytes + num_bytes; - if new_num_bytes > self.mtu { - self.reached_capacity = true; - return false; - } - self.num_bytes = new_num_bytes; - true + /// Returns false if the node to reset could not be added because the payload would exceed the + /// mtu. + pub fn try_add_node_to_reset(&mut self, chitchat_id: ChitchatId) -> bool { + let delta_op = DeltaOp::NodeToReset(chitchat_id); + self.try_add_op(delta_op) } - /// Returns false if the KV could not be added because mtu was reached. - pub fn add_kv(&mut self, key: &str, versioned_value: VersionedValue) -> bool { - assert!(!self.current_node_delta.key_values.contains_key(key)); - // Reserve bytes for the key (2 bytes are used to store the key length) and versioned value. - if !self.attempt_add_bytes( - 2 + key.len() - + versioned_value.value.serialized_len() - + versioned_value.version.serialized_len() - + versioned_value.tombstone.serialized_len(), - ) { + fn try_add_op(&mut self, delta_op: DeltaOp) -> bool { + if self + .compressed_stream_writer + .serialized_len_upperbound_after(&delta_op) + > self.mtu + { return false; } - self.current_node_delta - .key_values - .insert(key.to_string(), versioned_value); + self.compressed_stream_writer.append(&delta_op); + assert!(self.delta_builder.apply_op(delta_op).is_ok()); true } -} - -impl From for Delta { - fn from(mut delta_writer: DeltaWriter) -> Delta { - delta_writer.flush(); - if cfg!(debug_assertions) { - let mut buf = Vec::new(); - assert_eq!(delta_writer.num_bytes, delta_writer.delta.serialized_len()); - delta_writer.delta.serialize(&mut buf); - assert_eq!(buf.len(), delta_writer.num_bytes); - } - delta_writer.delta - } -} -impl Serializable for NodeDelta { - fn serialize(&self, buf: &mut Vec) { - self.heartbeat.serialize(buf); - (self.key_values.len() as u16).serialize(buf); - for ( - key, - VersionedValue { - value, - version, - tombstone, - }, - ) in &self.key_values - { - key.serialize(buf); - value.serialize(buf); - version.serialize(buf); - tombstone.serialize(buf); - } + /// Returns false if the KV could not be added because the payload would exceed the mtu. + pub fn try_add_kv(&mut self, key: &str, versioned_value: VersionedValue) -> bool { + let key_value_op = DeltaOp::KeyValue { + key: key.to_string(), + versioned_value, + }; + self.try_add_op(key_value_op) } - fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let heartbeat = Heartbeat::deserialize(buf)?; - let mut key_values: BTreeMap = Default::default(); - let num_key_values = u16::deserialize(buf)?; - for _ in 0..num_key_values { - let key = String::deserialize(buf)?; - let value = String::deserialize(buf)?; - let version = u64::deserialize(buf)?; - let tombstone = >::deserialize(buf)?; - key_values.insert( - key, - VersionedValue { - value, - version, - tombstone, - }, - ); - } - Ok(Self { + /// Returns false if the node could not be added because the payload would exceed the mtu. + pub fn try_add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) -> bool { + let new_node_op = DeltaOp::Node { + chitchat_id, heartbeat, - key_values, - }) + }; + self.try_add_op(new_node_op) } - fn serialized_len(&self) -> usize { - let mut len = 2; - len += self.heartbeat.serialized_len(); - - for ( - key, - VersionedValue { - value, - version, - tombstone, - }, - ) in &self.key_values - { - len += key.serialized_len(); - len += value.serialized_len(); - len += version.serialized_len(); - len += tombstone.serialized_len(); - } - len + pub fn finish(self) -> Delta { + let buffer = self.compressed_stream_writer.finish(); + self.delta_builder.finish(buffer.len()) } } @@ -279,23 +457,23 @@ mod tests { #[test] fn test_delta_serialization_default() { - test_serdeser_aux(&Delta::default(), 4); + test_serdeser_aux(&Delta::default(), 1); } #[test] fn test_delta_serialization_simple_foo() { // 4 bytes - let mut delta_writer = DeltaWriter::with_mtu(198); + let mut delta_writer = DeltaSerializer::with_mtu(198); // ChitchatId takes 27 bytes = 15 bytes + 2 bytes for node length + "node-10001".len(). let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (node). - assert!(delta_writer.add_node(node1, heartbeat)); + assert!(delta_writer.try_add_node(node1, heartbeat)); // +23 bytes: 2 bytes (key length) + 5 bytes (key) + 7 bytes (values) + 8 bytes (version) + // 1 bytes (empty tombstone). - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key11", VersionedValue { value: "val11".to_string(), @@ -305,7 +483,7 @@ mod tests { )); // +26 bytes: 2 bytes (key length) + 5 bytes (key) + 8 bytes (version) + // 9 bytes (empty tombstone). - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key12", VersionedValue { value: "".to_string(), @@ -317,10 +495,10 @@ mod tests { let node2 = ChitchatId::for_local_test(10_002); let heartbeat = Heartbeat(0); // +37 bytes - assert!(delta_writer.add_node(node2, heartbeat)); + assert!(delta_writer.try_add_node(node2, heartbeat)); // +23 bytes. - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key21", VersionedValue { value: "val21".to_string(), @@ -329,7 +507,7 @@ mod tests { }, )); // +23 bytes. - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key22", VersionedValue { value: "val22".to_string(), @@ -337,23 +515,22 @@ mod tests { tombstone: None, }, )); - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 173); + test_aux_delta_writer(delta_writer, 99); } #[test] fn test_delta_serialization_simple_node() { - // 4 bytes - let mut delta_writer = DeltaWriter::with_mtu(124); + // 1 bytes (End tag) + let mut delta_writer = DeltaSerializer::with_mtu(128); // ChitchatId takes 27 bytes = 15 bytes + 2 bytes for node length + "node-10001".len(). let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); - // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (node). - assert!(delta_writer.add_node(node1, heartbeat)); + // +37 bytes = 8 bytes (heartbeat) + 27 bytes (node) + 2bytes (block length) + assert!(delta_writer.try_add_node(node1, heartbeat)); - // +23 bytes. - assert!(delta_writer.add_kv( + // +24 bytes (kv + op tag) + assert!(delta_writer.try_add_kv( "key11", VersionedValue { value: "val11".to_string(), @@ -361,8 +538,9 @@ mod tests { tombstone: None, } )); - // +23 bytes. - assert!(delta_writer.add_kv( + + // +24 bytes. (kv + op tag) + assert!(delta_writer.try_add_kv( "key12", VersionedValue { value: "val12".to_string(), @@ -374,26 +552,35 @@ mod tests { let node2 = ChitchatId::for_local_test(10_002); let heartbeat = Heartbeat(0); // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (node). - assert!(delta_writer.add_node(node2, heartbeat)); + assert!(delta_writer.try_add_node(node2, heartbeat)); + test_aux_delta_writer(delta_writer, 80); + } - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 124); + #[track_caller] + fn test_aux_delta_writer(delta_writer: DeltaSerializer, expected_len: usize) { + let delta: Delta = delta_writer.finish(); + test_serdeser_aux(&delta, expected_len) } #[test] fn test_delta_serialization_simple_with_nodes_to_reset() { - // 4 bytes - let mut delta_writer = DeltaWriter::with_mtu(151); - // +27 bytes (ChitchatId). - assert!(delta_writer.add_node_to_reset(ChitchatId::for_local_test(10_000))); + // 1 bytes (end tag) + let mut delta_writer = DeltaSerializer::with_mtu(155); + + // +27 bytes (ChitchatId) + 1 (op tag) + 3 bytes (block len) + // = 32 bytes + assert!(delta_writer.try_add_node_to_reset(ChitchatId::for_local_test(10_000))); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); - // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (ChitchatId). - assert!(delta_writer.add_node(node1, heartbeat)); - // +23 bytes. - assert!(delta_writer.add_kv( + // +8 bytes (heartbeat) + 27 bytes (ChitchatId) + (1 op tag) + 3 bytes (pessimistic new + // block) = 71 + assert!(delta_writer.try_add_node(node1, heartbeat)); + + // +23 bytes (kv) + 1 (op tag) + // = 95 + assert!(delta_writer.try_add_kv( "key11", VersionedValue { value: "val11".to_string(), @@ -401,8 +588,9 @@ mod tests { tombstone: None, } )); - // +23 bytes. - assert!(delta_writer.add_kv( + // +23 bytes (kv) + 1 (op tag) + // = 119 + assert!(delta_writer.try_add_kv( "key12", VersionedValue { value: "val12".to_string(), @@ -413,25 +601,25 @@ mod tests { let node2 = ChitchatId::for_local_test(10_002); let heartbeat = Heartbeat(0); - // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (ChitchatId). - assert!(delta_writer.add_node(node2, heartbeat)); - - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 151); + // +8 bytes (heartbeat) + 27 bytes (ChitchatId) + 1 byte (op tag) + // = 155 + assert!(delta_writer.try_add_node(node2, heartbeat)); + // The block got compressed. + test_aux_delta_writer(delta_writer, 85); } #[test] fn test_delta_serialization_exceed_mtu_on_add_node() { // 4 bytes. - let mut delta_writer = DeltaWriter::with_mtu(87); + let mut delta_writer = DeltaSerializer::with_mtu(100); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (ChitchatId). - assert!(delta_writer.add_node(node1, heartbeat)); + assert!(delta_writer.try_add_node(node1, heartbeat)); // +23 bytes. - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key11", VersionedValue { value: "val11".to_string(), @@ -440,7 +628,7 @@ mod tests { } )); // +23 bytes. - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key12", VersionedValue { value: "val12".to_string(), @@ -452,24 +640,24 @@ mod tests { let node2 = ChitchatId::for_local_test(10_002); let heartbeat = Heartbeat(0); // +37 bytes = 8 bytes (heartbeat) + 2 bytes (empty node delta) + 27 bytes (ChitchatId). - assert!(!delta_writer.add_node(node2, heartbeat)); + assert!(!delta_writer.try_add_node(node2, heartbeat)); - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 87); + // The block got compressed. + test_aux_delta_writer(delta_writer, 72); } #[test] fn test_delta_serialization_exceed_mtu_on_add_node_to_reset() { // 4 bytes. - let mut delta_writer = DeltaWriter::with_mtu(90); + let mut delta_writer = DeltaSerializer::with_mtu(100); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); // +37 bytes. - assert!(delta_writer.add_node(node1, heartbeat)); + assert!(delta_writer.try_add_node(node1, heartbeat)); // +23 bytes. - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key11", VersionedValue { value: "val11".to_string(), @@ -478,7 +666,7 @@ mod tests { } )); // +23 bytes. - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key12", VersionedValue { value: "val12".to_string(), @@ -488,23 +676,27 @@ mod tests { )); let node2 = ChitchatId::for_local_test(10_002); - assert!(!delta_writer.add_node_to_reset(node2)); + assert!(!delta_writer.try_add_node_to_reset(node2)); - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 87); + // The block got compressed. + test_aux_delta_writer(delta_writer, 72); } #[test] fn test_delta_serialization_exceed_mtu_on_add_kv() { - // 4 bytes. - let mut delta_writer = DeltaWriter::with_mtu(86); + // 1 bytes. + let mut delta_writer = DeltaSerializer::with_mtu(100); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); - // +37 bytes. - assert!(delta_writer.add_node(node1, heartbeat)); - // +23 bytes. - assert!(delta_writer.add_kv( + + // + 3 bytes (block tag) + 35 bytes (node) + 1 byte (op tag) + // = 40 + assert!(delta_writer.try_add_node(node1, heartbeat)); + + // +23 bytes (kv) + 1 (op tag) + 3 bytes (pessimistic block tag) + // = 67 + assert!(delta_writer.try_add_kv( "key11", VersionedValue { value: "val11".to_string(), @@ -512,30 +704,30 @@ mod tests { tombstone: None, } )); - // +23 bytes. - assert!(!delta_writer.add_kv( + + // +33 bytes (kv) + 1 (op tag) + // = 101 (exceeding mtu!) + assert!(!delta_writer.try_add_kv( "key12", VersionedValue { - value: "val12".to_string(), + value: "val12aaaaaaaaaabcc".to_string(), version: 2, tombstone: None, } )); - - let delta: Delta = delta_writer.into(); - test_serdeser_aux(&delta, 64); + test_aux_delta_writer(delta_writer, 64); } #[test] #[should_panic] fn test_delta_serialization_panic_if_add_after_exceed() { - let mut delta_writer = DeltaWriter::with_mtu(62); + let mut delta_writer = DeltaSerializer::with_mtu(62); let node1 = ChitchatId::for_local_test(10_001); let heartbeat = Heartbeat(0); - assert!(delta_writer.add_node(node1, heartbeat)); + assert!(delta_writer.try_add_node(node1, heartbeat)); - assert!(delta_writer.add_kv( + assert!(delta_writer.try_add_kv( "key11", VersionedValue { value: "val11".to_string(), @@ -543,7 +735,7 @@ mod tests { tombstone: None, } )); - assert!(!delta_writer.add_kv( + assert!(!delta_writer.try_add_kv( "key12", VersionedValue { value: "val12".to_string(), @@ -551,7 +743,7 @@ mod tests { tombstone: None, } )); - delta_writer.add_kv( + delta_writer.try_add_kv( "key13", VersionedValue { value: "val12".to_string(), @@ -560,4 +752,17 @@ mod tests { }, ); } + + #[test] + fn test_delta_op_tag() { + let mut num_valid_tags = 0; + for b in 0..=u8::MAX { + if let Ok(tag) = DeltaOpTag::try_from(b) { + let tag_byte: u8 = tag.into(); + assert_eq!(b, tag_byte); + num_valid_tags += 1; + } + } + assert_eq!(num_valid_tags, 3); + } } diff --git a/chitchat/src/digest.rs b/chitchat/src/digest.rs index 9e0252b..6cb50e9 100644 --- a/chitchat/src/digest.rs +++ b/chitchat/src/digest.rs @@ -45,7 +45,18 @@ impl Serializable for Digest { node_digest.max_version.serialize(buf); } } + fn serialized_len(&self) -> usize { + let mut len = (self.node_digests.len() as u16).serialized_len(); + for (chitchat_id, node_digest) in &self.node_digests { + len += chitchat_id.serialized_len(); + len += node_digest.heartbeat.serialized_len(); + len += node_digest.max_version.serialized_len(); + } + len + } +} +impl Deserializable for Digest { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let num_nodes = u16::deserialize(buf)?; let mut node_digests: BTreeMap = Default::default(); @@ -59,14 +70,4 @@ impl Serializable for Digest { } Ok(Digest { node_digests }) } - - fn serialized_len(&self) -> usize { - let mut len = (self.node_digests.len() as u16).serialized_len(); - for (chitchat_id, node_digest) in &self.node_digests { - len += chitchat_id.serialized_len(); - len += node_digest.heartbeat.serialized_len(); - len += node_digest.max_version.serialized_len(); - } - len - } } diff --git a/chitchat/src/lib.rs b/chitchat/src/lib.rs index e901a85..6249dbd 100644 --- a/chitchat/src/lib.rs +++ b/chitchat/src/lib.rs @@ -29,7 +29,6 @@ use tracing::{error, warn}; pub use self::configuration::ChitchatConfig; pub use self::state::{ClusterStateSnapshot, NodeState}; use crate::digest::Digest; -use crate::message::syn_ack_serialized_len; pub use crate::message::ChitchatMessage; pub use crate::server::{spawn_chitchat, ChitchatHandle}; use crate::state::ClusterState; @@ -95,6 +94,14 @@ impl Chitchat { } } + fn process_delta(&mut self, delta: Delta) { + // Warning: order matters here. + // `report_heartbeats` will compare the current known heartbeat with the one + // in the delta, while `apply_delta` is actually updating this heartbeat. + self.report_heartbeats(&delta); + self.cluster_state.apply_delta(delta); + } + pub(crate) fn process_message(&mut self, msg: ChitchatMessage) -> Option { match msg { ChitchatMessage::Syn { cluster_id, digest } => { @@ -110,10 +117,8 @@ impl Chitchat { let scheduled_for_deletion: HashSet<_> = self.scheduled_for_deletion_nodes().collect(); let self_digest = self.compute_digest(&scheduled_for_deletion); - let empty_delta = Delta::default(); - let delta_mtu = MAX_UDP_DATAGRAM_PAYLOAD_SIZE - - syn_ack_serialized_len(&self_digest, &empty_delta); - let delta = self.cluster_state.compute_delta( + let delta_mtu = MAX_UDP_DATAGRAM_PAYLOAD_SIZE - 1 - digest.serialized_len(); + let delta = self.cluster_state.compute_partial_delta_respecting_mtu( &digest, delta_mtu, &scheduled_for_deletion, @@ -129,7 +134,7 @@ impl Chitchat { self.cluster_state.apply_delta(delta); let scheduled_for_deletion = self.scheduled_for_deletion_nodes().collect::>(); - let delta = self.cluster_state.compute_delta( + let delta = self.cluster_state.compute_partial_delta_respecting_mtu( &digest, MAX_UDP_DATAGRAM_PAYLOAD_SIZE - 1, &scheduled_for_deletion, @@ -138,8 +143,7 @@ impl Chitchat { Some(ChitchatMessage::Ack { delta }) } ChitchatMessage::Ack { delta } => { - self.report_heartbeats(&delta); - self.cluster_state.apply_delta(delta); + self.process_delta(delta); None } ChitchatMessage::BadCluster => { @@ -160,14 +164,17 @@ impl Chitchat { /// Reports heartbeats to the failure detector for nodes in the delta for which we received an /// update. fn report_heartbeats(&mut self, delta: &Delta) { - for (chitchat_id, node_delta) in &delta.node_deltas { - if let Some(node_state) = self.cluster_state.node_states.get(chitchat_id) { + for node_delta in &delta.node_deltas { + if let Some(node_state) = self.cluster_state.node_states.get(&node_delta.chitchat_id) { if node_state.heartbeat() < node_delta.heartbeat { - self.failure_detector.report_heartbeat(chitchat_id); + self.failure_detector + .report_heartbeat(&node_delta.chitchat_id); } } else { - self.failure_detector.report_unknown(chitchat_id); - self.failure_detector.update_node_liveness(chitchat_id); + self.failure_detector + .report_unknown(&node_delta.chitchat_id); + self.failure_detector + .update_node_liveness(&node_delta.chitchat_id); } } } diff --git a/chitchat/src/message.rs b/chitchat/src/message.rs index 863b040..92e3635 100644 --- a/chitchat/src/message.rs +++ b/chitchat/src/message.rs @@ -4,7 +4,7 @@ use anyhow::Context; use crate::delta::Delta; use crate::digest::Digest; -use crate::serialize::Serializable; +use crate::serialize::{Deserializable, Serializable}; /// Chitchat message. /// @@ -73,6 +73,21 @@ impl Serializable for ChitchatMessage { } } + fn serialized_len(&self) -> usize { + match self { + ChitchatMessage::Syn { cluster_id, digest } => { + 1 + cluster_id.serialized_len() + digest.serialized_len() + } + ChitchatMessage::SynAck { digest, delta } => { + 1 + digest.serialized_len() + delta.serialized_len() + } + ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), + ChitchatMessage::BadCluster => 1, + } + } +} + +impl Deserializable for ChitchatMessage { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let code = buf .first() @@ -98,21 +113,6 @@ impl Serializable for ChitchatMessage { MessageType::BadCluster => Ok(Self::BadCluster), } } - - fn serialized_len(&self) -> usize { - match self { - ChitchatMessage::Syn { cluster_id, digest } => { - 1 + cluster_id.serialized_len() + digest.serialized_len() - } - ChitchatMessage::SynAck { digest, delta } => syn_ack_serialized_len(digest, delta), - ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), - ChitchatMessage::BadCluster => 1, - } - } -} - -pub(crate) fn syn_ack_serialized_len(digest: &Digest, delta: &Delta) -> usize { - 1 + digest.serialized_len() + delta.serialized_len() } #[cfg(test)] @@ -149,7 +149,8 @@ mod tests { digest: Digest::default(), delta: Delta::default(), }; - test_serdeser_aux(&syn_ack, 7); + // 1 (message tag) + 2 (digest len) + 1 (delta end op) + test_serdeser_aux(&syn_ack, 4); } { // 2 bytes. @@ -165,6 +166,7 @@ mod tests { delta.add_node(node.clone(), Heartbeat(0)); // +29 bytes. delta.add_kv(&node, "key", "value", 0, Some(5)); + delta.set_serialized_len(70); let syn_ack = ChitchatMessage::SynAck { digest, delta }; // 1 bytes (syn ack message) + 45 bytes (digest) + 69 bytes (delta). @@ -177,7 +179,7 @@ mod tests { { let delta = Delta::default(); let ack = ChitchatMessage::Ack { delta }; - test_serdeser_aux(&ack, 5); + test_serdeser_aux(&ack, 2); } { // 4 bytes. @@ -187,7 +189,7 @@ mod tests { delta.add_node(node.clone(), Heartbeat(0)); // +29 bytes. delta.add_kv(&node, "key", "value", 0, Some(5)); - + delta.set_serialized_len(70); let ack = ChitchatMessage::Ack { delta }; test_serdeser_aux(&ack, 71); } diff --git a/chitchat/src/serialize.rs b/chitchat/src/serialize.rs index 5536383..4de64c2 100644 --- a/chitchat/src/serialize.rs +++ b/chitchat/src/serialize.rs @@ -1,7 +1,8 @@ use std::io::BufRead; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use anyhow::bail; +use anyhow::{bail, Context}; +use bytes::Buf; use crate::{ChitchatId, Heartbeat}; @@ -10,7 +11,7 @@ use crate::{ChitchatId, Heartbeat}; /// Chitchat uses a custom binary serialization format. /// The point of this format is to make it possible /// to truncate the delta payload to a given mtu. -pub trait Serializable: Sized { +pub trait Serializable { fn serialize(&self, buf: &mut Vec); fn serialize_to_vec(&self) -> Vec { @@ -19,19 +20,33 @@ pub trait Serializable: Sized { buf } - fn deserialize(buf: &mut &[u8]) -> anyhow::Result; - fn serialized_len(&self) -> usize; } -impl Serializable for u16 { +pub trait Deserializable: Sized { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result; +} + +impl Serializable for u8 { fn serialize(&self, buf: &mut Vec) { - self.to_le_bytes().serialize(buf); + buf.push(*self) } + fn serialized_len(&self) -> usize { + 1 + } +} + +impl Deserializable for u8 { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let u16_bytes: [u8; 2] = Serializable::deserialize(buf)?; - Ok(Self::from_le_bytes(u16_bytes)) + let byte: [u8; 1] = Deserializable::deserialize(buf)?; + Ok(byte[0]) + } +} + +impl Serializable for u16 { + fn serialize(&self, buf: &mut Vec) { + self.to_le_bytes().serialize(buf); } fn serialized_len(&self) -> usize { @@ -39,20 +54,44 @@ impl Serializable for u16 { } } -impl Serializable for u64 { +impl Deserializable for u16 { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let u16_bytes: [u8; 2] = Deserializable::deserialize(buf)?; + Ok(Self::from_le_bytes(u16_bytes)) + } +} + +impl Serializable for u32 { fn serialize(&self, buf: &mut Vec) { self.to_le_bytes().serialize(buf); } + fn serialized_len(&self) -> usize { + 4 + } +} + +impl Deserializable for u32 { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let u64_bytes: [u8; 8] = Serializable::deserialize(buf)?; - Ok(Self::from_le_bytes(u64_bytes)) + let u32_bytes: [u8; 4] = Deserializable::deserialize(buf)?; + Ok(Self::from_le_bytes(u32_bytes)) } +} +impl Serializable for u64 { + fn serialize(&self, buf: &mut Vec) { + self.to_le_bytes().serialize(buf); + } fn serialized_len(&self) -> usize { 8 } } +impl Deserializable for u64 { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let u64_bytes: [u8; 8] = Deserializable::deserialize(buf)?; + Ok(Self::from_le_bytes(u64_bytes)) + } +} impl Serializable for Option { fn serialize(&self, buf: &mut Vec) { @@ -61,16 +100,6 @@ impl Serializable for Option { tombstone.serialize(buf); } } - - fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let is_some: bool = Serializable::deserialize(buf)?; - if is_some { - let u64_value = Serializable::deserialize(buf)?; - return Ok(Some(u64_value)); - } - Ok(None) - } - fn serialized_len(&self) -> usize { if self.is_some() { 9 @@ -80,19 +109,31 @@ impl Serializable for Option { } } +impl Deserializable for Option { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let is_some: bool = Deserializable::deserialize(buf)?; + if is_some { + let u64_value = Deserializable::deserialize(buf)?; + return Ok(Some(u64_value)); + } + Ok(None) + } +} + impl Serializable for bool { fn serialize(&self, buf: &mut Vec) { buf.push(*self as u8); } + fn serialized_len(&self) -> usize { + 1 + } +} +impl Deserializable for bool { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let bool_byte: [u8; 1] = Serializable::deserialize(buf)?; + let bool_byte: [u8; 1] = Deserializable::deserialize(buf)?; Ok(bool_byte[0] != 0) } - - fn serialized_len(&self) -> usize { - 1 - } } #[repr(u8)] @@ -129,42 +170,55 @@ impl Serializable for IpAddr { } } + fn serialized_len(&self) -> usize { + 1 + match self { + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 16, + } + } +} + +impl Deserializable for IpAddr { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let ip_version_byte: [u8; 1] = Serializable::deserialize(buf)?; + let ip_version_byte: [u8; 1] = Deserializable::deserialize(buf)?; let ip_version = IpVersion::try_from(ip_version_byte[0])?; - match ip_version { IpVersion::V4 => { - let bytes: [u8; 4] = Serializable::deserialize(buf)?; + let bytes: [u8; 4] = Deserializable::deserialize(buf)?; Ok(Ipv4Addr::from(bytes).into()) } IpVersion::V6 => { - let bytes: [u8; 16] = Serializable::deserialize(buf)?; + let bytes: [u8; 16] = Deserializable::deserialize(buf)?; Ok(Ipv6Addr::from(bytes).into()) } } } - - fn serialized_len(&self) -> usize { - 1 + match self { - IpAddr::V4(_) => 4, - IpAddr::V6(_) => 16, - } - } } impl Serializable for String { fn serialize(&self, buf: &mut Vec) { - (self.len() as u16).serialize(buf); - buf.extend(self.as_bytes()) + self.as_str().serialize(buf) + } + + fn serialized_len(&self) -> usize { + self.as_str().serialized_len() } +} +impl Deserializable for String { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let len: usize = u16::deserialize(buf)? as usize; let s = std::str::from_utf8(&buf[..len])?.to_string(); buf.consume(len); Ok(s) } +} + +impl Serializable for str { + fn serialize(&self, buf: &mut Vec) { + (self.len() as u16).serialize(buf); + buf.extend(self.as_bytes()) + } fn serialized_len(&self) -> usize { 2 + self.len() @@ -175,7 +229,12 @@ impl Serializable for [u8; N] { fn serialize(&self, buf: &mut Vec) { buf.extend_from_slice(&self[..]); } + fn serialized_len(&self) -> usize { + N + } +} +impl Deserializable for [u8; N] { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { if buf.len() < N { bail!("Buffer too short"); @@ -184,10 +243,6 @@ impl Serializable for [u8; N] { buf.consume(N); Ok(val_bytes) } - - fn serialized_len(&self) -> usize { - N - } } impl Serializable for SocketAddr { @@ -196,15 +251,17 @@ impl Serializable for SocketAddr { self.port().serialize(buf); } + fn serialized_len(&self) -> usize { + self.ip().serialized_len() + self.port().serialized_len() + } +} + +impl Deserializable for SocketAddr { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let ip_addr = IpAddr::deserialize(buf)?; let port = u16::deserialize(buf)?; Ok(SocketAddr::new(ip_addr, port)) } - - fn serialized_len(&self) -> usize { - self.ip().serialized_len() + self.port().serialized_len() - } } impl Serializable for ChitchatId { @@ -214,6 +271,14 @@ impl Serializable for ChitchatId { self.gossip_advertise_addr.serialize(buf) } + fn serialized_len(&self) -> usize { + self.node_id.serialized_len() + + self.generation_id.serialized_len() + + self.gossip_advertise_addr.serialized_len() + } +} + +impl Deserializable for ChitchatId { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let node_id = String::deserialize(buf)?; let generation_id = u64::deserialize(buf)?; @@ -224,12 +289,6 @@ impl Serializable for ChitchatId { gossip_advertise_addr, }) } - - fn serialized_len(&self) -> usize { - self.node_id.serialized_len() - + self.generation_id.serialized_len() - + self.gossip_advertise_addr.serialized_len() - } } impl Serializable for Heartbeat { @@ -237,19 +296,190 @@ impl Serializable for Heartbeat { self.0.serialize(buf); } + fn serialized_len(&self) -> usize { + self.0.serialized_len() + } +} + +impl Deserializable for Heartbeat { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let heartbeat = u64::deserialize(buf)?; Ok(Self(heartbeat)) } +} + +/// A compressed stream writer receives a sequence of `Serializable` and +/// serialize/compresses into blocks of a configurable size. +/// +/// Block are tagged, so that blocks with a high entropy are stored kept "uncompressed". +/// +/// The stream gives the client an upperbound of what the overall payload length would +/// be if another item was appended. +/// This makes it possible to enforce a `mtu`. +pub struct CompressedStreamWriter { + output: Vec, + // temporary buffer used for block compression. + uncompressed_block: Vec, + // ongoing block being serialized. + compressed_block: Vec, + block_threshold: usize, +} + +impl CompressedStreamWriter { + pub fn with_block_threshold(block_threshold: u16) -> CompressedStreamWriter { + let block_threshold = block_threshold as usize; + let output = Vec::with_capacity(block_threshold); + CompressedStreamWriter { + output, + uncompressed_block: Vec::with_capacity(block_threshold * 2), + compressed_block: Vec::with_capacity(block_threshold), + block_threshold, + } + } + + /// Returns an upperbound of the serialized len after appending `s` + pub fn serialized_len_upperbound_after(&self, item: &S) -> usize { + let new_item_serialized_len = item.serialized_len(); + assert!(new_item_serialized_len > 0); + const BLOCK_META_LEN: usize = 3; + let new_block_needed = + self.uncompressed_block.len() + new_item_serialized_len > self.block_threshold; + if new_block_needed { + BLOCK_META_LEN + self.output.len() + self.uncompressed_block.len() + // current block + BLOCK_META_LEN + new_item_serialized_len // new block + + 1 // No more blocks tag. + } else { + BLOCK_META_LEN + self.output.len() + self.uncompressed_block.len() + new_item_serialized_len // current block + + 1 // No more blocks tag. + } + } + + /// Appends a new item to the stream. Items must be at most `u16::MAX` bytes long. + pub fn append(&mut self, item: &S) { + let item_len = item.serialized_len(); + assert!(item_len <= u16::MAX as usize); + item.serialize(&mut self.uncompressed_block); + while self.uncompressed_block.len() > self.block_threshold { + // time to flush our current block. + self.flush_block(); + } + } + + /// Flush the first `block_threshold` bytes from the `uncompressed_block` buffer, + /// and remove them from the buffer. + /// + /// If the buffer less bytes than `block_threshold`, compress the bytes available. + /// (This happens for the last block.) + fn flush_block(&mut self) { + if self.uncompressed_block.is_empty() { + return; + } + let num_bytes_to_compress = self.uncompressed_block.len().min(self.block_threshold); + + self.compressed_block.resize(num_bytes_to_compress, 0u8); + match zstd::bulk::compress_to_buffer( + &self.uncompressed_block[..num_bytes_to_compress], + &mut self.compressed_block[..], + 0, // default compression level + ) { + Ok(compressed_len) => { + BlockType::Compressed.serialize(&mut self.output); + let compressed_len_u16 = u16::try_from(compressed_len).unwrap(); + compressed_len_u16.serialize(&mut self.output); + self.output.extend(&self.compressed_block[..compressed_len]); + } + // The compressed version was actually longer than the decompressed one. + // Let's keep the block uncomopressed + Err(_) => { + BlockType::Uncompressed.serialize(&mut self.output); + let num_bytes_to_compress_u16 = + u16::try_from(num_bytes_to_compress).expect("uncompressed block too big"); + num_bytes_to_compress_u16.serialize(&mut self.output); + self.output + .extend(&self.uncompressed_block[..num_bytes_to_compress]); + } + } + self.uncompressed_block.drain(..num_bytes_to_compress); + } + pub fn finish(mut self) -> Vec { + self.flush_block(); + BlockType::NoMoreBlocks.serialize(&mut self.output); + self.output + } +} + +pub fn deserialize_stream(buf: &mut &[u8]) -> anyhow::Result> { + let mut decompressed_data = Vec::new(); + let mut decompressed_buffer = vec![0; u16::MAX as usize]; + loop { + let block_type = BlockType::deserialize(buf)?; + match block_type { + BlockType::Compressed => { + let len = u16::deserialize(buf)? as usize; + let compressed_block_bytes = &buf[..len]; + let uncompressed_len = zstd::bulk::decompress_to_buffer( + compressed_block_bytes, + &mut decompressed_buffer[..u16::MAX as usize], + ) + .context("failed to decompress block")?; + buf.advance(len); + decompressed_data.extend_from_slice(&decompressed_buffer[..uncompressed_len]); + } + BlockType::Uncompressed => { + let len = u16::deserialize(buf)? as usize; + decompressed_data.extend_from_slice(&buf[..len]); + buf.advance(len); + } + BlockType::NoMoreBlocks => { + break; + } + } + } + let mut decompressed_cursor = &decompressed_data[..]; + let mut items = Vec::new(); + while !decompressed_cursor.is_empty() { + let item = D::deserialize(&mut decompressed_cursor)?; + items.push(item); + } + Ok(items) +} + +#[repr(u8)] +#[derive(Copy, Clone)] +enum BlockType { + NoMoreBlocks, + Compressed, + Uncompressed, +} + +impl Serializable for BlockType { + fn serialize(&self, buf: &mut Vec) { + (*self as u8).serialize(buf) + } fn serialized_len(&self) -> usize { - self.0.serialized_len() + 1 + } +} + +impl Deserializable for BlockType { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let byte = u8::deserialize(buf)?; + match byte { + 0 => Ok(BlockType::NoMoreBlocks), + 1 => Ok(BlockType::Compressed), + 2 => Ok(BlockType::Uncompressed), + _ => anyhow::bail!("invalid block type"), + } } } #[cfg(test)] #[track_caller] -pub fn test_serdeser_aux(obj: &T, num_bytes: usize) { +pub fn test_serdeser_aux( + obj: &T, + num_bytes: usize, +) { let mut buf = Vec::new(); obj.serialize(&mut buf); assert_eq!(buf.len(), obj.serialized_len()); @@ -260,6 +490,10 @@ pub fn test_serdeser_aux(obj: &T, #[cfg(test)] mod tests { + use proptest::proptest; + use rand::distributions::Alphanumeric; + use rand::Rng; + use super::*; #[test] @@ -296,4 +530,144 @@ mod tests { test_serdeser_aux(&Some(1), 9); test_serdeser_aux(&None, 1); } + + #[test] + fn test_serialize_block_type() { + let mut valid_vals_count = 0; + for b in 0..=u8::MAX { + if let Ok(block_type) = BlockType::deserialize(&mut &[b][..]) { + valid_vals_count += 1; + let serialized = block_type.serialize_to_vec(); + assert_eq!(&serialized, &[b]); + } + } + assert_eq!(valid_vals_count, 3); + } + + // An array of 10 small sentences for tests. + const TEXT_SAMPLES: [&str; 10] = [ + "I'm happy.", + "She exercises every morning.", + "His dog barks loudly.", + "My school starts at 8:00.", + "We always eat dinner together.", + "They take the bus to work.", + "He doesn't like vegetables.", + "I don't want anything to drink.", + "hello Happy tax payer", + "do you like tea?", + ]; + + #[test] + fn test_compressed_serialized_stream() { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(1_000); + let mut uncompressed_len = 0; + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + compressed_stream_writer.append(sentence); + uncompressed_len += sentence.len(); + } + let buf = compressed_stream_writer.finish(); + let mut cursor = &buf[..]; + assert!(buf.len() * 3 < uncompressed_len); + let vals: Vec = super::deserialize_stream(&mut cursor).unwrap(); + assert_eq!(vals.len(), 100); + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + assert_eq!(&vals[i], sentence); + } + assert!(cursor.is_empty()); + } + + #[test] + fn test_compressed_serialized_stream_with_random_data() { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(200); + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + compressed_stream_writer.append(sentence); + } + for _ in 0..30 { + let rng = rand::thread_rng(); + let random_sentence: String = rng + .sample_iter(&Alphanumeric) + .take(30) + .map(char::from) + .collect(); + compressed_stream_writer.append(random_sentence.as_str()); + } + let buf = compressed_stream_writer.finish(); + let mut cursor = &buf[..]; + let vals: Vec = deserialize_stream(&mut cursor).unwrap(); + assert_eq!(vals.len(), 130); + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + assert_eq!(&vals[i], sentence); + } + assert!(cursor.is_empty()); + } + + #[test] + fn test_compressed_serialized_stream_len_when_there_are_no_compressed_blocks() { + { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(10); + let random_string = "xu1Y3l"; + let serialized_len_after_hello = + compressed_stream_writer.serialized_len_upperbound_after(random_string); + compressed_stream_writer.append(random_string); + let buffer = compressed_stream_writer.finish(); + // There are no compression opportunity here. The foreseen serialized len should be the + // same as the actual length. + assert_eq!(buffer.len(), serialized_len_after_hello); + } + { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(10); + let random_string = "pTs2yYd"; + compressed_stream_writer.append(random_string); + let random_string2 = "vLQRFPN6"; + let serialized_len_after_hello = + compressed_stream_writer.serialized_len_upperbound_after(random_string2); + compressed_stream_writer.append(random_string2); + let buffer = compressed_stream_writer.finish(); + // There are no compression opportunity here. The foreseen serialized len should be the + // same as the actual length. + assert_eq!(buffer.len(), serialized_len_after_hello); + } + } + + #[test] + fn test_empty() { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(1_000); + let len_upper_bound = compressed_stream_writer.serialized_len_upperbound_after(""); + compressed_stream_writer.append(""); + let buf = compressed_stream_writer.finish(); + let vals: Vec = deserialize_stream(&mut &buf[..]).unwrap(); + assert!(buf.len() <= len_upper_bound); + assert_eq!(vals.len(), 1); + assert_eq!(vals[0], ""); + } + + proptest! { + #[test] + fn test_proptest_compressed_stream(payload in proptest::collection::vec(".{0,1000}", 1..100)) { + let mut compressed_stream_writer: CompressedStreamWriter = + CompressedStreamWriter::with_block_threshold(1_000); + for s in &payload[..payload.len() - 1] { + compressed_stream_writer.append(s); + } + let len_upper_bound = compressed_stream_writer.serialized_len_upperbound_after(&payload[payload.len() - 1]); + compressed_stream_writer.append(&payload[payload.len() - 1]); + let buf = compressed_stream_writer.finish(); + let vals: Vec = deserialize_stream(&mut &buf[..]).unwrap(); + assert!(buf.len() <= len_upper_bound); + assert_eq!(vals.len(), payload.len()); + for (left, right) in vals.iter().zip(payload.iter()) { + assert_eq!(left, right); + } + } + } } diff --git a/chitchat/src/server.rs b/chitchat/src/server.rs index 51676db..4124607 100644 --- a/chitchat/src/server.rs +++ b/chitchat/src/server.rs @@ -601,7 +601,7 @@ mod tests { }; }; - let node_delta = &delta.node_deltas.get(&server_id).unwrap(); + let node_delta = delta.get(&server_id).unwrap(); let heartbeat = node_delta.heartbeat; assert_eq!(heartbeat, Heartbeat(3)); diff --git a/chitchat/src/state.rs b/chitchat/src/state.rs index bb96649..d6b78da 100644 --- a/chitchat/src/state.rs +++ b/chitchat/src/state.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use tokio::sync::watch; use tracing::warn; -use crate::delta::{Delta, DeltaWriter}; +use crate::delta::{Delta, DeltaSerializer, NodeDelta}; use crate::digest::{Digest, NodeDigest}; use crate::listener::Listeners; use crate::{ChitchatId, Heartbeat, KeyChangeEvent, Version, VersionedValue}; @@ -296,16 +296,21 @@ impl ClusterState { .retain(|chitchat_id, _| !delta.nodes_to_reset.contains(chitchat_id)); // Apply delta. - for (chitchat_id, node_delta) in delta.node_deltas { + for node_delta in delta.node_deltas { + let NodeDelta { + chitchat_id, + heartbeat, + key_values, + } = node_delta; let node_state = self .node_states .entry(chitchat_id.clone()) .or_insert_with(|| NodeState::new(chitchat_id, self.listeners.clone())); - if node_state.heartbeat < node_delta.heartbeat { - node_state.heartbeat = node_delta.heartbeat; + if node_state.heartbeat < heartbeat { + node_state.heartbeat = heartbeat; node_state.last_heartbeat = Instant::now(); } - for (key, versioned_value) in node_delta.key_values { + for (key, versioned_value) in key_values { node_state.max_version = node_state.max_version.max(versioned_value.version); node_state.set_versioned_value(key, versioned_value); } @@ -339,7 +344,7 @@ impl ClusterState { /// Implements the Scuttlebutt reconciliation with the scuttle-depth ordering. /// /// Nodes that are scheduled for deletion (as passed by argument) are not shared. - pub fn compute_delta( + pub fn compute_partial_delta_respecting_mtu( &self, digest: &Digest, mtu: usize, @@ -365,30 +370,31 @@ impl ClusterState { } stale_nodes.offer(chitchat_id, node_state, node_digest); } - let mut delta_writer = DeltaWriter::with_mtu(mtu); + let mut delta_serializer = DeltaSerializer::with_mtu(mtu); for chitchat_id in &nodes_to_reset { - if !delta_writer.add_node_to_reset((*chitchat_id).clone()) { + if !delta_serializer.try_add_node_to_reset((*chitchat_id).clone()) { break; } } + for stale_node in stale_nodes.into_iter() { - if !delta_writer.add_node(stale_node.chitchat_id.clone(), stale_node.heartbeat) { + if !delta_serializer.try_add_node(stale_node.chitchat_id.clone(), stale_node.heartbeat) + { break; } let mut added_something = false; for (key, versioned_value) in stale_node.stale_key_values() { added_something = true; - if !delta_writer.add_kv(key, versioned_value.clone()) { - let delta: Delta = delta_writer.into(); - return delta; + if !delta_serializer.try_add_kv(key, versioned_value.clone()) { + return delta_serializer.finish(); } } if !added_something && nodes_to_reset.contains(&stale_node.chitchat_id) { // send a sentinel element to update the max_version. Otherwise the node's vision // of max_version will be 0, and it may accept writes that are supposed to be // stale, but it can tell they are. - if !delta_writer.add_kv( + if !delta_serializer.try_add_kv( "__reset_sentinel", VersionedValue { value: String::new(), @@ -396,12 +402,11 @@ impl ClusterState { tombstone: Some(0), }, ) { - let delta: Delta = delta_writer.into(); - return delta; + return delta_serializer.finish(); } } } - delta_writer.into() + delta_serializer.finish() } } @@ -948,12 +953,18 @@ mod tests { dead_nodes: &HashSet<&ChitchatId>, expected_delta_atoms: &[(&ChitchatId, &str, &str, Version, Option)], ) { - let max_delta = cluster_state.compute_delta(digest, usize::MAX, dead_nodes, 10_000); + let max_delta = cluster_state.compute_partial_delta_respecting_mtu( + digest, + usize::MAX, + dead_nodes, + 10_000, + ); let mut buf = Vec::new(); max_delta.serialize(&mut buf); let mut mtu_per_num_entries = Vec::new(); - for mtu in 2..buf.len() { - let delta = cluster_state.compute_delta(digest, mtu, dead_nodes, 10_000); + for mtu in 100..buf.len() { + let delta = + cluster_state.compute_partial_delta_respecting_mtu(digest, mtu, dead_nodes, 10_000); let num_tuples = delta.num_tuples(); if mtu_per_num_entries.len() == num_tuples + 1 { continue; @@ -969,11 +980,17 @@ mod tests { expected_delta.add_kv(node, key, val, version, tombstone); } { - let delta = cluster_state.compute_delta(digest, mtu, dead_nodes, 10_000); + let delta = cluster_state + .compute_partial_delta_respecting_mtu(digest, mtu, dead_nodes, 10_000); assert_eq!(&delta, &expected_delta); } { - let delta = cluster_state.compute_delta(digest, mtu + 1, dead_nodes, 10_000); + let delta = cluster_state.compute_partial_delta_respecting_mtu( + digest, + mtu + 1, + dead_nodes, + 10_000, + ); assert_eq!(&delta, &expected_delta); } } @@ -1102,7 +1119,7 @@ mod tests { let node1 = ChitchatId::for_local_test(10_001); digest.add_node(node1.clone(), Heartbeat(0), 1); { - let delta = cluster_state.compute_delta( + let delta = cluster_state.compute_partial_delta_respecting_mtu( &digest, MAX_UDP_DATAGRAM_PAYLOAD_SIZE, &HashSet::new(), @@ -1114,13 +1131,14 @@ mod tests { expected_delta.add_kv(&node1, "key_b", "2", 2, None); expected_delta.add_node(node2.clone(), Heartbeat(0)); expected_delta.add_kv(&node2.clone(), "key_c", "3", 2, None); + expected_delta.set_serialized_len(78); assert_eq!(delta, expected_delta); } { // Node 1 heartbeat in digest + grace period (9_999) is inferior to the // node1's hearbeat in the cluster state. Thus we expect the cluster to compute a // delta that will reset node 1. - let delta = cluster_state.compute_delta( + let delta = cluster_state.compute_partial_delta_respecting_mtu( &digest, MAX_UDP_DATAGRAM_PAYLOAD_SIZE, &HashSet::new(), @@ -1133,6 +1151,7 @@ mod tests { expected_delta.add_kv(&node1, "key_b", "2", 2, None); expected_delta.add_node(node2.clone(), Heartbeat(0)); expected_delta.add_kv(&node2.clone(), "key_c", "3", 2, None); + expected_delta.set_serialized_len(91); assert_eq!(delta, expected_delta); } } diff --git a/chitchat/src/transport/channel.rs b/chitchat/src/transport/channel.rs index 11ba515..137076c 100644 --- a/chitchat/src/transport/channel.rs +++ b/chitchat/src/transport/channel.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Mutex}; use anyhow::{bail, Context}; use async_trait::async_trait; use tokio::sync::mpsc::{Receiver, Sender}; -use tracing::{debug, info}; +use tracing::info; use crate::serialize::Serializable; use crate::transport::{Socket, Transport}; @@ -98,7 +98,6 @@ impl ChannelTransport { bail!("Serialized message size exceeds MTU."); } } - debug!(num_bytes = num_bytes, "send"); let mut inner_lock = self.inner.lock().unwrap(); inner_lock.statistics.record_message_len(num_bytes); if let Some(to_addrs) = inner_lock.removed_links.get(&from_addr) { diff --git a/chitchat/src/transport/udp.rs b/chitchat/src/transport/udp.rs index 59d5e89..d1d8fb4 100644 --- a/chitchat/src/transport/udp.rs +++ b/chitchat/src/transport/udp.rs @@ -4,7 +4,7 @@ use anyhow::Context; use async_trait::async_trait; use tracing::warn; -use crate::serialize::Serializable; +use crate::serialize::{Deserializable, Serializable}; use crate::transport::{Socket, Transport}; use crate::{ChitchatMessage, MAX_UDP_DATAGRAM_PAYLOAD_SIZE}; diff --git a/chitchat/tests/cluster_test.rs b/chitchat/tests/cluster_test.rs index ccaf816..f493a8a 100644 --- a/chitchat/tests/cluster_test.rs +++ b/chitchat/tests/cluster_test.rs @@ -9,7 +9,7 @@ use chitchat::{ }; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; -use tracing::{debug, info}; +use tracing::{debug, error, info}; #[derive(Debug)] enum Operation { @@ -158,7 +158,7 @@ impl Simulator { } assert!(predicate_value); } else { - info!(node_id=%chitchat_id.node_id, state_snapshot=?chitchat_guard.state_snapshot(), "Node state missing."); + error!(node_id=%chitchat_id.node_id, state_snapshot=?chitchat_guard.state_snapshot(), "Node state missing."); panic!("Node state missing"); } }