diff --git a/rust/bench/bench.rs b/rust/bench/bench.rs index c32a4097..d2985f67 100644 --- a/rust/bench/bench.rs +++ b/rust/bench/bench.rs @@ -7,6 +7,7 @@ mod nns; const N: usize = 2097152; const COST: usize = 20_000_000; +const SKIP: usize = 10_000; #[bench(raw)] fn blob() -> BenchResult { @@ -19,7 +20,7 @@ fn blob() -> BenchResult { }; { let _p = bench_scope("2. Decoding"); - Decode!([COST]; &bytes, ByteBuf).unwrap(); + Decode!([COST; SKIP]; &bytes, ByteBuf).unwrap(); } }) } @@ -35,7 +36,7 @@ fn text() -> BenchResult { }; { let _p = bench_scope("2. Decoding"); - Decode!([COST]; &bytes, String).unwrap(); + Decode!([COST; SKIP]; &bytes, String).unwrap(); } }) } @@ -50,7 +51,7 @@ fn vec_int16() -> BenchResult { }; { let _p = bench_scope("2. Decoding"); - Decode!([COST]; &bytes, Vec).unwrap(); + Decode!([COST; SKIP]; &bytes, Vec).unwrap(); } }) } @@ -68,7 +69,7 @@ fn btreemap() -> BenchResult { }; { let _p = bench_scope("2. Decoding"); - Decode!([COST]; &bytes, BTreeMap).unwrap(); + Decode!([COST; SKIP]; &bytes, BTreeMap).unwrap(); } }) } @@ -94,7 +95,7 @@ fn option_list() -> BenchResult { }; { let _p = bench_scope("2. Decoding"); - Decode!([COST]; &bytes, Option>).unwrap(); + Decode!([COST; SKIP]; &bytes, Option>).unwrap(); } }) } @@ -117,7 +118,7 @@ fn variant_list() -> BenchResult { }; { let _p = bench_scope("2. Decoding"); - Decode!([COST]; &bytes, VariantList).unwrap(); + Decode!([COST; SKIP]; &bytes, VariantList).unwrap(); } }) } @@ -208,16 +209,18 @@ fn nns() -> BenchResult { }; { let _p = bench_scope("2. Decoding"); - Decode!([COST]; &bytes, nns::ManageNeuron).unwrap(); + Decode!([COST; SKIP]; &bytes, nns::ManageNeuron).unwrap(); } }) } #[bench(raw)] fn extra_args() -> BenchResult { - let bytes = hex::decode("4449444c036c01d6fca702016d026c00010080ade204").unwrap(); + let vec_null = hex::decode("4449444c036c01d6fca702016d026c00010080ade204").unwrap(); + let vec_opt_record = hex::decode("4449444c176c02017f027f6c02010002006c02000101016c02000201026c02000301036c02000401046c02000501056c02000601066c02000701076c02000801086c02000901096c02000a010a6c02000b010b6c02000c010c6c02000d020d6c02000e010e6c02000f010f6c02001001106c02001101116c02001201126c02001301136e146d150116050101010101").unwrap(); bench_fn(|| { - let _ = Decode!([COST]; &bytes); + //assert!(Decode!([COST; SKIP]; &vec_null).is_err()); + assert!(Decode!([COST; SKIP]; &vec_opt_record).is_err()); }) } diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 110e14ed..ed6ee7bd 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -20,7 +20,6 @@ use std::fmt::Write; use std::{collections::VecDeque, io::Cursor, mem::replace}; const MAX_TYPE_LEN: i32 = 500; -const DEFAULT_DECODING_COST: usize = 20_000_000; /// Use this struct to deserialize a sequence of Rust values (heterogeneous) from IDL binary message. pub struct IDLDeserialize<'de> { @@ -29,27 +28,19 @@ pub struct IDLDeserialize<'de> { impl<'de> IDLDeserialize<'de> { /// Create a new deserializer with IDL binary message. pub fn new(bytes: &'de [u8]) -> Result { - let mut de = Deserializer::from_bytes(bytes).with_context(|| { - if bytes.len() <= 500 { - format!("Cannot parse header {}", &hex::encode(bytes)) - } else { - "Cannot parse header".to_string() - } - })?; - de.add_cost(de.input.position() as usize * 2)?; - Ok(IDLDeserialize { de }) + let config = DecoderConfig::new(); + Self::new_with_config(bytes, config) } /// Create a new deserializer with IDL binary message. The config is used to adjust some parameters in the deserializer. pub fn new_with_config(bytes: &'de [u8], config: DecoderConfig) -> Result { - let mut de = Deserializer::from_bytes(bytes).with_context(|| { - if config.full_error_message || bytes.len() <= 500 { + let full_error_message = config.full_error_message; + let mut de = Deserializer::from_bytes(bytes, config).with_context(|| { + if full_error_message || bytes.len() <= 500 { format!("Cannot parse header {}", &hex::encode(bytes)) } else { "Cannot parse header".to_string() } })?; - de.decoding_cost = config.decoding_cost; - de.full_error_message = config.full_error_message; de.add_cost(de.input.position() as usize * 2)?; Ok(IDLDeserialize { de }) } @@ -86,7 +77,8 @@ impl<'de> IDLDeserialize<'de> { self.de.expect_type = expected_type; self.de.wire_type = TypeInner::Reserved.into(); return T::deserialize(&mut self.de); - } else if self.de.full_error_message || text_size(&expected_type, MAX_TYPE_LEN).is_ok() + } else if self.de.config.full_error_message + || text_size(&expected_type, MAX_TYPE_LEN).is_ok() { return Err(Error::msg(format!( "No more values on the wire, the expected type {expected_type} is not opt, null, or reserved" @@ -106,7 +98,7 @@ impl<'de> IDLDeserialize<'de> { self.de.wire_type = ty.clone(); let mut v = T::deserialize(&mut self.de).with_context(|| { - if self.de.full_error_message + if self.de.config.full_error_message || (text_size(&ty, MAX_TYPE_LEN).is_ok() && text_size(&expected_type, MAX_TYPE_LEN).is_ok()) { @@ -115,7 +107,7 @@ impl<'de> IDLDeserialize<'de> { format!("Fail to decode argument {ind}") } }); - if self.de.full_error_message { + if self.de.config.full_error_message { v = v.with_context(|| self.de.dump_state()); } Ok(v?) @@ -132,7 +124,7 @@ impl<'de> IDLDeserialize<'de> { let ind = self.de.input.position() as usize; let rest = &self.de.input.get_ref()[ind..]; if !rest.is_empty() { - if !self.de.full_error_message { + if !self.de.config.full_error_message { return Err(Error::msg("Trailing value after finishing deserialization")); } else { return Err(anyhow!(self.de.dump_state())) @@ -143,25 +135,29 @@ impl<'de> IDLDeserialize<'de> { } } +#[derive(Clone)] pub struct DecoderConfig { - decoding_cost: Option, + decoding_quota: Option, + skipping_quota: Option, full_error_message: bool, } impl DecoderConfig { pub fn new() -> Self { Self { - decoding_cost: Some(DEFAULT_DECODING_COST), + decoding_quota: None, + skipping_quota: None, + #[cfg(not(target_arch = "wasm32"))] + full_error_message: true, + #[cfg(target_arch = "wasm32")] full_error_message: false, } } - pub fn new_cost(cost: usize) -> Self { - Self { - decoding_cost: Some(cost), - full_error_message: false, - } + pub fn set_decoding_quota(&mut self, n: usize) -> &mut Self { + self.decoding_quota = Some(n); + self } - pub fn set_decoding_cost(&mut self, n: Option) -> &mut Self { - self.decoding_cost = n; + pub fn set_skipping_quota(&mut self, n: usize) -> &mut Self { + self.skipping_quota = Some(n); self } pub fn set_full_error_message(&mut self, n: bool) -> &mut Self { @@ -245,14 +241,13 @@ struct Deserializer<'de> { // Indicates whether to deserialize with IDLValue. // It only affects the field id generation in enum type. is_untyped: bool, - decoding_cost: Option, - full_error_message: bool, + config: DecoderConfig, #[cfg(not(target_arch = "wasm32"))] recursion_depth: u16, } impl<'de> Deserializer<'de> { - fn from_bytes(bytes: &'de [u8]) -> Result { + fn from_bytes(bytes: &'de [u8], config: DecoderConfig) -> Result { let mut reader = Cursor::new(bytes); let header = Header::read(&mut reader)?; let (env, types) = header.to_types()?; @@ -265,11 +260,7 @@ impl<'de> Deserializer<'de> { gamma: Gamma::default(), field_name: None, is_untyped: false, - decoding_cost: None, - #[cfg(not(target_arch = "wasm32"))] - full_error_message: true, - #[cfg(target_arch = "wasm32")] - full_error_message: false, + config, #[cfg(not(target_arch = "wasm32"))] recursion_depth: 0, }) @@ -314,7 +305,7 @@ impl<'de> Deserializer<'de> { &self.expect_type, ) .with_context(|| { - if self.full_error_message + if self.config.full_error_message || (text_size(&self.wire_type, MAX_TYPE_LEN).is_ok() && text_size(&self.expect_type, MAX_TYPE_LEN).is_ok()) { @@ -347,13 +338,20 @@ impl<'de> Deserializer<'de> { Ok(()) } fn add_cost(&mut self, cost: usize) -> Result<()> { - if let Some(n) = self.decoding_cost { - // Double the cost when untyped or skipping values + if let Some(n) = self.config.decoding_quota { let cost = if self.is_untyped { cost * 50 } else { cost }; if n < cost { return Err(Error::msg("Decoding cost exceeds the limit")); } - self.decoding_cost = Some(n - cost); + self.config.decoding_quota = Some(n - cost); + } + if self.is_untyped { + if let Some(n) = self.config.skipping_quota { + if n < cost { + return Err(Error::msg("Skipping cost exceeds the limit")); + } + self.config.skipping_quota = Some(n - cost); + } } Ok(()) } @@ -536,7 +534,7 @@ impl<'de> Deserializer<'de> { } Err(Error::Subtype(_)) => { // Remember the backtracking cost - self.decoding_cost = self_clone.decoding_cost; + self.config = self_clone.config; self.add_cost(10)?; self.deserialize_ignored_any(serde::de::IgnoredAny)?; visitor.visit_none() diff --git a/rust/candid/src/utils.rs b/rust/candid/src/utils.rs index 9ffd5db9..5bdb060a 100644 --- a/rust/candid/src/utils.rs +++ b/rust/candid/src/utils.rs @@ -61,8 +61,10 @@ macro_rules! Decode { .and_then(|mut de| Decode!(@GetValue [] de $($ty,)*) .and_then(|res| de.done().and(Ok(res)))) }}; - ( [ $cost:expr ] ; $hex:expr $(,$ty:ty)* ) => {{ - $crate::de::IDLDeserialize::new_with_config($hex, $crate::de::DecoderConfig::new_cost($cost)) + ( [ $cost:expr; $skip:expr ] ; $hex:expr $(,$ty:ty)* ) => {{ + let mut config = $crate::de::DecoderConfig::new(); + config.set_decoding_quota($cost).set_skipping_quota($skip); + $crate::de::IDLDeserialize::new_with_config($hex, config) .and_then(|mut de| Decode!(@GetValue [] de $($ty,)*) .and_then(|res| de.done().and(Ok(res)))) }}; diff --git a/rust/candid/tests/serde.rs b/rust/candid/tests/serde.rs index d1bb4c50..ac23d327 100644 --- a/rust/candid/tests/serde.rs +++ b/rust/candid/tests/serde.rs @@ -775,8 +775,11 @@ where T: PartialEq + serde::de::Deserialize<'de> + std::fmt::Debug + CandidType, { let cost = 20_000_000; - let decoded_one = decode_one_with_config::(bytes, DecoderConfig::new_cost(cost)).unwrap(); - let decoded_macro = Decode!([cost]; bytes, T).unwrap(); + let skip = 10_000; + let mut config = DecoderConfig::new(); + config.set_decoding_quota(cost).set_skipping_quota(skip); + let decoded_one = decode_one_with_config::(bytes, config).unwrap(); + let decoded_macro = Decode!([cost; skip]; bytes, T).unwrap(); assert_eq!(decoded_one, *expected); assert_eq!(decoded_macro, *expected); } diff --git a/rust/candid_parser/src/test.rs b/rust/candid_parser/src/test.rs index e6729116..8bca2bb7 100644 --- a/rust/candid_parser/src/test.rs +++ b/rust/candid_parser/src/test.rs @@ -54,12 +54,13 @@ impl Input { pub fn parse(&self, env: &TypeEnv, types: &[Type]) -> Result { match self { Input::Text(ref s) => Ok(super::parse_idl_args(s)?.annotate_types(true, env, types)?), - Input::Blob(ref bytes) => Ok(IDLArgs::from_bytes_with_types_with_config( - bytes, - env, - types, - DecoderConfig::new_cost(DECODING_COST), - )?), + Input::Blob(ref bytes) => { + let mut config = DecoderConfig::new(); + config.set_decoding_quota(DECODING_COST); + Ok(IDLArgs::from_bytes_with_types_with_config( + bytes, env, types, config, + )?) + } } } fn check_round_trip(&self, v: &IDLArgs, env: &TypeEnv, types: &[Type]) -> Result { @@ -113,13 +114,11 @@ impl HostTest { if !assert.pass && assert.right.is_none() { asserts.push(NotDecode(bytes, types)); } else { - let args = IDLArgs::from_bytes_with_types_with_config( - &bytes, - env, - &types, - DecoderConfig::new_cost(DECODING_COST), - ) - .unwrap(); + let mut config = DecoderConfig::new(); + config.set_decoding_quota(DECODING_COST); + let args = + IDLArgs::from_bytes_with_types_with_config(&bytes, env, &types, config) + .unwrap(); asserts.push(Decode(bytes.clone(), types.clone(), true, args)); // round tripping // asserts.push(Encode(args, types.clone(), true, bytes.clone()));