diff --git a/commons/zenoh-buffers/src/lib.rs b/commons/zenoh-buffers/src/lib.rs index eae7f1715c..117fb412b7 100644 --- a/commons/zenoh-buffers/src/lib.rs +++ b/commons/zenoh-buffers/src/lib.rs @@ -199,6 +199,18 @@ pub mod reader { fn rewind(&mut self, mark: Self::Mark) -> bool; } + pub trait AdvanceableReader: Reader { + fn skip(&mut self, offset: usize) -> Result<(), DidntRead>; + fn backtrack(&mut self, offset: usize) -> Result<(), DidntRead>; + fn advance(&mut self, offset: isize) -> Result<(), DidntRead> { + if offset > 0 { + self.skip(offset as usize) + } else { + self.backtrack((-offset) as usize) + } + } + } + #[derive(Debug, Clone, Copy)] pub struct DidntSiphon; diff --git a/commons/zenoh-buffers/src/zbuf.rs b/commons/zenoh-buffers/src/zbuf.rs index fd86f454af..5c96c156cc 100644 --- a/commons/zenoh-buffers/src/zbuf.rs +++ b/commons/zenoh-buffers/src/zbuf.rs @@ -15,12 +15,21 @@ use crate::ZSliceKind; use crate::{ buffer::{Buffer, SplitBuffer}, - reader::{BacktrackableReader, DidntRead, DidntSiphon, HasReader, Reader, SiphonableReader}, + reader::{ + AdvanceableReader, BacktrackableReader, DidntRead, DidntSiphon, HasReader, Reader, + SiphonableReader, + }, writer::{BacktrackableWriter, DidntWrite, HasWriter, Writer}, ZSlice, }; use alloc::{sync::Arc, vec::Vec}; -use core::{cmp, iter, mem, num::NonZeroUsize, ops::RangeBounds, ptr}; +use core::{ + cmp::{self, min}, + isize, iter, mem, + num::NonZeroUsize, + ops::RangeBounds, + ptr, +}; use zenoh_collections::SingleOrVec; fn get_mut_unchecked(arc: &mut Arc) -> &mut T { @@ -355,6 +364,48 @@ impl<'a> BacktrackableReader for ZBufReader<'a> { } } +impl<'a> AdvanceableReader for ZBufReader<'a> { + fn skip(&mut self, offset: usize) -> Result<(), DidntRead> { + let mut remianing_offset = offset; + while remianing_offset > 0 { + if let Some(s) = self.inner.slices.get(self.cursor.slice) { + let remains_in_current_slice = s.len() - self.cursor.byte; + let advance = min(remianing_offset, remains_in_current_slice); + remianing_offset -= advance; + self.cursor.byte += advance; + if self.cursor.byte == s.len() { + self.cursor.slice += 1; + self.cursor.byte = 0; + } + } else { + return Err(DidntRead); + } + } + Ok(()) + } + + fn backtrack(&mut self, offset: usize) -> Result<(), DidntRead> { + let mut remianing_offset = offset; + while remianing_offset > 0 { + let backtrack = min(remianing_offset, self.cursor.byte); + remianing_offset -= backtrack; + self.cursor.byte -= backtrack; + if self.cursor.byte == 0 { + if self.cursor.slice == 0 { + break; + } + self.cursor.slice -= 1; + self.cursor.byte = self.inner.slices.get(self.cursor.slice).unwrap().len(); + } + } + if remianing_offset == 0 { + Ok(()) + } else { + Err(DidntRead) + } + } +} + impl<'a> SiphonableReader for ZBufReader<'a> { fn siphon(&mut self, writer: &mut W) -> Result where @@ -400,6 +451,28 @@ impl<'a> std::io::Read for ZBufReader<'a> { } } +#[cfg(feature = "std")] +impl<'a> std::io::Seek for ZBufReader<'a> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + let mut current_pos = self.cursor.byte as i64; + for i in 0..self.cursor.slice { + current_pos += self.inner.slices.get(i).unwrap().len() as i64; + } + let offset = match pos { + std::io::SeekFrom::Start(s) => s as i64 - current_pos, + std::io::SeekFrom::Current(s) => s, + std::io::SeekFrom::End(s) => self.inner.len() as i64 + s - current_pos, + }; + match self.advance(offset as isize) { + Ok(()) => Ok((offset + current_pos) as u64), + Err(_) => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "InvalidInput", + )), + } + } +} + // ZSlice iterator pub struct ZBufSliceIterator<'a, 'b> { reader: &'a mut ZBufReader<'b>, @@ -640,6 +713,7 @@ impl ZBuf { } mod tests { + #[test] fn zbuf_eq() { use super::{ZBuf, ZSlice}; @@ -668,4 +742,45 @@ mod tests { assert_eq!(zbuf1, zbuf2); } + + #[test] + fn zbuf_seek() { + use super::{HasReader, ZBuf}; + use crate::reader::Reader; + use std::io::Seek; + + let mut buf = ZBuf::empty(); + buf.push_zslice([0u8, 1u8, 2u8, 3u8].into()); + buf.push_zslice([4u8, 5u8, 6u8, 7u8, 8u8].into()); + buf.push_zslice([9u8, 10u8, 11u8, 12u8, 13u8, 14u8].into()); + let mut reader = buf.reader(); + assert!(reader.stream_position().unwrap() == 0); + assert!(reader.read_u8().unwrap() == 0); + assert!(reader.seek(std::io::SeekFrom::Current(6)).unwrap() == 7); + assert!(reader.read_u8().unwrap() == 7); + assert!(reader.seek(std::io::SeekFrom::Current(-5)).unwrap() == 3); + assert!(reader.read_u8().unwrap() == 3); + assert!(reader.seek(std::io::SeekFrom::Current(10)).unwrap() == 14); + assert!(reader.read_u8().unwrap() == 14); + assert!(reader.seek(std::io::SeekFrom::Current(100)).is_err()); + + assert!(reader.seek(std::io::SeekFrom::Start(0)).unwrap() == 0); + assert!(reader.read_u8().unwrap() == 0); + assert!(reader.seek(std::io::SeekFrom::Start(12)).unwrap() == 12); + assert!(reader.read_u8().unwrap() == 12); + assert!(reader.seek(std::io::SeekFrom::Start(15)).unwrap() == 15); + assert!(reader.read_u8().is_err()); + assert!(reader.seek(std::io::SeekFrom::Start(100)).is_err()); + + assert!(reader.seek(std::io::SeekFrom::End(0)).unwrap() == 15); + assert!(reader.read_u8().is_err()); + assert!(reader.seek(std::io::SeekFrom::End(-5)).unwrap() == 10); + assert!(reader.read_u8().unwrap() == 10); + assert!(reader.seek(std::io::SeekFrom::End(-15)).unwrap() == 0); + assert!(reader.read_u8().unwrap() == 0); + assert!(reader.seek(std::io::SeekFrom::End(-20)).is_err()); + + assert!(reader.seek(std::io::SeekFrom::Start(10)).is_ok()); + assert!(reader.seek(std::io::SeekFrom::Current(-100)).is_err()); + } } diff --git a/zenoh/src/payload.rs b/zenoh/src/payload.rs index ed2a58145c..57c1126a2d 100644 --- a/zenoh/src/payload.rs +++ b/zenoh/src/payload.rs @@ -53,13 +53,13 @@ impl Payload { self.0.len() } - /// Get a [`PayloadReader`] implementing [`std::io::Read`] trait. + /// Get a [`PayloadReader`] implementing [`std::io::Read`] and [`std::io::Seek`] traits. pub fn reader(&self) -> PayloadReader<'_> { PayloadReader(self.0.reader()) } } -/// A reader that implements [`std::io::Read`] trait to read from a [`Payload`]. +/// A reader that implements [`std::io::Read`] and [`std::io::Seek`] traits to read from a [`Payload`]. pub struct PayloadReader<'a>(ZBufReader<'a>); impl std::io::Read for PayloadReader<'_> { @@ -68,6 +68,12 @@ impl std::io::Read for PayloadReader<'_> { } } +impl std::io::Seek for PayloadReader<'_> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.0.seek(pos) + } +} + /// Provide some facilities specific to the Rust API to encode/decode a [`Value`] with an `Serialize`. impl Payload { /// Encode an object of type `T` as a [`Value`] using the [`ZSerde`].