From ab91b4c50da21b857c3806ebb00fcde98845ba74 Mon Sep 17 00:00:00 2001 From: Yan Chen <48968912+chenyan-dfinity@users.noreply.github.com> Date: Thu, 15 Feb 2024 15:31:35 -0800 Subject: [PATCH] add metering for deserialization --- rust/candid/src/de.rs | 133 ++++++++++++++++++++++++------------- rust/candid/tests/serde.rs | 2 +- 2 files changed, 89 insertions(+), 46 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 6a8dd624..1dc1c25e 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -20,6 +20,7 @@ use std::fmt::Write; use std::{collections::VecDeque, io::Cursor, mem::replace}; const MAX_TYPE_LEN: i32 = 500; +const DEFAULT_DECODING_COST: usize = 3_000_000; /// Use this struct to deserialize a sequence of Rust values (heterogeneous) from IDL binary message. pub struct IDLDeserialize<'de> { @@ -46,7 +47,7 @@ impl<'de> IDLDeserialize<'de> { "Cannot parse header".to_string() } })?; - de.zero_sized_values = config.zero_sized_values; + de.decoding_cost = config.decoding_cost; de.full_error_message = config.full_error_message; Ok(IDLDeserialize { de }) } @@ -141,18 +142,18 @@ impl<'de> IDLDeserialize<'de> { } pub struct Config { - zero_sized_values: usize, + decoding_cost: Option, full_error_message: bool, } impl Config { pub fn new() -> Self { Self { - zero_sized_values: 2_000_000, + decoding_cost: Some(DEFAULT_DECODING_COST), full_error_message: true, } } - pub fn set_zero_sized_values(&mut self, n: usize) -> &mut Self { - self.zero_sized_values = n; + pub fn set_decoding_cost(&mut self, n: Option) -> &mut Self { + self.decoding_cost = n; self } pub fn set_full_error_message(&mut self, n: bool) -> &mut Self { @@ -236,7 +237,7 @@ struct Deserializer<'de> { // Indicates whether to deserialize with IDLValue. // It only affects the field id generation in enum type. is_untyped: bool, - zero_sized_values: usize, + decoding_cost: Option, full_error_message: bool, #[cfg(not(target_arch = "wasm32"))] recursion_depth: u16, @@ -256,10 +257,7 @@ impl<'de> Deserializer<'de> { gamma: Gamma::default(), field_name: None, is_untyped: false, - #[cfg(not(target_arch = "wasm32"))] - zero_sized_values: 2_000_000, - #[cfg(target_arch = "wasm32")] - zero_sized_values: 0, + decoding_cost: Some(DEFAULT_DECODING_COST), #[cfg(not(target_arch = "wasm32"))] full_error_message: true, #[cfg(target_arch = "wasm32")] @@ -299,6 +297,7 @@ impl<'de> Deserializer<'de> { Ok(res) } fn check_subtype(&mut self) -> Result<()> { + self.add_cost(self.table.0.len())?; subtype_with_config( OptReport::Silence, &mut self.gamma, @@ -327,26 +326,26 @@ impl<'de> Deserializer<'de> { self.expect_type.as_ref(), TypeInner::Var(_) | TypeInner::Knot(_) ) { + self.add_cost(1)?; self.expect_type = self.table.trace_type(&self.expect_type)?; } if matches!( self.wire_type.as_ref(), TypeInner::Var(_) | TypeInner::Knot(_) ) { + self.add_cost(1)?; self.wire_type = self.table.trace_type(&self.wire_type)?; } Ok(()) } - fn is_zero_sized_type(&self, t: &Type) -> bool { - match t.as_ref() { - TypeInner::Null | TypeInner::Reserved => true, - TypeInner::Record(fs) => fs.iter().all(|f| { - let t = self.table.trace_type(&f.ty).unwrap(); - // recursive records have been replaced with empty already, it's safe to call without memoization. - self.is_zero_sized_type(&t) - }), - _ => false, + fn add_cost(&mut self, cost: usize) -> Result<()> { + if let Some(n) = self.decoding_cost { + if n < cost { + return Err(Error::msg("Decoding cost exceeds the limit")); + } + self.decoding_cost = Some(n - cost); } + Ok(()) } // Should always call set_field_name to set the field_name. After deserialize_identifier // processed the field_name, field_name will be reset to None. @@ -371,11 +370,13 @@ impl<'de> Deserializer<'de> { self.unroll_type()?; assert!(*self.expect_type == TypeInner::Int); let mut bytes = vec![0u8]; + let pos = self.input.position(); let int = match self.wire_type.as_ref() { TypeInner::Int => Int::decode(&mut self.input).map_err(Error::msg)?, TypeInner::Nat => Int(Nat::decode(&mut self.input).map_err(Error::msg)?.0.into()), t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))), }; + self.add_cost((self.input.position() - pos) as usize)?; bytes.extend_from_slice(&int.0.to_signed_bytes_le()); visitor.visit_byte_buf(bytes) } @@ -391,7 +392,9 @@ impl<'de> Deserializer<'de> { "nat" ); let mut bytes = vec![1u8]; + let pos = self.input.position(); let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; + self.add_cost((self.input.position() - pos) as usize)?; bytes.extend_from_slice(&nat.0.to_bytes_le()); visitor.visit_byte_buf(bytes) } @@ -405,14 +408,16 @@ impl<'de> Deserializer<'de> { "principal" ); let mut bytes = vec![2u8]; - let id = PrincipalBytes::read(&mut self.input)?.inner; - bytes.extend_from_slice(&id); + let id = PrincipalBytes::read(&mut self.input)?; + self.add_cost(id.len as usize + 1)?; + bytes.extend_from_slice(&id.inner); visitor.visit_byte_buf(bytes) } fn deserialize_reserved<'a, V>(&'a mut self, visitor: V) -> Result where V: Visitor<'de>, { + self.add_cost(1)?; let bytes = vec![3u8]; visitor.visit_byte_buf(bytes) } @@ -423,8 +428,9 @@ impl<'de> Deserializer<'de> { self.unroll_type()?; self.check_subtype()?; let mut bytes = vec![4u8]; - let id = PrincipalBytes::read(&mut self.input)?.inner; - bytes.extend_from_slice(&id); + let id = PrincipalBytes::read(&mut self.input)?; + self.add_cost(id.len as usize + 1)?; + bytes.extend_from_slice(&id.inner); visitor.visit_byte_buf(bytes) } fn deserialize_function<'a, V>(&'a mut self, visitor: V) -> Result @@ -437,13 +443,14 @@ impl<'de> Deserializer<'de> { return Err(Error::msg("Opaque reference not supported")); } let mut bytes = vec![5u8]; - let id = PrincipalBytes::read(&mut self.input)?.inner; + let id = PrincipalBytes::read(&mut self.input)?; let len = Len::read(&mut self.input)?.0; let meth = self.borrow_bytes(len)?; + self.add_cost(id.len as usize + len + 3)?; // TODO find a better way leb128::write::unsigned(&mut bytes, len as u64)?; bytes.extend_from_slice(meth); - bytes.extend_from_slice(&id); + bytes.extend_from_slice(&id.inner); visitor.visit_byte_buf(bytes) } fn deserialize_blob<'a, V>(&'a mut self, visitor: V) -> Result @@ -456,6 +463,7 @@ impl<'de> Deserializer<'de> { "blob" ); let len = Len::read(&mut self.input)?.0; + self.add_cost(len + 1)?; let blob = self.borrow_bytes(len)?; let mut bytes = Vec::with_capacity(len + 1); bytes.push(6u8); @@ -477,6 +485,7 @@ impl<'de> Deserializer<'de> { V: Visitor<'de>, { let len = Len::read(&mut self.input)?.0 as u64; + self.add_cost(len as usize + 1)?; Len::read(&mut self.input)?; let slice_len = self.input.get_ref().len() as u64; let pos = self.input.position(); @@ -516,6 +525,9 @@ impl<'de> Deserializer<'de> { Ok(v) } Err(Error::Subtype(_)) => { + // Remember the backtracking cost + self.decoding_cost = self_clone.decoding_cost; + self.add_cost(10)?; self.deserialize_ignored_any(serde::de::IgnoredAny)?; visitor.visit_none() } @@ -525,12 +537,13 @@ impl<'de> Deserializer<'de> { } macro_rules! primitive_impl { - ($ty:ident, $type:expr, $($value:tt)*) => { + ($ty:ident, $type:expr, $cost:literal, $($value:tt)*) => { paste::item! { fn [](self, visitor: V) -> Result where V: Visitor<'de> { self.unroll_type()?; check!(*self.expect_type == $type && *self.wire_type == $type, stringify!($type)); + self.add_cost($cost)?; let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?; visitor.[](val) } @@ -604,16 +617,16 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { v } - primitive_impl!(i8, TypeInner::Int8, read_i8); - primitive_impl!(i16, TypeInner::Int16, read_i16::); - primitive_impl!(i32, TypeInner::Int32, read_i32::); - primitive_impl!(i64, TypeInner::Int64, read_i64::); - primitive_impl!(u8, TypeInner::Nat8, read_u8); - primitive_impl!(u16, TypeInner::Nat16, read_u16::); - primitive_impl!(u32, TypeInner::Nat32, read_u32::); - primitive_impl!(u64, TypeInner::Nat64, read_u64::); - primitive_impl!(f32, TypeInner::Float32, read_f32::); - primitive_impl!(f64, TypeInner::Float64, read_f64::); + primitive_impl!(i8, TypeInner::Int8, 1, read_i8); + primitive_impl!(i16, TypeInner::Int16, 2, read_i16::); + primitive_impl!(i32, TypeInner::Int32, 4, read_i32::); + primitive_impl!(i64, TypeInner::Int64, 8, read_i64::); + primitive_impl!(u8, TypeInner::Nat8, 1, read_u8); + primitive_impl!(u16, TypeInner::Nat16, 2, read_u16::); + primitive_impl!(u32, TypeInner::Nat32, 4, read_u32::); + primitive_impl!(u64, TypeInner::Nat64, 8, read_u64::); + primitive_impl!(f32, TypeInner::Float32, 4, read_f32::); + primitive_impl!(f64, TypeInner::Float64, 8, read_f64::); fn is_human_readable(&self) -> bool { false @@ -625,12 +638,14 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { use crate::types::leb128::{decode_int, decode_nat}; self.unroll_type()?; assert!(*self.expect_type == TypeInner::Int); + let pos = self.input.position(); let value: i128 = match self.wire_type.as_ref() { TypeInner::Int => decode_int(&mut self.input)?, TypeInner::Nat => i128::try_from(decode_nat(&mut self.input)?) .map_err(|_| Error::msg("Cannot convert nat to i128"))?, t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))), }; + self.add_cost((self.input.position() - pos) as usize)?; visitor.visit_i128(value) } fn deserialize_u128(self, visitor: V) -> Result @@ -642,7 +657,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { *self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat, "nat" ); + let pos = self.input.position(); let value = crate::types::leb128::decode_nat(&mut self.input)?; + self.add_cost((self.input.position() - pos) as usize)?; visitor.visit_u128(value) } fn deserialize_unit(self, visitor: V) -> Result @@ -655,6 +672,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { && matches!(*self.wire_type, TypeInner::Null | TypeInner::Reserved), "unit" ); + self.add_cost(1)?; visitor.visit_unit() } fn deserialize_bool(self, visitor: V) -> Result @@ -666,6 +684,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { *self.expect_type == TypeInner::Bool && *self.wire_type == TypeInner::Bool, "bool" ); + self.add_cost(1)?; let res = BoolValue::read(&mut self.input)?; visitor.visit_bool(res.0) } @@ -679,6 +698,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { "text" ); let len = Len::read(&mut self.input)?.0; + self.add_cost(len + 1)?; let bytes = self.borrow_bytes(len)?.to_owned(); let value = String::from_utf8(bytes).map_err(Error::msg)?; visitor.visit_string(value) @@ -693,6 +713,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { "text" ); let len = Len::read(&mut self.input)?.0; + self.add_cost(len + 1)?; let slice = self.borrow_bytes(len)?; let value: &str = std::str::from_utf8(slice).map_err(Error::msg)?; visitor.visit_borrowed_str(value) @@ -701,12 +722,14 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { + self.add_cost(1)?; self.deserialize_unit(visitor) } fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result where V: Visitor<'de>, { + self.add_cost(1)?; visitor.visit_newtype_struct(self) } fn deserialize_option(self, visitor: V) -> Result @@ -714,6 +737,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { self.unroll_type()?; + self.add_cost(1)?; match (self.wire_type.as_ref(), self.expect_type.as_ref()) { (TypeInner::Null | TypeInner::Reserved, TypeInner::Opt(_)) => visitor.visit_none(), (TypeInner::Opt(t1), TypeInner::Opt(t2)) => { @@ -750,17 +774,12 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { check_recursion! { self.unroll_type()?; + self.add_cost(1)?; match (self.expect_type.as_ref(), self.wire_type.as_ref()) { (TypeInner::Vec(e), TypeInner::Vec(w)) => { let expect = e.clone(); let wire = self.table.trace_type(w)?; let len = Len::read(&mut self.input)?.0; - if self.is_zero_sized_type(&wire) { - if self.zero_sized_values < len { - return Err(Error::msg("vec length of zero sized values too large")); - } - self.zero_sized_values -= len; - } visitor.visit_seq(Compound::new(self, Style::Vector { len, expect, wire })) } (TypeInner::Record(e), TypeInner::Record(w)) => { @@ -789,6 +808,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { "vec nat8" ); let len = Len::read(&mut self.input)?.0; + self.add_cost(len + 1)?; let bytes = self.borrow_bytes(len)?.to_owned(); visitor.visit_byte_buf(bytes) } @@ -798,6 +818,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { TypeInner::Principal => self.deserialize_principal(visitor), TypeInner::Vec(t) if **t == TypeInner::Nat8 => { let len = Len::read(&mut self.input)?.0; + self.add_cost(len + 1)?; let slice = self.borrow_bytes(len)?; visitor.visit_borrowed_bytes(slice) } @@ -810,6 +831,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { check_recursion! { self.unroll_type()?; + self.add_cost(1)?; match (self.expect_type.as_ref(), self.wire_type.as_ref()) { (TypeInner::Vec(e), TypeInner::Vec(w)) => { let e = self.table.trace_type(e)?; @@ -848,6 +870,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { check_recursion! { + self.add_cost(1)?; self.deserialize_seq(visitor) } } @@ -861,6 +884,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { V: Visitor<'de>, { check_recursion! { + self.add_cost(1)?; self.deserialize_seq(visitor) } } @@ -875,6 +899,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { check_recursion! { self.unroll_type()?; + self.add_cost(1)?; match (self.expect_type.as_ref(), self.wire_type.as_ref()) { (TypeInner::Record(e), TypeInner::Record(w)) => { let expect = e.clone().into(); @@ -898,6 +923,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { check_recursion! { self.unroll_type()?; + self.add_cost(1)?; match (self.expect_type.as_ref(), self.wire_type.as_ref()) { (TypeInner::Variant(e), TypeInner::Variant(w)) => { let index = Len::read(&mut self.input)?.0; @@ -926,8 +952,14 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { match self.field_name.take() { Some(l) => match l.as_ref() { - Label::Named(name) => visitor.visit_string(name.to_string()), - Label::Id(hash) | Label::Unnamed(hash) => visitor.visit_u32(*hash), + Label::Named(name) => { + self.add_cost(name.len())?; + visitor.visit_string(name.to_string()) + } + Label::Id(hash) | Label::Unnamed(hash) => { + self.add_cost(4)?; + visitor.visit_u32(*hash) + } }, None => assert!(false), } @@ -978,6 +1010,7 @@ impl<'de, 'a> de::SeqAccess<'de> for Compound<'a, 'de> { where T: de::DeserializeSeed<'de>, { + self.de.add_cost(3)?; match self.style { Style::Vector { ref mut len, @@ -1020,6 +1053,7 @@ impl<'de, 'a> de::MapAccess<'de> for Compound<'a, 'de> { where K: de::DeserializeSeed<'de>, { + self.de.add_cost(4)?; match self.style { Style::Struct { ref mut expect, @@ -1093,11 +1127,15 @@ impl<'de, 'a> de::MapAccess<'de> for Compound<'a, 'de> { { match &self.style { Style::Map { expect, wire, .. } => { + self.de.add_cost(3)?; self.de.expect_type = expect.1.clone(); self.de.wire_type = wire.1.clone(); seed.deserialize(&mut *self.de) } - _ => seed.deserialize(&mut *self.de), + _ => { + self.de.add_cost(1)?; + seed.deserialize(&mut *self.de) + } } } } @@ -1110,6 +1148,7 @@ impl<'de, 'a> de::EnumAccess<'de> for Compound<'a, 'de> { where V: de::DeserializeSeed<'de>, { + self.de.add_cost(4)?; match &self.style { Style::Enum { expect, wire } => { self.de.expect_type = expect.ty.clone(); @@ -1143,6 +1182,7 @@ impl<'de, 'a> de::VariantAccess<'de> for Compound<'a, 'de> { *self.de.expect_type == TypeInner::Null && *self.de.wire_type == TypeInner::Null, "unit_variant" ); + self.de.add_cost(1)?; Ok(()) } @@ -1150,6 +1190,7 @@ impl<'de, 'a> de::VariantAccess<'de> for Compound<'a, 'de> { where T: de::DeserializeSeed<'de>, { + self.de.add_cost(1)?; seed.deserialize(self.de) } @@ -1157,6 +1198,7 @@ impl<'de, 'a> de::VariantAccess<'de> for Compound<'a, 'de> { where V: Visitor<'de>, { + self.de.add_cost(1)?; de::Deserializer::deserialize_tuple(self.de, len, visitor) } @@ -1164,6 +1206,7 @@ impl<'de, 'a> de::VariantAccess<'de> for Compound<'a, 'de> { where V: Visitor<'de>, { + self.de.add_cost(1)?; de::Deserializer::deserialize_struct(self.de, "_", fields, visitor) } } diff --git a/rust/candid/tests/serde.rs b/rust/candid/tests/serde.rs index e33bcf59..9c6380c7 100644 --- a/rust/candid/tests/serde.rs +++ b/rust/candid/tests/serde.rs @@ -531,7 +531,7 @@ fn test_vector() { let bytes = hex("4449444c036c01d6fca702016d026c00010080ade204"); check_error( || test_decode(&bytes, &candid::Reserved), - "zero sized values too large", + "Decoding cost exceeds the limit", ); }