diff --git a/arrow2_convert/src/serialize.rs b/arrow2_convert/src/serialize.rs index b359e0b..c27d867 100644 --- a/arrow2_convert/src/serialize.rs +++ b/arrow2_convert/src/serialize.rs @@ -1,5 +1,6 @@ //! Implementation and traits for serializing to Arrow. +use arrow2::array::Array; use arrow2::array::*; use arrow2::chunk::Chunk; use chrono::{NaiveDate, NaiveDateTime}; @@ -425,6 +426,47 @@ pub fn arrow_serialize_to_mutable_array< Ok(arr) } +/// API to flatten a Chunk consisting of an `arrow2::array::StructArray` into a `Chunk` consisting of `arrow2::array::Array`s contained by the `StructArray` +pub trait FlattenChunk { + /// Convert an `arrow2::chunk::Chunk` containing a `arrow2::array::StructArray` to an `arrow2::chunk::Chunk` consisting of the + /// `arrow::array::Array`s contained by the `StructArray` by consuming the + /// original `Chunk`. Returns an error if the `Chunk` cannot be flattened. + fn flatten(self) -> Result>, arrow2::error::Error>; +} + +impl FlattenChunk for Chunk +where + A: AsRef, +{ + fn flatten(self) -> Result>, arrow2::error::Error> { + let arrays = self.into_arrays(); + + // we only support flattening of a Chunk containing a single StructArray + if arrays.len() != 1 { + return Err(arrow2::error::Error::InvalidArgumentError( + "Chunk must contain a single Array".to_string(), + )); + } + + let array = &arrays[0]; + + let physical_type = array.as_ref().data_type().to_physical_type(); + if physical_type != arrow2::datatypes::PhysicalType::Struct { + return Err(arrow2::error::Error::InvalidArgumentError( + "Array in Chunk must be of type arrow2::datatypes::PhysicalType::Struct" + .to_string(), + )); + } + + let struct_array = array + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + Ok(Chunk::new(struct_array.values().to_vec())) + } +} + /// Top-level API to serialize to Arrow pub trait TryIntoArrow<'a, ArrowArray, Element> where diff --git a/arrow2_convert/tests/test_flatten_chunk.rs b/arrow2_convert/tests/test_flatten_chunk.rs new file mode 100644 index 0000000..d9ed261 --- /dev/null +++ b/arrow2_convert/tests/test_flatten_chunk.rs @@ -0,0 +1,69 @@ +use arrow2::array::*; +use arrow2::chunk::Chunk; +use arrow2_convert::serialize::*; +use arrow2_convert::ArrowField; +use std::sync::Arc; + +#[test] +fn test_flatten_chunk() { + #[derive(Debug, Clone, ArrowField)] + struct Struct { + a: i64, + b: i64, + } + + let target = Chunk::new(vec![ + Int64Array::from(&[Some(1), Some(2)]).arced(), + Int64Array::from(&[Some(1), Some(2)]).arced(), + ]); + + let array = vec![Struct { a: 1, b: 1 }, Struct { a: 2, b: 2 }]; + + let array: Arc = array.try_into_arrow().unwrap(); + let chunk = Chunk::new(vec![array]); + + let flattened = chunk.flatten().unwrap(); + + assert_eq!(flattened, target); +} + +#[test] +fn test_flatten_chunk_empty_chunk_error() { + let chunk: Chunk> = Chunk::new(vec![]); + assert!(chunk.flatten().is_err()); +} + +#[test] +fn test_flatten_chunk_no_single_struct_array_error() { + #[derive(Debug, Clone, ArrowField)] + struct Struct { + a: i64, + b: String, + } + + let array = vec![ + Struct { + a: 1, + b: "one".to_string(), + }, + Struct { + a: 2, + b: "two".to_string(), + }, + ]; + + let array: Arc = array.try_into_arrow().unwrap(); + + let arrays = vec![array.clone(), array.clone()]; + let chunk = Chunk::new(arrays); + + assert!(chunk.flatten().is_err()); +} + +#[test] +fn test_flatten_chunk_type_not_struct_error() { + let array: Arc = Int32Array::from(&[Some(1), None, Some(3)]).arced(); + let chunk = Chunk::new(vec![array]); + + assert!(chunk.flatten().is_err()); +}