Skip to content

Commit

Permalink
Decoder improvements (#2259)
Browse files Browse the repository at this point in the history
* Add a benchmark for the decoder

* Make the bench more rigorous

* Dramatic improvements in varint decoding

* Why was I able to compile without importing size_of?

* I kept too many bits

* Improve benchmark performance

* Make it build

* skip rather than decode

Co-authored-by: Lars Eggert <[email protected]>
Signed-off-by: Martin Thomson <[email protected]>

* Comment about not using Decoder

---------

Signed-off-by: Lars Eggert <[email protected]>
Signed-off-by: Martin Thomson <[email protected]>
Co-authored-by: Lars Eggert <[email protected]>
  • Loading branch information
martinthomson and larseggert authored Dec 8, 2024
1 parent a758177 commit dd8e801
Show file tree
Hide file tree
Showing 16 changed files with 118 additions and 94 deletions.
86 changes: 49 additions & 37 deletions neqo-common/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::fmt::Debug;
use std::{fmt::Debug, mem::size_of};

use crate::hex_with_len;

Expand Down Expand Up @@ -54,7 +54,7 @@ impl<'a> Decoder<'a> {
/// Only use this for tests because we panic rather than reporting a result.
#[cfg(any(test, feature = "test-fixture"))]
pub fn skip_vec(&mut self, n: usize) {
let len = self.decode_uint(n);
let len = self.decode_n(n);
self.skip_inner(len);
}

Expand All @@ -66,16 +66,6 @@ impl<'a> Decoder<'a> {
self.skip_inner(len);
}

/// Decodes (reads) a single byte.
pub fn decode_byte(&mut self) -> Option<u8> {
if self.remaining() < 1 {
return None;
}
let b = self.buf[self.offset];
self.offset += 1;
Some(b)
}

/// Provides the next byte without moving the read position.
#[must_use]
pub const fn peek_byte(&self) -> Option<u8> {
Expand All @@ -96,33 +86,43 @@ impl<'a> Decoder<'a> {
Some(res)
}

/// Decodes an unsigned integer of length 1..=8.
///
/// # Panics
///
/// This panics if `n` is not in the range `1..=8`.
pub fn decode_uint(&mut self, n: usize) -> Option<u64> {
assert!(n > 0 && n <= 8);
#[inline]
pub(crate) fn decode_n(&mut self, n: usize) -> Option<u64> {
debug_assert!(n > 0 && n <= 8);
if self.remaining() < n {
return None;
}
let mut v = 0_u64;
for i in 0..n {
let b = self.buf[self.offset + i];
v = v << 8 | u64::from(b);
}
self.offset += n;
Some(v)
Some(if n == 1 {
let v = u64::from(self.buf[self.offset]);
self.offset += 1;
v
} else {
let mut buf = [0; 8];
buf[8 - n..].copy_from_slice(&self.buf[self.offset..self.offset + n]);
self.offset += n;
u64::from_be_bytes(buf)
})
}

/// Decodes a big-endian, unsigned integer value into the target type.
/// This returns `None` if there is not enough data remaining
/// or if the conversion to the identified type fails.
/// Conversion is via `u64`, so failures are impossible for
/// unsigned integer types: `u8`, `u16`, `u32`, or `u64`.
/// Signed types will fail if the high bit is set.
pub fn decode_uint<T: TryFrom<u64>>(&mut self) -> Option<T> {
let v = self.decode_n(size_of::<T>());
v.and_then(|v| T::try_from(v).ok())
}

/// Decodes a QUIC varint.
pub fn decode_varint(&mut self) -> Option<u64> {
let b1 = self.decode_byte()?;
let b1 = self.decode_n(1)?;
match b1 >> 6 {
0 => Some(u64::from(b1 & 0x3f)),
1 => Some((u64::from(b1 & 0x3f) << 8) | self.decode_uint(1)?),
2 => Some((u64::from(b1 & 0x3f) << 24) | self.decode_uint(3)?),
3 => Some((u64::from(b1 & 0x3f) << 56) | self.decode_uint(7)?),
0 => Some(b1),
1 => Some((b1 & 0x3f) << 8 | self.decode_n(1)?),
2 => Some((b1 & 0x3f) << 24 | self.decode_n(3)?),
3 => Some((b1 & 0x3f) << 56 | self.decode_n(7)?),
_ => unreachable!(),
}
}
Expand All @@ -147,7 +147,7 @@ impl<'a> Decoder<'a> {

/// Decodes a TLS-style length-prefixed buffer.
pub fn decode_vec(&mut self, n: usize) -> Option<&'a [u8]> {
let len = self.decode_uint(n);
let len = self.decode_n(n);
self.decode_checked(len)
}

Expand Down Expand Up @@ -486,16 +486,28 @@ mod tests {
let enc = Encoder::from_hex("0123");
let mut dec = enc.as_decoder();

assert_eq!(dec.decode_byte().unwrap(), 0x01);
assert_eq!(dec.decode_byte().unwrap(), 0x23);
assert!(dec.decode_byte().is_none());
assert_eq!(dec.decode_uint::<u8>().unwrap(), 0x01);
assert_eq!(dec.decode_uint::<u8>().unwrap(), 0x23);
assert!(dec.decode_uint::<u8>().is_none());
}

#[test]
fn peek_byte() {
let enc = Encoder::from_hex("01");
let mut dec = enc.as_decoder();

assert_eq!(dec.offset(), 0);
assert_eq!(dec.peek_byte().unwrap(), 0x01);
dec.skip(1);
assert_eq!(dec.offset(), 1);
assert!(dec.peek_byte().is_none());
}

#[test]
fn decode_byte_short() {
let enc = Encoder::from_hex("");
let mut dec = enc.as_decoder();
assert!(dec.decode_byte().is_none());
assert!(dec.decode_uint::<u8>().is_none());
}

#[test]
Expand All @@ -506,7 +518,7 @@ mod tests {
assert!(dec.decode(2).is_none());

let mut dec = Decoder::from(&[]);
assert_eq!(dec.decode_remainder().len(), 0);
assert!(dec.decode_remainder().is_empty());
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions neqo-common/src/incrdecoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ impl IncrementalDecoderUint {
if amount < 8 {
self.v <<= amount * 8;
}
self.v |= dv.decode_uint(amount).unwrap();
self.v |= dv.decode_n(amount).unwrap();
*r -= amount;
if *r == 0 {
Some(self.v)
} else {
None
}
} else {
let (v, remaining) = dv.decode_byte().map_or_else(
let (v, remaining) = dv.decode_uint::<u8>().map_or_else(
|| unreachable!(),
|b| {
(
Expand Down
4 changes: 3 additions & 1 deletion neqo-http3/src/frames/hframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ impl HFrame {
Self::PriorityUpdateRequest { .. } => H3_FRAME_TYPE_PRIORITY_UPDATE_REQUEST,
Self::PriorityUpdatePush { .. } => H3_FRAME_TYPE_PRIORITY_UPDATE_PUSH,
Self::Grease => {
HFrameType(Decoder::from(&random::<7>()).decode_uint(7).unwrap() * 0x1f + 0x21)
let r = Decoder::from(&random::<8>()).decode_uint::<u64>().unwrap();
// Zero out the top 7 bits: 2 for being a varint; 5 to account for the *0x1f.
HFrameType((r >> 7) * 0x1f + 0x21)
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions neqo-http3/src/frames/wtframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ impl FrameDecoder<Self> for WebTransportFrame {
if frame_len > WT_FRAME_CLOSE_MAX_MESSAGE_SIZE + 4 {
return Err(Error::HttpMessageError);
}
let error =
u32::try_from(dec.decode_uint(4).ok_or(Error::HttpMessageError)?).unwrap();
let error = dec.decode_uint().ok_or(Error::HttpMessageError)?;
let Ok(message) = String::from_utf8(dec.decode_remainder().to_vec()) else {
return Err(Error::HttpMessageError);
};
Expand Down
4 changes: 2 additions & 2 deletions neqo-transport/src/addr_valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ impl AddressValidation {
let peer_addr = Self::encode_aad(peer_address, retry);
let data = self.self_encrypt.open(peer_addr.as_ref(), token).ok()?;
let mut dec = Decoder::new(&data);
match dec.decode_uint(4) {
match dec.decode_uint::<u32>() {
Some(d) => {
let end = self.start_time + Duration::from_millis(d);
let end = self.start_time + Duration::from_millis(u64::from(d));
if end < now {
qtrace!("Expired token: {:?} vs. {:?}", end, now);
return None;
Expand Down
7 changes: 4 additions & 3 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,10 @@ impl Connection {
);
let mut dec = Decoder::from(token.as_ref());

let version = Version::try_from(u32::try_from(
dec.decode_uint(4).ok_or(Error::InvalidResumptionToken)?,
)?)?;
let version = Version::try_from(
dec.decode_uint::<WireVersion>()
.ok_or(Error::InvalidResumptionToken)?,
)?;
qtrace!([self], " version {:?}", version);
if !self.conn_params.get_versions().all().contains(&version) {
return Err(Error::DisabledVersion);
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ impl<'a> Frame<'a> {
return Err(Error::FrameEncodingError);
}
let delay = dv(dec)?;
let ignore_order = match d(dec.decode_uint(1))? {
let ignore_order = match d(dec.decode_uint::<u8>())? {
0 => false,
1 => true,
_ => return Err(Error::FrameEncodingError),
Expand Down
6 changes: 3 additions & 3 deletions neqo-transport/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ impl<'a> PublicPacket<'a> {
#[allow(clippy::similar_names)]
pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> {
let mut decoder = Decoder::new(data);
let first = Self::opt(decoder.decode_byte())?;
let first = Self::opt(decoder.decode_uint::<u8>())?;

if first & 0x80 == PACKET_BIT_SHORT {
// Conveniently, this also guarantees that there is enough space
Expand Down Expand Up @@ -638,7 +638,7 @@ impl<'a> PublicPacket<'a> {
}

// Generic long header.
let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?;
let version = Self::opt(decoder.decode_uint())?;
let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);

Expand Down Expand Up @@ -893,7 +893,7 @@ impl<'a> PublicPacket<'a> {
let mut decoder = Decoder::new(&self.data[self.header_len..]);
let mut res = Vec::new();
while decoder.remaining() > 0 {
let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?;
let version = Self::opt(decoder.decode_uint::<WireVersion>())?;
res.push(version);
}
Ok(res)
Expand Down
24 changes: 14 additions & 10 deletions neqo-transport/src/shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,33 @@ pub fn find_sni(buf: &[u8]) -> Option<Range<usize>> {
}

#[must_use]
fn skip_vec<const N: usize>(dec: &mut Decoder) -> Option<()> {
let len = dec.decode_uint(N)?.try_into().ok()?;
skip(dec, len)
fn skip_vec<T>(dec: &mut Decoder) -> Option<()>
where
T: TryFrom<u64>,
usize: TryFrom<T>,
{
let len = dec.decode_uint::<T>()?;
skip(dec, usize::try_from(len).ok()?)
}

let mut dec = Decoder::from(buf);

// Return if buf is empty or does not contain a ClientHello (first byte == 1)
if buf.is_empty() || dec.decode_byte()? != 1 {
if buf.is_empty() || dec.decode_uint::<u8>()? != 1 {
return None;
}
skip(&mut dec, 3 + 2 + 32)?; // Skip length, version, random
skip_vec::<1>(&mut dec)?; // Skip session_id
skip_vec::<2>(&mut dec)?; // Skip cipher_suites
skip_vec::<1>(&mut dec)?; // Skip compression_methods
skip_vec::<u8>(&mut dec)?; // Skip session_id
skip_vec::<u16>(&mut dec)?; // Skip cipher_suites
skip_vec::<u8>(&mut dec)?; // Skip compression_methods
skip(&mut dec, 2)?;

while dec.remaining() >= 4 {
let ext_type: u16 = dec.decode_uint(2)?.try_into().ok()?;
let ext_len: u16 = dec.decode_uint(2)?.try_into().ok()?;
let ext_type: u16 = dec.decode_uint()?;
let ext_len: u16 = dec.decode_uint()?;
if ext_type == 0 {
// SNI!
let sni_len: u16 = dec.decode_uint(2)?.try_into().ok()?;
let sni_len: u16 = dec.decode_uint()?;
skip(&mut dec, 3)?; // Skip name_type and host_name length
let start = dec.offset();
let end = start + usize::from(sni_len) - 3;
Expand Down
11 changes: 5 additions & 6 deletions neqo-transport/src/tparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl TransportParameter {
fn decode_preferred_address(d: &mut Decoder) -> Res<Self> {
// IPv4 address (maybe)
let v4ip = Ipv4Addr::from(<[u8; 4]>::try_from(d.decode(4).ok_or(Error::NoMoreData)?)?);
let v4port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?)?;
let v4port = d.decode_uint::<u16>().ok_or(Error::NoMoreData)?;
// Can't have non-zero IP and zero port, or vice versa.
if v4ip.is_unspecified() ^ (v4port == 0) {
return Err(Error::TransportParameterError);
Expand All @@ -200,7 +200,7 @@ impl TransportParameter {
let v6ip = Ipv6Addr::from(<[u8; 16]>::try_from(
d.decode(16).ok_or(Error::NoMoreData)?,
)?);
let v6port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?)?;
let v6port = d.decode_uint().ok_or(Error::NoMoreData)?;
if v6ip.is_unspecified() ^ (v6port == 0) {
return Err(Error::TransportParameterError);
}
Expand Down Expand Up @@ -229,11 +229,11 @@ impl TransportParameter {

fn decode_versions(dec: &mut Decoder) -> Res<Self> {
fn dv(dec: &mut Decoder) -> Res<WireVersion> {
let v = dec.decode_uint(4).ok_or(Error::NoMoreData)?;
let v = dec.decode_uint::<WireVersion>().ok_or(Error::NoMoreData)?;
if v == 0 {
Err(Error::TransportParameterError)
} else {
Ok(WireVersion::try_from(v)?)
Ok(v)
}
}

Expand Down Expand Up @@ -457,8 +457,7 @@ impl TransportParameters {
let rbuf = random::<4>();
let mut other = Vec::with_capacity(versions.all().len() + 1);
let mut dec = Decoder::new(&rbuf);
let grease =
(u32::try_from(dec.decode_uint(4).unwrap()).unwrap()) & 0xf0f0_f0f0 | 0x0a0a_0a0a;
let grease = dec.decode_uint::<u32>().unwrap() & 0xf0f0_f0f0 | 0x0a0a_0a0a;
other.push(grease);
for &v in versions.all() {
if role == Role::Client && !versions.initial().is_compatible(v) {
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/tracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ mod tests {
assert_eq!(stats.ack, 1);

let mut dec = builder.as_decoder();
_ = dec.decode_byte().unwrap(); // Skip the short header.
dec.skip(1); // Skip the short header.
let frame = Frame::decode(&mut dec).unwrap();
if let Frame::Ack { ack_ranges, .. } = frame {
assert_eq!(ack_ranges.len(), 0);
Expand Down
7 changes: 4 additions & 3 deletions neqo-transport/tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use neqo_crypto::{
};
use neqo_transport::{
server::{ConnectionRef, Server, ValidateAddress},
version::WireVersion,
CloseReason, Connection, ConnectionParameters, Error, Output, State, StreamType, Version,
MIN_INITIAL_PACKET_SIZE,
};
Expand Down Expand Up @@ -584,13 +585,13 @@ fn version_negotiation_ignored() {
let vn = vn.expect("a vn packet");
let mut dec = Decoder::from(&vn[1..]); // Skip first byte.

assert_eq!(dec.decode_uint(4).expect("VN"), 0);
assert_eq!(dec.decode_uint::<u32>().expect("VN"), 0);
assert_eq!(dec.decode_vec(1).expect("VN DCID"), &s_cid[..]);
assert_eq!(dec.decode_vec(1).expect("VN SCID"), &d_cid[..]);
let mut found = false;
while dec.remaining() > 0 {
let v = dec.decode_uint(4).expect("supported version");
found |= v == u64::from(Version::default().wire_version());
let v = dec.decode_uint::<WireVersion>().expect("supported version");
found |= v == Version::default().wire_version();
}
assert!(found, "valid version not found");

Expand Down
Loading

0 comments on commit dd8e801

Please sign in to comment.