From 0f60a617704c8e5725012631f71d889b8ddd072b Mon Sep 17 00:00:00 2001 From: Mateusz Date: Thu, 23 Nov 2023 06:34:29 +0100 Subject: [PATCH] fix parsing trailing message content from CMGR command (#182) --- serde_at/src/de/mod.rs | 141 +++++++++++++++++++++++++++++------------ serde_at/src/de/seq.rs | 30 +++++++-- 2 files changed, 125 insertions(+), 46 deletions(-) diff --git a/serde_at/src/de/mod.rs b/serde_at/src/de/mod.rs index 460e7ea8..83cb33e1 100644 --- a/serde_at/src/de/mod.rs +++ b/serde_at/src/de/mod.rs @@ -67,11 +67,18 @@ pub enum Error { pub(crate) struct Deserializer<'b> { slice: &'b [u8], index: usize, + struct_size_hint: Option, + is_trailing_parsing: bool, } impl<'a> Deserializer<'a> { const fn new(slice: &'a [u8]) -> Deserializer<'_> { - Deserializer { slice, index: 0 } + Deserializer { + slice, + index: 0, + struct_size_hint: None, + is_trailing_parsing: false, + } } fn eat_char(&mut self) { @@ -93,6 +100,14 @@ impl<'a> Deserializer<'a> { None } + fn set_is_trailing_parsing(&mut self) { + self.is_trailing_parsing = true; + } + + fn struct_size_hint(&self) -> Option { + self.struct_size_hint + } + fn parse_ident(&mut self, ident: &[u8]) -> Result<()> { for c in ident { if Some(c) != self.next_char() { @@ -105,43 +120,51 @@ impl<'a> Deserializer<'a> { fn parse_str(&mut self) -> Result<&'a str> { let start = self.index; - loop { - match self.peek() { - Some(b'"') => { - // Counts the number of backslashes in front of the current index. - // - // "some string with \\\" included." - // ^^^^^ - // ||||| - // loop run: 4321| - // | - // `index` - // - // Since we only get in this code branch if we found a " starting the string and `index` is greater - // than the start position, we know the loop will end no later than this point. - let leading_backslashes = |index: usize| -> usize { - let mut count = 0; - loop { - if self.slice[index - count - 1] == b'\\' { - count += 1; - } else { - return count; + if self.is_trailing_parsing { + self.index = self.slice.len(); + return str::from_utf8(&self.slice[start..]) + .map_err(|_| Error::InvalidUnicodeCodePoint); + } else { + loop { + match self.peek() { + Some(b'"') => { + // Counts the number of backslashes in front of the current index. + // + // "some string with \\\" included." + // ^^^^^ + // ||||| + // loop run: 4321| + // | + // `index` + // + // Since we only get in this code branch if we found a " starting the string and `index` is greater + // than the start position, we know the loop will end no later than this point. + let leading_backslashes = |index: usize| -> usize { + let mut count = 0; + loop { + if self.slice[index - count - 1] == b'\\' { + count += 1; + } else { + return count; + } } - } - }; + }; - let is_escaped = leading_backslashes(self.index) % 2 == 1; - if is_escaped { - self.eat_char(); // just continue - } else { - let end = self.index; - self.eat_char(); - return str::from_utf8(&self.slice[start..end]) - .map_err(|_| Error::InvalidUnicodeCodePoint); + let is_escaped = leading_backslashes(self.index) % 2 == 1; + if is_escaped { + self.eat_char(); // just continue + } else { + let end = self.index; + self.eat_char(); + return str::from_utf8(&self.slice[start..end]) + .map_err(|_| Error::InvalidUnicodeCodePoint); + } + } + Some(_) => self.eat_char(), + None => { + return Err(Error::EofWhileParsingString); } } - Some(_) => self.eat_char(), - None => return Err(Error::EofWhileParsingString), } } } @@ -149,14 +172,19 @@ impl<'a> Deserializer<'a> { fn parse_bytes(&mut self) -> Result<&'a [u8]> { let start = self.index; loop { - if let Some(c) = self.peek() { - if (c as char).is_alphanumeric() || (c as char).is_whitespace() { - self.eat_char(); + if self.is_trailing_parsing { + self.index = self.slice.len(); + return Ok(&self.slice[start..]); + } else { + if let Some(c) = self.peek() { + if (c as char).is_alphanumeric() || (c as char).is_whitespace() { + self.eat_char(); + } else { + return Err(Error::EofWhileParsingString); + } } else { - return Err(Error::EofWhileParsingString); + return Ok(&self.slice[start..self.index]); } - } else { - return Ok(&self.slice[start..self.index]); } } } @@ -561,7 +589,7 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { fn deserialize_struct( self, _name: &'static str, - _fields: &'static [&'static str], + fields: &'static [&'static str], visitor: V, ) -> Result where @@ -577,7 +605,11 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { if self.index == self.slice.len() && self.index > 0 { return Err(Error::EofWhileParsingObject); } - self.deserialize_seq(visitor) + self.struct_size_hint = Some(fields.len()); + let result = self.deserialize_seq(visitor); + self.struct_size_hint = None; + + result } fn deserialize_enum( @@ -832,4 +864,29 @@ mod tests { assert_eq!(&res, b"IMP_"); } + + #[test] + fn trailing_cmgr_parsing() { + #[derive(Clone, Debug, Deserialize, PartialEq)] + pub struct Message { + state: String<256>, + sender: String<256>, + size: Option, + date: String<256>, + message: String<256>, + } + + assert_eq!( + crate::from_str( + "+CMGR: \"REC UNREAD\",\"+48788899722\",12,\"23/11/21,13:31:39+04\"\r\nINFO,WWW\"\"a" + ), + Ok(Message { + state: String::try_from("REC UNREAD").unwrap(), + sender: String::try_from("+48788899722").unwrap(), + size: Some(12), + date: String::try_from("23/11/21,13:31:39+04").unwrap(), + message: String::try_from("INFO,WWW\"\"a").unwrap(), + }) + ); + } } diff --git a/serde_at/src/de/seq.rs b/serde_at/src/de/seq.rs index d9ed5f13..e36cd146 100644 --- a/serde_at/src/de/seq.rs +++ b/serde_at/src/de/seq.rs @@ -5,12 +5,20 @@ use crate::de::{Deserializer, Error, Result}; #[allow(clippy::module_name_repetitions)] pub struct SeqAccess<'a, 'b> { first: bool, + count: usize, + len: Option, de: &'a mut Deserializer<'b>, } impl<'a, 'b> SeqAccess<'a, 'b> { pub(crate) fn new(de: &'a mut Deserializer<'b>) -> Self { - SeqAccess { de, first: true } + let len = de.struct_size_hint(); + SeqAccess { + de, + first: true, + len, + count: 0, + } } } @@ -32,7 +40,15 @@ impl<'a, 'de> de::SeqAccess<'de> for SeqAccess<'a, 'de> { if self.first { self.first = false; } else if c != b'+' { - return Ok(None); + if let Some(len) = self.len { + if self.count == len - 1 { + self.de.set_is_trailing_parsing(); + } else { + return Ok(None); + } + } else { + return Ok(None); + } } } None => { @@ -44,9 +60,15 @@ impl<'a, 'de> de::SeqAccess<'de> for SeqAccess<'a, 'de> { match seed.deserialize(&mut *self.de) { // Misuse EofWhileParsingObject here to indicate finished object in vec cases. // See matching TODO in `de::mod`.. - Err(Error::EofWhileParsingObject) => Ok(None), + Err(Error::EofWhileParsingObject) => { + self.count += 1; + Ok(None) + } Err(e) => Err(e), - Ok(v) => Ok(Some(v)), + Ok(v) => { + self.count += 1; + Ok(Some(v)) + } } } }