Skip to content

Commit

Permalink
add skipping limit
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyan-dfinity committed Feb 22, 2024
1 parent 3279530 commit 6146965
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 66 deletions.
21 changes: 12 additions & 9 deletions rust/bench/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod nns;

const N: usize = 2097152;
const COST: usize = 20_000_000;
const SKIP: usize = 10_000;

#[bench(raw)]
fn blob() -> BenchResult {
Expand All @@ -19,7 +20,7 @@ fn blob() -> BenchResult {
};
{
let _p = bench_scope("2. Decoding");
Decode!([COST]; &bytes, ByteBuf).unwrap();
Decode!([COST; SKIP]; &bytes, ByteBuf).unwrap();
}
})
}
Expand All @@ -35,7 +36,7 @@ fn text() -> BenchResult {
};
{
let _p = bench_scope("2. Decoding");
Decode!([COST]; &bytes, String).unwrap();
Decode!([COST; SKIP]; &bytes, String).unwrap();
}
})
}
Expand All @@ -50,7 +51,7 @@ fn vec_int16() -> BenchResult {
};
{
let _p = bench_scope("2. Decoding");
Decode!([COST]; &bytes, Vec<i16>).unwrap();
Decode!([COST; SKIP]; &bytes, Vec<i16>).unwrap();
}
})
}
Expand All @@ -68,7 +69,7 @@ fn btreemap() -> BenchResult {
};
{
let _p = bench_scope("2. Decoding");
Decode!([COST]; &bytes, BTreeMap<String, Nat>).unwrap();
Decode!([COST; SKIP]; &bytes, BTreeMap<String, Nat>).unwrap();
}
})
}
Expand All @@ -94,7 +95,7 @@ fn option_list() -> BenchResult {
};
{
let _p = bench_scope("2. Decoding");
Decode!([COST]; &bytes, Option<Box<List>>).unwrap();
Decode!([COST; SKIP]; &bytes, Option<Box<List>>).unwrap();
}
})
}
Expand All @@ -117,7 +118,7 @@ fn variant_list() -> BenchResult {
};
{
let _p = bench_scope("2. Decoding");
Decode!([COST]; &bytes, VariantList).unwrap();
Decode!([COST; SKIP]; &bytes, VariantList).unwrap();
}
})
}
Expand Down Expand Up @@ -208,16 +209,18 @@ fn nns() -> BenchResult {
};
{
let _p = bench_scope("2. Decoding");
Decode!([COST]; &bytes, nns::ManageNeuron).unwrap();
Decode!([COST; SKIP]; &bytes, nns::ManageNeuron).unwrap();
}
})
}

#[bench(raw)]
fn extra_args() -> BenchResult {
let bytes = hex::decode("4449444c036c01d6fca702016d026c00010080ade204").unwrap();
let vec_null = hex::decode("4449444c036c01d6fca702016d026c00010080ade204").unwrap();
let vec_opt_record = hex::decode("4449444c176c02017f027f6c02010002006c02000101016c02000201026c02000301036c02000401046c02000501056c02000601066c02000701076c02000801086c02000901096c02000a010a6c02000b010b6c02000c010c6c02000d020d6c02000e010e6c02000f010f6c02001001106c02001101116c02001201126c02001301136e146d150116050101010101").unwrap();
bench_fn(|| {
let _ = Decode!([COST]; &bytes);
//assert!(Decode!([COST; SKIP]; &vec_null).is_err());
assert!(Decode!([COST; SKIP]; &vec_opt_record).is_err());
})
}

Expand Down
78 changes: 38 additions & 40 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use std::fmt::Write;
use std::{collections::VecDeque, io::Cursor, mem::replace};

const MAX_TYPE_LEN: i32 = 500;
const DEFAULT_DECODING_COST: usize = 20_000_000;

/// Use this struct to deserialize a sequence of Rust values (heterogeneous) from IDL binary message.
pub struct IDLDeserialize<'de> {
Expand All @@ -29,27 +28,19 @@ pub struct IDLDeserialize<'de> {
impl<'de> IDLDeserialize<'de> {
/// Create a new deserializer with IDL binary message.
pub fn new(bytes: &'de [u8]) -> Result<Self> {
let mut de = Deserializer::from_bytes(bytes).with_context(|| {
if bytes.len() <= 500 {
format!("Cannot parse header {}", &hex::encode(bytes))
} else {
"Cannot parse header".to_string()
}
})?;
de.add_cost(de.input.position() as usize * 2)?;
Ok(IDLDeserialize { de })
let config = DecoderConfig::new();
Self::new_with_config(bytes, config)
}
/// Create a new deserializer with IDL binary message. The config is used to adjust some parameters in the deserializer.
pub fn new_with_config(bytes: &'de [u8], config: DecoderConfig) -> Result<Self> {
let mut de = Deserializer::from_bytes(bytes).with_context(|| {
if config.full_error_message || bytes.len() <= 500 {
let full_error_message = config.full_error_message;
let mut de = Deserializer::from_bytes(bytes, config).with_context(|| {
if full_error_message || bytes.len() <= 500 {
format!("Cannot parse header {}", &hex::encode(bytes))
} else {
"Cannot parse header".to_string()
}
})?;
de.decoding_cost = config.decoding_cost;
de.full_error_message = config.full_error_message;
de.add_cost(de.input.position() as usize * 2)?;
Ok(IDLDeserialize { de })
}
Expand Down Expand Up @@ -86,7 +77,8 @@ impl<'de> IDLDeserialize<'de> {
self.de.expect_type = expected_type;
self.de.wire_type = TypeInner::Reserved.into();
return T::deserialize(&mut self.de);
} else if self.de.full_error_message || text_size(&expected_type, MAX_TYPE_LEN).is_ok()
} else if self.de.config.full_error_message
|| text_size(&expected_type, MAX_TYPE_LEN).is_ok()
{
return Err(Error::msg(format!(
"No more values on the wire, the expected type {expected_type} is not opt, null, or reserved"
Expand All @@ -106,7 +98,7 @@ impl<'de> IDLDeserialize<'de> {
self.de.wire_type = ty.clone();

let mut v = T::deserialize(&mut self.de).with_context(|| {
if self.de.full_error_message
if self.de.config.full_error_message
|| (text_size(&ty, MAX_TYPE_LEN).is_ok()
&& text_size(&expected_type, MAX_TYPE_LEN).is_ok())
{
Expand All @@ -115,7 +107,7 @@ impl<'de> IDLDeserialize<'de> {
format!("Fail to decode argument {ind}")
}
});
if self.de.full_error_message {
if self.de.config.full_error_message {
v = v.with_context(|| self.de.dump_state());
}
Ok(v?)
Expand All @@ -132,7 +124,7 @@ impl<'de> IDLDeserialize<'de> {
let ind = self.de.input.position() as usize;
let rest = &self.de.input.get_ref()[ind..];
if !rest.is_empty() {
if !self.de.full_error_message {
if !self.de.config.full_error_message {
return Err(Error::msg("Trailing value after finishing deserialization"));
} else {
return Err(anyhow!(self.de.dump_state()))
Expand All @@ -143,25 +135,29 @@ impl<'de> IDLDeserialize<'de> {
}
}

#[derive(Clone)]
pub struct DecoderConfig {
decoding_cost: Option<usize>,
decoding_quota: Option<usize>,
skipping_quota: Option<usize>,
full_error_message: bool,
}
impl DecoderConfig {
pub fn new() -> Self {
Self {
decoding_cost: Some(DEFAULT_DECODING_COST),
decoding_quota: None,
skipping_quota: None,
#[cfg(not(target_arch = "wasm32"))]
full_error_message: true,
#[cfg(target_arch = "wasm32")]
full_error_message: false,
}
}
pub fn new_cost(cost: usize) -> Self {
Self {
decoding_cost: Some(cost),
full_error_message: false,
}
pub fn set_decoding_quota(&mut self, n: usize) -> &mut Self {
self.decoding_quota = Some(n);
self
}
pub fn set_decoding_cost(&mut self, n: Option<usize>) -> &mut Self {
self.decoding_cost = n;
pub fn set_skipping_quota(&mut self, n: usize) -> &mut Self {
self.skipping_quota = Some(n);
self
}
pub fn set_full_error_message(&mut self, n: bool) -> &mut Self {
Expand Down Expand Up @@ -245,14 +241,13 @@ struct Deserializer<'de> {
// Indicates whether to deserialize with IDLValue.
// It only affects the field id generation in enum type.
is_untyped: bool,
decoding_cost: Option<usize>,
full_error_message: bool,
config: DecoderConfig,
#[cfg(not(target_arch = "wasm32"))]
recursion_depth: u16,
}

impl<'de> Deserializer<'de> {
fn from_bytes(bytes: &'de [u8]) -> Result<Self> {
fn from_bytes(bytes: &'de [u8], config: DecoderConfig) -> Result<Self> {
let mut reader = Cursor::new(bytes);
let header = Header::read(&mut reader)?;
let (env, types) = header.to_types()?;
Expand All @@ -265,11 +260,7 @@ impl<'de> Deserializer<'de> {
gamma: Gamma::default(),
field_name: None,
is_untyped: false,
decoding_cost: None,
#[cfg(not(target_arch = "wasm32"))]
full_error_message: true,
#[cfg(target_arch = "wasm32")]
full_error_message: false,
config,
#[cfg(not(target_arch = "wasm32"))]
recursion_depth: 0,
})
Expand Down Expand Up @@ -314,7 +305,7 @@ impl<'de> Deserializer<'de> {
&self.expect_type,
)
.with_context(|| {
if self.full_error_message
if self.config.full_error_message
|| (text_size(&self.wire_type, MAX_TYPE_LEN).is_ok()
&& text_size(&self.expect_type, MAX_TYPE_LEN).is_ok())
{
Expand Down Expand Up @@ -347,13 +338,20 @@ impl<'de> Deserializer<'de> {
Ok(())
}
fn add_cost(&mut self, cost: usize) -> Result<()> {
if let Some(n) = self.decoding_cost {
// Double the cost when untyped or skipping values
if let Some(n) = self.config.decoding_quota {
let cost = if self.is_untyped { cost * 50 } else { cost };
if n < cost {
return Err(Error::msg("Decoding cost exceeds the limit"));
}
self.decoding_cost = Some(n - cost);
self.config.decoding_quota = Some(n - cost);
}
if self.is_untyped {
if let Some(n) = self.config.skipping_quota {
if n < cost {
return Err(Error::msg("Skipping cost exceeds the limit"));
}
self.config.skipping_quota = Some(n - cost);
}
}
Ok(())
}
Expand Down Expand Up @@ -536,7 +534,7 @@ impl<'de> Deserializer<'de> {
}
Err(Error::Subtype(_)) => {
// Remember the backtracking cost
self.decoding_cost = self_clone.decoding_cost;
self.config = self_clone.config;
self.add_cost(10)?;
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
Expand Down
6 changes: 4 additions & 2 deletions rust/candid/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ macro_rules! Decode {
.and_then(|mut de| Decode!(@GetValue [] de $($ty,)*)
.and_then(|res| de.done().and(Ok(res))))
}};
( [ $cost:expr ] ; $hex:expr $(,$ty:ty)* ) => {{
$crate::de::IDLDeserialize::new_with_config($hex, $crate::de::DecoderConfig::new_cost($cost))
( [ $cost:expr; $skip:expr ] ; $hex:expr $(,$ty:ty)* ) => {{
let mut config = $crate::de::DecoderConfig::new();
config.set_decoding_quota($cost).set_skipping_quota($skip);
$crate::de::IDLDeserialize::new_with_config($hex, config)
.and_then(|mut de| Decode!(@GetValue [] de $($ty,)*)
.and_then(|res| de.done().and(Ok(res))))
}};
Expand Down
7 changes: 5 additions & 2 deletions rust/candid/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,11 @@ where
T: PartialEq + serde::de::Deserialize<'de> + std::fmt::Debug + CandidType,
{
let cost = 20_000_000;
let decoded_one = decode_one_with_config::<T>(bytes, DecoderConfig::new_cost(cost)).unwrap();
let decoded_macro = Decode!([cost]; bytes, T).unwrap();
let skip = 10_000;
let mut config = DecoderConfig::new();
config.set_decoding_quota(cost).set_skipping_quota(skip);
let decoded_one = decode_one_with_config::<T>(bytes, config).unwrap();
let decoded_macro = Decode!([cost; skip]; bytes, T).unwrap();
assert_eq!(decoded_one, *expected);
assert_eq!(decoded_macro, *expected);
}
Expand Down
25 changes: 12 additions & 13 deletions rust/candid_parser/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ impl Input {
pub fn parse(&self, env: &TypeEnv, types: &[Type]) -> Result<IDLArgs> {
match self {
Input::Text(ref s) => Ok(super::parse_idl_args(s)?.annotate_types(true, env, types)?),
Input::Blob(ref bytes) => Ok(IDLArgs::from_bytes_with_types_with_config(
bytes,
env,
types,
DecoderConfig::new_cost(DECODING_COST),
)?),
Input::Blob(ref bytes) => {
let mut config = DecoderConfig::new();
config.set_decoding_quota(DECODING_COST);
Ok(IDLArgs::from_bytes_with_types_with_config(
bytes, env, types, config,
)?)
}
}
}
fn check_round_trip(&self, v: &IDLArgs, env: &TypeEnv, types: &[Type]) -> Result<bool> {
Expand Down Expand Up @@ -113,13 +114,11 @@ impl HostTest {
if !assert.pass && assert.right.is_none() {
asserts.push(NotDecode(bytes, types));
} else {
let args = IDLArgs::from_bytes_with_types_with_config(
&bytes,
env,
&types,
DecoderConfig::new_cost(DECODING_COST),
)
.unwrap();
let mut config = DecoderConfig::new();
config.set_decoding_quota(DECODING_COST);
let args =
IDLArgs::from_bytes_with_types_with_config(&bytes, env, &types, config)
.unwrap();
asserts.push(Decode(bytes.clone(), types.clone(), true, args));
// round tripping
// asserts.push(Encode(args, types.clone(), true, bytes.clone()));
Expand Down

0 comments on commit 6146965

Please sign in to comment.