diff --git a/src/concat_source.rs b/src/concat_source.rs index a2b60f4..76be9f1 100644 --- a/src/concat_source.rs +++ b/src/concat_source.rs @@ -54,7 +54,7 @@ use crate::{ /// .unwrap() /// ); /// ``` -#[derive(Debug, Default, Clone, Eq)] +#[derive(Debug, Default, Clone)] pub struct ConcatSource { children: Vec, } @@ -66,26 +66,46 @@ impl ConcatSource { T: Source + 'static, S: IntoIterator, { - Self { - children: sources.into_iter().map(|s| SourceExt::boxed(s)).collect(), - } + // Flatten the children + let children = sources + .into_iter() + .flat_map( + |source| match source.as_any().downcast_ref::() { + // This clone is cheap because `BoxSource` is `Arc`ed for each child. + Some(concat_source) => concat_source.children.clone(), + None => { + vec![SourceExt::boxed(source)] + } + }, + ) + .collect(); + Self { children } + } + + fn children(&self) -> &Vec { + &self.children } /// Add a [Source] to concat. - pub fn add(&mut self, item: S) { - self.children.push(SourceExt::boxed(item)); + pub fn add(&mut self, source: S) { + if let Some(concat_source) = source.as_any().downcast_ref::() + { + self.children.extend(concat_source.children.clone()); + } else { + self.children.push(SourceExt::boxed(source)); + } } } impl Source for ConcatSource { fn source(&self) -> Cow { - let all = self.children.iter().map(|child| child.source()).collect(); + let all = self.children().iter().map(|child| child.source()).collect(); Cow::Owned(all) } fn buffer(&self) -> Cow<[u8]> { let all = self - .children + .children() .iter() .map(|child| child.buffer()) .collect::>() @@ -94,7 +114,7 @@ impl Source for ConcatSource { } fn size(&self) -> usize { - self.children.iter().map(|child| child.size()).sum() + self.children().iter().map(|child| child.size()).sum() } fn map(&self, options: &MapOptions) -> Option { @@ -102,7 +122,7 @@ impl Source for ConcatSource { } fn to_writer(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { - for child in &self.children { + for child in self.children() { child.to_writer(writer)?; } Ok(()) @@ -112,7 +132,7 @@ impl Source for ConcatSource { impl Hash for ConcatSource { fn hash(&self, state: &mut H) { "ConcatSource".hash(state); - for child in self.children.iter() { + for child in self.children().iter() { child.hash(state); } } @@ -120,9 +140,10 @@ impl Hash for ConcatSource { impl PartialEq for ConcatSource { fn eq(&self, other: &Self) -> bool { - self.children == other.children + self.children() == other.children() } } +impl Eq for ConcatSource {} impl StreamChunks for ConcatSource { fn stream_chunks( @@ -132,8 +153,8 @@ impl StreamChunks for ConcatSource { on_source: OnSource, on_name: OnName, ) -> crate::helpers::GeneratedInfo { - if self.children.len() == 1 { - return self.children[0] + if self.children().len() == 1 { + return self.children()[0] .stream_chunks(options, on_chunk, on_source, on_name); } let mut current_line_offset = 0; @@ -141,7 +162,7 @@ impl StreamChunks for ConcatSource { let mut source_mapping: HashMap = HashMap::default(); let mut name_mapping: HashMap = HashMap::default(); let mut need_to_cloas_mapping = false; - for item in &self.children { + for item in self.children() { let source_index_mapping: RefCell> = RefCell::new(HashMap::default()); let name_index_mapping: RefCell> =