From 94144bf8fd5faa144ded1e4c4d65f00ccdf7ad6c Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Sat, 2 Nov 2024 16:45:55 -0600 Subject: [PATCH] const-oid: fix (and simplify) base 128 encoder This changes the base 128 decoder to calculate the length of a base 128-encoded arc and then iterates over each byte, computing the value for that byte, without any mutable state other than the position. It also refactors the unit tests and adds an example extracted from proptest failures. The new implementation passes that test. --- const-oid/src/arcs.rs | 4 +- const-oid/src/encoder.rs | 90 ++++++++++++------------------ const-oid/src/parser.rs | 2 +- const-oid/tests/oid.rs | 115 ++++++++++++++++++++++----------------- 4 files changed, 103 insertions(+), 108 deletions(-) diff --git a/const-oid/src/arcs.rs b/const-oid/src/arcs.rs index f245845f..7e5056ce 100644 --- a/const-oid/src/arcs.rs +++ b/const-oid/src/arcs.rs @@ -26,8 +26,8 @@ pub(crate) const ARC_MAX_SECOND: Arc = 39; /// Maximum number of bytes supported in an arc. /// -/// Note that OIDs are LEB128 encoded (i.e. base 128), so we must consider how many bytes are -/// required when each byte can only represent 7-bits of the input. +/// Note that OIDs are base 128 encoded (with continuation bits), so we must consider how many bytes +/// are required when each byte can only represent 7-bits of the input. const ARC_MAX_BYTES: usize = (Arc::BITS as usize).div_ceil(7); /// Maximum value of the last byte in an arc. diff --git a/const-oid/src/encoder.rs b/const-oid/src/encoder.rs index 1081dcd8..297cc35f 100644 --- a/const-oid/src/encoder.rs +++ b/const-oid/src/encoder.rs @@ -24,7 +24,7 @@ enum State { /// Initial state - no arcs yet encoded. Initial, - /// First arc parsed. + /// First arc has been supplied and stored as the wrapped [`Arc`]. FirstArc(Arc), /// Encoding base 128 body of the OID. @@ -83,10 +83,7 @@ impl Encoder { self.cursor = 1; Ok(self) } - State::Body => { - let nbytes = base128_len(arc); - self.encode_base128(arc, nbytes) - } + State::Body => self.encode_base128(arc), } } @@ -104,64 +101,44 @@ impl Encoder { Ok(ObjectIdentifier { ber }) } - /// Encode a single byte of a Base 128 value. - const fn encode_base128(mut self, n: u32, remaining_len: usize) -> Result { - if self.cursor >= MAX_SIZE { + /// Encode base 128. + const fn encode_base128(mut self, arc: Arc) -> Result { + let nbytes = base128_len(arc); + let end_pos = checked_add!(self.cursor, nbytes); + + if end_pos > MAX_SIZE { return Err(Error::Length); } - let mask = if remaining_len > 0 { 0b10000000 } else { 0 }; - let (hi, lo) = split_hi_bits(n); - self.bytes[self.cursor] = hi | mask; - self.cursor = checked_add!(self.cursor, 1); - - match remaining_len.checked_sub(1) { - Some(len) => self.encode_base128(lo, len), - None => Ok(self), + let mut i = 0; + while i < nbytes { + self.bytes[self.cursor] = base128_byte(arc, i, nbytes); + self.cursor = checked_add!(self.cursor, 1); + i = checked_add!(i, 1); } + + Ok(self) } } -/// Compute the length - 1 of an arc when encoded in base 128. +/// Compute the length of an arc when encoded in base 128. const fn base128_len(arc: Arc) -> usize { match arc { - 0..=0x7f => 0, - 0x80..=0x3fff => 1, - 0x4000..=0x1fffff => 2, - 0x200000..=0x1fffffff => 3, - _ => 4, + 0..=0x7f => 1, + 0x80..=0x3fff => 2, + 0x4000..=0x1fffff => 3, + 0x200000..=0x1fffffff => 4, + _ => 5, } } -/// Split the highest 7-bits of an [`Arc`] from the rest of an arc. -/// -/// Returns: `(hi, lo)` -#[inline] -const fn split_hi_bits(arc: Arc) -> (u8, Arc) { - if arc < 0x80 { - return (arc as u8, 0); - } - - let hi_bit = match 32u32.checked_sub(arc.leading_zeros()) { - Some(bit) => bit, - None => unreachable!(), - }; - - let hi_bit_mod7 = hi_bit % 7; - let upper_bit_offset = if hi_bit > 0 && hi_bit_mod7 == 0 { - 7 - } else { - hi_bit_mod7 - }; - - let upper_bit_pos = match hi_bit.checked_sub(upper_bit_offset) { - Some(bit) => bit, - None => unreachable!(), - }; - - let upper_bits = arc >> upper_bit_pos; - let lower_bits = arc ^ (upper_bits << upper_bit_pos); - (upper_bits as u8, lower_bits) +/// Compute the big endian base 128 encoding of the given [`Arc`] at the given byte. +const fn base128_byte(arc: Arc, pos: usize, total: usize) -> u8 { + debug_assert!(pos < total); + let last_byte = (pos + 1) == total; + let mask = if last_byte { 0 } else { 0b10000000 }; + let shift = (total - pos - 1) * 7; + ((arc >> shift) & 0b1111111) as u8 | mask } #[cfg(test)] @@ -174,9 +151,14 @@ mod tests { const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201"); #[test] - fn split_hi_bits_with_gaps() { - assert_eq!(super::split_hi_bits(0x3a00002), (0x1d, 0x2)); - assert_eq!(super::split_hi_bits(0x3a08000), (0x1d, 0x8000)); + fn base128_byte() { + let example_arc = 0x44332211; + assert_eq!(super::base128_len(example_arc), 5); + assert_eq!(super::base128_byte(example_arc, 0, 5), 0b10000100); + assert_eq!(super::base128_byte(example_arc, 1, 5), 0b10100001); + assert_eq!(super::base128_byte(example_arc, 2, 5), 0b11001100); + assert_eq!(super::base128_byte(example_arc, 3, 5), 0b11000100); + assert_eq!(super::base128_byte(example_arc, 4, 5), 0b10001); } #[test] diff --git a/const-oid/src/parser.rs b/const-oid/src/parser.rs index 4810294d..5b5155b3 100644 --- a/const-oid/src/parser.rs +++ b/const-oid/src/parser.rs @@ -63,7 +63,7 @@ impl Parser { self.current_arc = match arc.checked_mul(10) { Some(arc) => match arc.checked_add(digit as Arc) { None => return Err(Error::ArcTooBig), - arc => arc, + Some(arc) => Some(arc), }, None => return Err(Error::ArcTooBig), }; diff --git a/const-oid/tests/oid.rs b/const-oid/tests/oid.rs index ad7a0f8e..92bfc49c 100644 --- a/const-oid/tests/oid.rs +++ b/const-oid/tests/oid.rs @@ -29,8 +29,8 @@ const EXAMPLE_OID_LARGE_ARC_0: ObjectIdentifier = ObjectIdentifier::new_unwrap(crate::EXAMPLE_OID_LARGE_ARC_0_STR); /// Example OID value with a large arc -const EXAMPLE_OID_LARGE_ARC_1_STR: &str = "0.9.2342.19200300.100.1.1"; -const EXAMPLE_OID_LARGE_ARC_1_BER: &[u8] = &hex!("0992268993F22C640101"); +const EXAMPLE_OID_LARGE_ARC_1_STR: &str = "1.1.1.60817410.1"; +const EXAMPLE_OID_LARGE_ARC_1_BER: &[u8] = &hex!("29019D80800201"); const EXAMPLE_OID_LARGE_ARC_1: ObjectIdentifier = ObjectIdentifier::new_unwrap(EXAMPLE_OID_LARGE_ARC_1_STR); @@ -45,54 +45,69 @@ pub fn oid(s: &str) -> ObjectIdentifier { ObjectIdentifier::new(s).unwrap() } +/// 0.9.2342.19200300.100.1.1 #[test] -fn from_bytes() { - // 0.9.2342.19200300.100.1.1 - let oid0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_0_BER).unwrap(); - assert_eq!(oid0.arc(0).unwrap(), 0); - assert_eq!(oid0.arc(1).unwrap(), 9); - assert_eq!(oid0.arc(2).unwrap(), 2342); - assert_eq!(oid0, EXAMPLE_OID_0); +fn from_bytes_oid_0() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_0_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_0); + assert_eq!(oid.arc(0).unwrap(), 0); + assert_eq!(oid.arc(1).unwrap(), 9); + assert_eq!(oid.arc(2).unwrap(), 2342); +} - // 1.2.840.10045.2.1 - let oid1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_1_BER).unwrap(); - assert_eq!(oid1.arc(0).unwrap(), 1); - assert_eq!(oid1.arc(1).unwrap(), 2); - assert_eq!(oid1.arc(2).unwrap(), 840); - assert_eq!(oid1, EXAMPLE_OID_1); +/// 1.2.840.10045.2.1 +#[test] +fn from_bytes_oid_1() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_1_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_1); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 2); + assert_eq!(oid.arc(2).unwrap(), 840); +} - // 2.16.840.1.101.3.4.1.42 - let oid2 = ObjectIdentifier::from_bytes(EXAMPLE_OID_2_BER).unwrap(); - assert_eq!(oid2.arc(0).unwrap(), 2); - assert_eq!(oid2.arc(1).unwrap(), 16); - assert_eq!(oid2.arc(2).unwrap(), 840); - assert_eq!(oid2, EXAMPLE_OID_2); +/// 2.16.840.1.101.3.4.1.42 +#[test] +fn from_bytes_oid_2() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_2_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_2); + assert_eq!(oid.arc(0).unwrap(), 2); + assert_eq!(oid.arc(1).unwrap(), 16); + assert_eq!(oid.arc(2).unwrap(), 840); +} - // 1.2.16384 - let oid_largearc0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_0_BER).unwrap(); - assert_eq!(oid_largearc0.arc(0).unwrap(), 1); - assert_eq!(oid_largearc0.arc(1).unwrap(), 2); - assert_eq!(oid_largearc0.arc(2).unwrap(), 16384); - assert_eq!(oid_largearc0.arc(3), None); - assert_eq!(oid_largearc0, EXAMPLE_OID_LARGE_ARC_0); +/// 1.2.16384 +#[test] +fn from_bytes_oid_largearc_0() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_0_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_0); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 2); + assert_eq!(oid.arc(2).unwrap(), 16384); + assert_eq!(oid.arc(3), None); +} - // 0.9.2342.19200300.100.1.1 - let oid_largearc1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_1_BER).unwrap(); - assert_eq!(oid_largearc1.arc(0).unwrap(), 0); - assert_eq!(oid_largearc1.arc(1).unwrap(), 9); - assert_eq!(oid_largearc1.arc(2).unwrap(), 2342); - assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300); - assert_eq!(oid_largearc1.arc(4).unwrap(), 100); - assert_eq!(oid_largearc1.arc(5).unwrap(), 1); - assert_eq!(oid_largearc1.arc(6).unwrap(), 1); - assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1); +/// 1.1.1.60817410.1 +#[test] +fn from_bytes_oid_largearc_1() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_1_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_1); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 1); + assert_eq!(oid.arc(2).unwrap(), 1); + assert_eq!(oid.arc(3).unwrap(), 60817410); + assert_eq!(oid.arc(4).unwrap(), 1); + assert_eq!(oid.arc(5), None); +} - // 1.2.4294967295 - let oid_largearc2 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_2_BER).unwrap(); - assert_eq!(oid_largearc2.arc(0).unwrap(), 1); - assert_eq!(oid_largearc2.arc(1).unwrap(), 2); - assert_eq!(oid_largearc2.arc(2).unwrap(), 4294967295); - assert_eq!(oid_largearc2, EXAMPLE_OID_LARGE_ARC_2); +/// 1.2.4294967295 +#[test] +fn from_bytes_oid_largearc_2() { + let oid = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_2_BER).unwrap(); + assert_eq!(oid, EXAMPLE_OID_LARGE_ARC_2); + assert_eq!(oid.arc(0).unwrap(), 1); + assert_eq!(oid.arc(1).unwrap(), 2); + assert_eq!(oid.arc(2).unwrap(), 4294967295); + assert_eq!(oid.arc(3), None); // Empty assert_eq!(ObjectIdentifier::from_bytes(&[]), Err(Error::Empty)); @@ -126,13 +141,11 @@ fn from_str() { let oid_largearc1 = EXAMPLE_OID_LARGE_ARC_1_STR .parse::() .unwrap(); - assert_eq!(oid_largearc1.arc(0).unwrap(), 0); - assert_eq!(oid_largearc1.arc(1).unwrap(), 9); - assert_eq!(oid_largearc1.arc(2).unwrap(), 2342); - assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300); - assert_eq!(oid_largearc1.arc(4).unwrap(), 100); - assert_eq!(oid_largearc1.arc(5).unwrap(), 1); - assert_eq!(oid_largearc1.arc(6).unwrap(), 1); + assert_eq!(oid_largearc1.arc(0).unwrap(), 1); + assert_eq!(oid_largearc1.arc(1).unwrap(), 1); + assert_eq!(oid_largearc1.arc(2).unwrap(), 1); + assert_eq!(oid_largearc1.arc(3).unwrap(), 60817410); + assert_eq!(oid_largearc1.arc(4).unwrap(), 1); assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1); let oid_largearc2 = EXAMPLE_OID_LARGE_ARC_2_STR