Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyan-dfinity committed Feb 4, 2024
1 parent 0498d89 commit a33dcc9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
15 changes: 10 additions & 5 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::fmt::Write;
use std::{collections::VecDeque, io::Cursor, mem::replace};

const MAX_TYPE_LEN: i32 = 500;
const DEFAULT_VALUE_COST: usize = 3_000_000;

/// Use this struct to deserialize a sequence of Rust values (heterogeneous) from IDL binary message.
pub struct IDLDeserialize<'de> {
Expand Down Expand Up @@ -138,6 +139,10 @@ impl<'de> IDLDeserialize<'de> {
}
Ok(())
}
/// Return the value cost on the wire
pub fn value_cost(&self) -> usize {
DEFAULT_VALUE_COST - self.de.value_cost
}
}

pub struct Config {
Expand All @@ -147,7 +152,7 @@ pub struct Config {
impl Config {
pub fn new() -> Self {
Self {
value_cost: 2_000_000,
value_cost: DEFAULT_VALUE_COST,
full_error_message: true,
}
}
Expand Down Expand Up @@ -256,7 +261,7 @@ impl<'de> Deserializer<'de> {
gamma: Gamma::default(),
field_name: None,
is_untyped: false,
value_cost: 2_000_000,
value_cost: DEFAULT_VALUE_COST,
#[cfg(not(target_arch = "wasm32"))]
full_error_message: true,
#[cfg(target_arch = "wasm32")]
Expand Down Expand Up @@ -440,7 +445,7 @@ impl<'de> Deserializer<'de> {
let id = PrincipalBytes::read(&mut self.input)?;
let len = Len::read(&mut self.input)?.0;
let meth = self.borrow_bytes(len)?;
self.dec_value_cost(id.len as usize + len + 1)?;
self.dec_value_cost(id.len as usize + len + 3)?;
// TODO find a better way
leb128::write::unsigned(&mut bytes, len as u64)?;
bytes.extend_from_slice(meth);
Expand Down Expand Up @@ -804,7 +809,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
"vec nat8"
);
let len = Len::read(&mut self.input)?.0;
self.dec_value_cost(len as usize + 1)?;
self.dec_value_cost(len + 1)?;
let bytes = self.borrow_bytes(len)?.to_owned();
visitor.visit_byte_buf(bytes)
}
Expand All @@ -814,7 +819,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
TypeInner::Principal => self.deserialize_principal(visitor),
TypeInner::Vec(t) if **t == TypeInner::Nat8 => {
let len = Len::read(&mut self.input)?.0;
self.dec_value_cost(len as usize + 1)?;
self.dec_value_cost(len + 1)?;
let slice = self.borrow_bytes(len)?;
visitor.visit_borrowed_bytes(slice)
}
Expand Down
42 changes: 42 additions & 0 deletions rust/candid/src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ impl IDLArgs {
de.done()?;
Ok(IDLArgs { args })
}
pub fn cost(&self) -> usize {
self.args.iter().map(IDLValue::cost).sum::<usize>()
}
}

impl IDLValue {
Expand Down Expand Up @@ -366,6 +369,45 @@ impl IDLValue {
let args = IDLArgs::from_bytes_with_types(&blob, &TypeEnv::default(), &[T::ty()])?;
Ok(args.args[0].clone())
}
/// Return the value cost of the IDLValue. It can only match the cost reported by the deserializer when decoded untyped and the LEB128 encoding is compact.
pub fn cost(&self) -> usize {
match self {
IDLValue::Null => 1,
IDLValue::Bool(_) => 1,
IDLValue::Text(s) => 1 + s.len(),
IDLValue::Int(i) => {
let mut leb = Vec::new();
i.encode(&mut leb).unwrap();
leb.len()
}
IDLValue::Nat(n) => {
let mut leb = Vec::new();
n.encode(&mut leb).unwrap();
leb.len()
}
IDLValue::Nat8(_) => 1,
IDLValue::Nat16(_) => 2,
IDLValue::Nat32(_) => 4,
IDLValue::Nat64(_) => 8,
IDLValue::Int8(_) => 1,
IDLValue::Int16(_) => 2,
IDLValue::Int32(_) => 4,
IDLValue::Int64(_) => 8,
IDLValue::Float32(_) => 4,
IDLValue::Float64(_) => 8,
IDLValue::None => 1,
IDLValue::Opt(v) => 1 + v.cost(),
IDLValue::Vec(vec) => 1 + vec.iter().map(IDLValue::cost).sum::<usize>(),
IDLValue::Record(vec) => 1 + vec.iter().map(|f| 1 + f.val.cost()).sum::<usize>(),
IDLValue::Variant(v) => 1 + v.0.val.cost(),
IDLValue::Blob(blob) => 1 + blob.len(),
IDLValue::Principal(id) => 1 + id.as_slice().len(),
IDLValue::Service(id) => 1 + id.as_slice().len(),
IDLValue::Func(id, meth) => 3 + id.as_slice().len() + meth.len(),
IDLValue::Reserved => 1,
IDLValue::Number(_) => unreachable!(),
}
}
}

impl crate::CandidType for IDLValue {
Expand Down
9 changes: 9 additions & 0 deletions rust/candid/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,15 @@ where
let decoded_macro = Decode!(bytes, T).unwrap();
assert_eq!(decoded_one, *expected);
assert_eq!(decoded_macro, *expected);
#[cfg(feature = "value")]
{
use candid::types::value::IDLValue;
let mut deserializer = candid::de::IDLDeserialize::new(bytes).unwrap();
let decoded = deserializer.get_value::<IDLValue>().unwrap();
let cost = deserializer.value_cost();
let _ = deserializer.done().unwrap();
assert_eq!(cost, decoded.cost());
}
}

fn encode<T: CandidType>(value: &T) -> Vec<u8> {
Expand Down

0 comments on commit a33dcc9

Please sign in to comment.