Skip to content

Commit

Permalink
fix(serialization): safe_serialization with unlimited size
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Oct 21, 2024
1 parent 4bd9325 commit dba7f22
Showing 1 changed file with 140 additions and 40 deletions.
180 changes: 140 additions & 40 deletions tfhe/src/safe_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Please use the versioned serialization mode for backward compatibility.",
#[derive(Clone)]
pub struct SerializationConfig {
versioned: SerializationVersioningMode,
serialized_size_limit: u64,
serialized_size_limit: Option<u64>,
}

impl SerializationConfig {
Expand All @@ -150,22 +150,22 @@ impl SerializationConfig {
pub fn new(serialized_size_limit: u64) -> Self {
Self {
versioned: SerializationVersioningMode::versioned(),
serialized_size_limit,
serialized_size_limit: Some(serialized_size_limit),
}
}

/// Creates a new serialization config without any size check.
pub fn new_with_unlimited_size() -> Self {
Self {
versioned: SerializationVersioningMode::versioned(),
serialized_size_limit: 0,
serialized_size_limit: None,
}
}

/// Disables the size limit for serialized objects
pub fn disable_size_limit(self) -> Self {
Self {
serialized_size_limit: 0,
serialized_size_limit: None,
..self
}
}
Expand All @@ -178,6 +178,14 @@ impl SerializationConfig {
}
}

/// Sets the size limit for this serialization config
pub fn with_size_limit(self, size: u64) -> Self {
Self {
serialized_size_limit: Some(size),
..self
}
}

/// Create a serialization header based on the current config
fn create_header<T: Named>(&self) -> SerializationHeader {
match self.versioned {
Expand Down Expand Up @@ -221,22 +229,41 @@ impl SerializationConfig {
object: &T,
mut writer: impl std::io::Write,
) -> bincode::Result<()> {
let options = bincode::DefaultOptions::new().with_fixint_encoding();
let options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0); // Force to explicitly set the limit for each serialization

let header = self.create_header::<T>();
let header_size = options.serialized_size(&header)?;
let header_size = options.with_no_limit().serialized_size(&header)?;

options
.with_limit(self.serialized_size_limit)
.serialize_into(&mut writer, &header)?;
if let Some(size_limit) = self.serialized_size_limit {
options
.with_limit(size_limit)
.serialize_into(&mut writer, &header)?;

match self.versioned {
SerializationVersioningMode::Versioned { .. } => options
.with_limit(self.serialized_size_limit - header_size)
.serialize_into(&mut writer, &object.versionize())?,
SerializationVersioningMode::Unversioned { .. } => options
.with_limit(self.serialized_size_limit - header_size)
.serialize_into(&mut writer, &object)?,
let options = options.with_limit(size_limit - header_size);

match self.versioned {
SerializationVersioningMode::Versioned { .. } => {
options.serialize_into(&mut writer, &object.versionize())?
}
SerializationVersioningMode::Unversioned { .. } => {
options.serialize_into(&mut writer, &object)?
}
};
} else {
let options = options.with_no_limit();

options.serialize_into(&mut writer, &header)?;

match self.versioned {
SerializationVersioningMode::Versioned { .. } => {
options.serialize_into(&mut writer, &object.versionize())?
}
SerializationVersioningMode::Unversioned { .. } => {
options.serialize_into(&mut writer, &object)?
}
};
};

Ok(())
Expand All @@ -247,7 +274,7 @@ impl SerializationConfig {
/// the various sanity checks that will be performed during deserialization.
#[derive(Copy, Clone)]
pub struct DeserializationConfig {
serialized_size_limit: u64,
serialized_size_limit: Option<u64>,
validate_header: bool,
}

Expand All @@ -257,46 +284,83 @@ pub struct DeserializationConfig {
/// This type should be created with [`DeserializationConfig::disable_conformance`]
#[derive(Copy, Clone)]
pub struct NonConformantDeserializationConfig {
serialized_size_limit: u64,
serialized_size_limit: Option<u64>,
validate_header: bool,
}

impl NonConformantDeserializationConfig {
/// Deserialize a header using the current config
fn deserialize_header(
&self,
reader: &mut impl std::io::Read,
) -> Result<SerializationHeader, String> {
let options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0);

if let Some(size_limit) = self.serialized_size_limit {
options
.with_limit(size_limit)
.deserialize_from(reader)
.map_err(|err| err.to_string())
} else {
options
.with_no_limit()
.deserialize_from(reader)
.map_err(|err| err.to_string())
}
}

/// Deserializes an object serialized by [`SerializationConfig::serialize_into`] from a
/// [reader](std::io::Read). Performs various sanity checks based on the deserialization config,
/// but skips conformance checks.
pub fn deserialize_from<T: DeserializeOwned + Unversionize + Named>(
self,
mut reader: impl std::io::Read,
) -> Result<T, String> {
let options = bincode::DefaultOptions::new().with_fixint_encoding();
let options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0); // Force to explicitly set the limit for each deserialization

let deserialized_header: SerializationHeader = options
.with_limit(self.serialized_size_limit)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())?;
let deserialized_header: SerializationHeader = self.deserialize_header(&mut reader)?;

let header_size = options
.with_no_limit()
.serialized_size(&deserialized_header)
.map_err(|err| err.to_string())?;

if self.validate_header {
deserialized_header.validate::<T>()?;
}

match deserialized_header.versioning_mode {
SerializationVersioningMode::Versioned { .. } => {
let deser_versioned = options
.with_limit(self.serialized_size_limit - header_size)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())?;
if let Some(size_limit) = self.serialized_size_limit {
let options = options.with_limit(size_limit - header_size);
match deserialized_header.versioning_mode {
SerializationVersioningMode::Versioned { .. } => {
let deser_versioned = options
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())?;

T::unversionize(deser_versioned).map_err(|e| e.to_string())
T::unversionize(deser_versioned).map_err(|e| e.to_string())
}
SerializationVersioningMode::Unversioned { .. } => options
.deserialize_from(&mut reader)
.map_err(|err| err.to_string()),
}
} else {
let options = options.with_no_limit();
match deserialized_header.versioning_mode {
SerializationVersioningMode::Versioned { .. } => {
let deser_versioned = options
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())?;

T::unversionize(deser_versioned).map_err(|e| e.to_string())
}
SerializationVersioningMode::Unversioned { .. } => options
.deserialize_from(&mut reader)
.map_err(|err| err.to_string()),
}
SerializationVersioningMode::Unversioned { .. } => options
.with_limit(self.serialized_size_limit - header_size)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string()),
}
}

Expand All @@ -322,23 +386,31 @@ impl DeserializationConfig {
/// the current *TFHE-rs* version.
pub fn new(serialized_size_limit: u64) -> Self {
Self {
serialized_size_limit,
serialized_size_limit: Some(serialized_size_limit),
validate_header: true,
}
}

/// Creates a new config without any size limit for the deserialized objects.
pub fn new_with_unlimited_size() -> Self {
Self {
serialized_size_limit: 0,
serialized_size_limit: None,
validate_header: true,
}
}

/// Disables the size limit for the serialized objects.
pub fn disable_size_limit(self) -> Self {
Self {
serialized_size_limit: 0,
serialized_size_limit: None,
..self
}
}

/// Sets the size limit for this deserialization config
pub fn with_size_limit(self, size: u64) -> Self {
Self {
serialized_size_limit: Some(size),
..self
}
}
Expand Down Expand Up @@ -429,7 +501,7 @@ mod test_shortint {
use crate::shortint::{gen_keys, Ciphertext};

#[test]
fn safe_deserialization_ct() {
fn safe_deserialization_ct_unversioned() {
let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);

let msg = 2_u64;
Expand Down Expand Up @@ -464,7 +536,7 @@ mod test_shortint {
}

#[test]
fn safe_deserialization_ct_versioned() {
fn safe_deserialization_ct() {
let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);

let msg = 2_u64;
Expand Down Expand Up @@ -498,6 +570,34 @@ mod test_shortint {
assert_eq!(msg, dec);
}

#[test]
fn safe_deserialization_ct_unlimited_size() {
let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);

let msg = 2_u64;

let ct = ck.encrypt(msg);

let mut buffer = vec![];

let config = SerializationConfig::new_with_unlimited_size();

let size = config.serialized_size(&ct).unwrap();
config.serialize_into(&ct, &mut buffer).unwrap();

assert_eq!(size as usize, buffer.len());

let ct2 = DeserializationConfig::new_with_unlimited_size()
.deserialize_from::<Ciphertext>(
buffer.as_slice(),
&PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(),
)
.unwrap();

let dec = ck.decrypt(&ct2);
assert_eq!(msg, dec);
}

#[test]
fn safe_deserialization_size_limit() {
let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
Expand All @@ -508,7 +608,7 @@ mod test_shortint {

let mut buffer = vec![];

let config = SerializationConfig::new(1 << 20).disable_versioning();
let config = SerializationConfig::new_with_unlimited_size().disable_versioning();

let size = config.serialized_size(&ct).unwrap();
config.serialize_into(&ct, &mut buffer).unwrap();
Expand Down

0 comments on commit dba7f22

Please sign in to comment.