diff --git a/const-oid/src/encoder.rs b/const-oid/src/encoder.rs index 5f9401aa6..bd9793305 100644 --- a/const-oid/src/encoder.rs +++ b/const-oid/src/encoder.rs @@ -73,27 +73,9 @@ impl Encoder { self.cursor = 1; Ok(self) } - // TODO(tarcieri): finer-grained overflow safety / checked arithmetic - #[allow(clippy::arithmetic_side_effects)] State::Body => { - // Total number of bytes in encoded arc - 1 let nbytes = base128_len(arc); - - // Shouldn't overflow on any 16-bit+ architectures - if self.cursor + nbytes + 1 > MAX_SIZE { - return Err(Error::Length); - } - - let new_cursor = self.cursor + nbytes + 1; - - // TODO(tarcieri): use `?` when stable in `const fn` - match self.encode_base128_byte(arc, nbytes, false) { - Ok(mut encoder) => { - encoder.cursor = new_cursor; - Ok(encoder) - } - Err(err) => Err(err), - } + self.encode_base128(arc, nbytes) } } } @@ -113,22 +95,19 @@ impl Encoder { } /// Encode a single byte of a Base 128 value. - const fn encode_base128_byte(mut self, mut n: u32, i: usize, continued: bool) -> Result { - let mask = if continued { 0b10000000 } else { 0 }; - - // Underflow checked by branch - #[allow(clippy::arithmetic_side_effects)] - if n > 0x80 { - self.bytes[checked_add!(self.cursor, i)] = (n & 0b1111111) as u8 | mask; - n >>= 7; - - if i > 0 { - self.encode_base128_byte(n, i.saturating_sub(1), true) - } else { - Err(Error::Base128) - } + const fn encode_base128(mut self, n: u32, remaining_len: usize) -> Result { + if self.cursor >= MAX_SIZE { + return Err(Error::Length); + } + + let mask = if remaining_len > 0 { 0b10000000 } else { 0 }; + let (hi, lo) = split_high_bits(n); + self.bytes[self.cursor] = hi | mask; + self.cursor = checked_add!(self.cursor, 1); + + if remaining_len > 0 { + self.encode_base128(lo, remaining_len - 1) } else { - self.bytes[self.cursor] = n as u8 | mask; Ok(self) } } @@ -145,6 +124,29 @@ const fn base128_len(arc: Arc) -> usize { } } +/// Split the highest 7-bits of an [`Arc`] from the rest of an arc. +/// +/// Returns: `(hi, lo)` +// TODO(tarcieri): always use checked arithmetic +#[allow(clippy::arithmetic_side_effects)] +const fn split_high_bits(arc: Arc) -> (u8, Arc) { + if arc < 0x80 { + return (arc as u8, 0); + } + + let hi_bit = 32 - arc.leading_zeros(); + let hi_bit_mod7 = hi_bit % 7; + let upper_bit_pos = hi_bit + - if hi_bit > 0 && hi_bit_mod7 == 0 { + 7 + } else { + hi_bit_mod7 + }; + let upper_bits = arc >> upper_bit_pos; + let lower_bits = arc ^ (upper_bits << upper_bit_pos); + (upper_bits as u8, lower_bits) +} + #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { diff --git a/const-oid/tests/oid.rs b/const-oid/tests/oid.rs index 7172fc160..bebbbadf5 100644 --- a/const-oid/tests/oid.rs +++ b/const-oid/tests/oid.rs @@ -23,10 +23,16 @@ const EXAMPLE_OID_2_BER: &[u8] = &hex!("60864801650304012A"); const EXAMPLE_OID_2: ObjectIdentifier = ObjectIdentifier::new_unwrap(EXAMPLE_OID_2_STR); /// Example OID value with a large arc -const EXAMPLE_OID_LARGE_ARC_STR: &str = "0.9.2342.19200300.100.1.1"; -const EXAMPLE_OID_LARGE_ARC_BER: &[u8] = &hex!("0992268993F22C640101"); -const EXAMPLE_OID_LARGE_ARC: ObjectIdentifier = - ObjectIdentifier::new_unwrap("0.9.2342.19200300.100.1.1"); +const EXAMPLE_OID_LARGE_ARC_0_STR: &str = "1.2.16384"; +const EXAMPLE_OID_LARGE_ARC_0_BER: &[u8] = &hex!("2A818000"); +const EXAMPLE_OID_LARGE_ARC_0: ObjectIdentifier = + ObjectIdentifier::new_unwrap(crate::EXAMPLE_OID_LARGE_ARC_0_STR); + +/// Example OID value with a large arc +const EXAMPLE_OID_LARGE_ARC_1_STR: &str = "0.9.2342.19200300.100.1.1"; +const EXAMPLE_OID_LARGE_ARC_1_BER: &[u8] = &hex!("0992268993F22C640101"); +const EXAMPLE_OID_LARGE_ARC_1: ObjectIdentifier = + ObjectIdentifier::new_unwrap(EXAMPLE_OID_LARGE_ARC_1_STR); /// Create an OID from a string. pub fn oid(s: &str) -> ObjectIdentifier { @@ -38,27 +44,37 @@ fn from_bytes() { let oid0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_0_BER).unwrap(); assert_eq!(oid0.arc(0).unwrap(), 0); assert_eq!(oid0.arc(1).unwrap(), 9); + assert_eq!(oid0.arc(2).unwrap(), 2342); assert_eq!(oid0, EXAMPLE_OID_0); let oid1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_1_BER).unwrap(); assert_eq!(oid1.arc(0).unwrap(), 1); assert_eq!(oid1.arc(1).unwrap(), 2); + assert_eq!(oid1.arc(2).unwrap(), 840); assert_eq!(oid1, EXAMPLE_OID_1); let oid2 = ObjectIdentifier::from_bytes(EXAMPLE_OID_2_BER).unwrap(); assert_eq!(oid2.arc(0).unwrap(), 2); assert_eq!(oid2.arc(1).unwrap(), 16); + assert_eq!(oid2.arc(2).unwrap(), 840); assert_eq!(oid2, EXAMPLE_OID_2); - let oid3 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_BER).unwrap(); - assert_eq!(oid3.arc(0).unwrap(), 0); - assert_eq!(oid3.arc(1).unwrap(), 9); - assert_eq!(oid3.arc(2).unwrap(), 2342); - assert_eq!(oid3.arc(3).unwrap(), 19200300); - assert_eq!(oid3.arc(4).unwrap(), 100); - assert_eq!(oid3.arc(5).unwrap(), 1); - assert_eq!(oid3.arc(6).unwrap(), 1); - assert_eq!(oid3, EXAMPLE_OID_LARGE_ARC); + let oid_largearc0 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_0_BER).unwrap(); + assert_eq!(oid_largearc0.arc(0).unwrap(), 1); + assert_eq!(oid_largearc0.arc(1).unwrap(), 2); + assert_eq!(oid_largearc0.arc(2).unwrap(), 16384); + assert_eq!(oid_largearc0.arc(3), None); + assert_eq!(oid_largearc0, EXAMPLE_OID_LARGE_ARC_0); + + let oid_largearc1 = ObjectIdentifier::from_bytes(EXAMPLE_OID_LARGE_ARC_1_BER).unwrap(); + assert_eq!(oid_largearc1.arc(0).unwrap(), 0); + assert_eq!(oid_largearc1.arc(1).unwrap(), 9); + assert_eq!(oid_largearc1.arc(2).unwrap(), 2342); + assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300); + assert_eq!(oid_largearc1.arc(4).unwrap(), 100); + assert_eq!(oid_largearc1.arc(5).unwrap(), 1); + assert_eq!(oid_largearc1.arc(6).unwrap(), 1); + assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1); // Empty assert_eq!(ObjectIdentifier::from_bytes(&[]), Err(Error::Empty)); @@ -81,17 +97,25 @@ fn from_str() { assert_eq!(oid2.arc(1).unwrap(), 16); assert_eq!(oid2, EXAMPLE_OID_2); - let oid3 = EXAMPLE_OID_LARGE_ARC_STR + let oid_largearc0 = EXAMPLE_OID_LARGE_ARC_0_STR .parse::() .unwrap(); - assert_eq!(oid3.arc(0).unwrap(), 0); - assert_eq!(oid3.arc(1).unwrap(), 9); - assert_eq!(oid3.arc(2).unwrap(), 2342); - assert_eq!(oid3.arc(3).unwrap(), 19200300); - assert_eq!(oid3.arc(4).unwrap(), 100); - assert_eq!(oid3.arc(5).unwrap(), 1); - assert_eq!(oid3.arc(6).unwrap(), 1); - assert_eq!(oid3, EXAMPLE_OID_LARGE_ARC); + assert_eq!(oid_largearc0.arc(0).unwrap(), 1); + assert_eq!(oid_largearc0.arc(1).unwrap(), 2); + assert_eq!(oid_largearc0.arc(2).unwrap(), 16384); + assert_eq!(oid_largearc0, EXAMPLE_OID_LARGE_ARC_0); + + let oid_largearc1 = EXAMPLE_OID_LARGE_ARC_1_STR + .parse::() + .unwrap(); + assert_eq!(oid_largearc1.arc(0).unwrap(), 0); + assert_eq!(oid_largearc1.arc(1).unwrap(), 9); + assert_eq!(oid_largearc1.arc(2).unwrap(), 2342); + assert_eq!(oid_largearc1.arc(3).unwrap(), 19200300); + assert_eq!(oid_largearc1.arc(4).unwrap(), 100); + assert_eq!(oid_largearc1.arc(5).unwrap(), 1); + assert_eq!(oid_largearc1.arc(6).unwrap(), 1); + assert_eq!(oid_largearc1, EXAMPLE_OID_LARGE_ARC_1); // Truncated assert_eq!( @@ -117,7 +141,10 @@ fn display() { assert_eq!(EXAMPLE_OID_0.to_string(), EXAMPLE_OID_0_STR); assert_eq!(EXAMPLE_OID_1.to_string(), EXAMPLE_OID_1_STR); assert_eq!(EXAMPLE_OID_2.to_string(), EXAMPLE_OID_2_STR); - assert_eq!(EXAMPLE_OID_LARGE_ARC.to_string(), EXAMPLE_OID_LARGE_ARC_STR); + assert_eq!( + EXAMPLE_OID_LARGE_ARC_1.to_string(), + EXAMPLE_OID_LARGE_ARC_1_STR + ); } #[test]