From 4bd9325c6fc7271deb2f55dbfb02972acc846cc2 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Thu, 17 Oct 2024 16:21:46 +0200 Subject: [PATCH] fix(serialization): serialized_size_limit includes the header --- tfhe/src/safe_serialization.rs | 88 +++++++++++++++++----------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/tfhe/src/safe_serialization.rs b/tfhe/src/safe_serialization.rs index 1c0a3b6b37..e62a0a3d74 100644 --- a/tfhe/src/safe_serialization.rs +++ b/tfhe/src/safe_serialization.rs @@ -67,12 +67,6 @@ impl SerializationVersioningMode { } } -/// `HEADER_LENGTH_LIMIT` is the maximum `SerializationHeader` size which -/// `DeserializationConfig::deserialize_from` is going to try to read (it returns an error if -/// it's too big). -/// It helps prevent an attacker passing a very long header to exhaust memory. -const HEADER_LENGTH_LIMIT: u64 = 1000; - /// Header with global metadata about the serialized object. This help checking that we are not /// deserializing data that we can't handle. #[derive(Serialize, Deserialize)] @@ -152,7 +146,7 @@ impl SerializationConfig { /// Creates a new serialization config. The default configuration will serialize the object /// with versioning information for backward compatibility. /// `serialized_size_limit` is the size limit (in number of byte) of the serialized object - /// (excluding the header). + /// (including the header). pub fn new(serialized_size_limit: u64) -> Self { Self { versioned: SerializationVersioningMode::versioned(), @@ -196,15 +190,6 @@ impl SerializationConfig { } } - /// Returns the max length of the serialized header - fn header_length_limit(&self) -> u64 { - if self.serialized_size_limit == 0 { - 0 - } else { - HEADER_LENGTH_LIMIT - } - } - /// Returns the size the object would take if serialized using the current config /// /// The size is returned as a u64 to handle the serialization of large buffers under 32b @@ -236,21 +221,21 @@ impl SerializationConfig { object: &T, mut writer: impl std::io::Write, ) -> bincode::Result<()> { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); + let options = bincode::DefaultOptions::new().with_fixint_encoding(); let header = self.create_header::(); + let header_size = options.serialized_size(&header)?; + options - .with_limit(self.header_length_limit()) + .with_limit(self.serialized_size_limit) .serialize_into(&mut writer, &header)?; match self.versioned { SerializationVersioningMode::Versioned { .. } => options - .with_limit(self.serialized_size_limit) + .with_limit(self.serialized_size_limit - header_size) .serialize_into(&mut writer, &object.versionize())?, SerializationVersioningMode::Unversioned { .. } => options - .with_limit(self.serialized_size_limit) + .with_limit(self.serialized_size_limit - header_size) .serialize_into(&mut writer, &object)?, }; @@ -284,22 +269,17 @@ impl NonConformantDeserializationConfig { self, mut reader: impl std::io::Read, ) -> Result { - if self.serialized_size_limit != 0 && self.serialized_size_limit <= HEADER_LENGTH_LIMIT { - return Err(format!( - "The provided size limit is too small, provide a limit of at least \ -{HEADER_LENGTH_LIMIT} bytes" - )); - } - - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); + let options = bincode::DefaultOptions::new().with_fixint_encoding(); let deserialized_header: SerializationHeader = options - .with_limit(self.header_length_limit()) + .with_limit(self.serialized_size_limit) .deserialize_from(&mut reader) .map_err(|err| err.to_string())?; + let header_size = options + .serialized_size(&deserialized_header) + .map_err(|err| err.to_string())?; + if self.validate_header { deserialized_header.validate::()?; } @@ -307,14 +287,14 @@ impl NonConformantDeserializationConfig { match deserialized_header.versioning_mode { SerializationVersioningMode::Versioned { .. } => { let deser_versioned = options - .with_limit(self.serialized_size_limit - self.header_length_limit()) + .with_limit(self.serialized_size_limit - header_size) .deserialize_from(&mut reader) .map_err(|err| err.to_string())?; T::unversionize(deser_versioned).map_err(|e| e.to_string()) } SerializationVersioningMode::Unversioned { .. } => options - .with_limit(self.serialized_size_limit - self.header_length_limit()) + .with_limit(self.serialized_size_limit - header_size) .deserialize_from(&mut reader) .map_err(|err| err.to_string()), } @@ -327,14 +307,6 @@ impl NonConformantDeserializationConfig { validate_header: self.validate_header, } } - - fn header_length_limit(&self) -> u64 { - if self.serialized_size_limit == 0 { - 0 - } else { - HEADER_LENGTH_LIMIT - } - } } impl DeserializationConfig { @@ -343,7 +315,7 @@ impl DeserializationConfig { /// By default, it will check that the serialization version and the name of the /// deserialized type are correct. /// `serialized_size_limit` is the size limit (in number of byte) of the serialized object - /// (excluding version and name serialization). + /// (include the safe serialization header). /// /// It will also check that the object is conformant with the parameter set given in /// `conformance_params`. Finally, it will check the compatibility of the loaded data with @@ -525,6 +497,34 @@ mod test_shortint { let dec = ck.decrypt(&ct2); assert_eq!(msg, dec); } + + #[test] + fn safe_deserialization_size_limit() { + let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); + + let msg = 2_u64; + + let ct = ck.encrypt(msg); + + let mut buffer = vec![]; + + let config = SerializationConfig::new(1 << 20).disable_versioning(); + + let size = config.serialized_size(&ct).unwrap(); + config.serialize_into(&ct, &mut buffer).unwrap(); + + assert_eq!(size as usize, buffer.len()); + + let ct2 = DeserializationConfig::new(size) + .deserialize_from::( + buffer.as_slice(), + &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + ) + .unwrap(); + + let dec = ck.decrypt(&ct2); + assert_eq!(msg, dec); + } } #[cfg(all(test, feature = "integer"))]