From c8699d8a199e9e1d1503664c2c48a7c1aae64da0 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Sun, 18 Dec 2022 17:57:37 +0100 Subject: [PATCH] Improved `UnionArray` (#1331) --- src/array/growable/union.rs | 4 +- src/array/union/iterator.rs | 5 +- src/array/union/mod.rs | 175 ++++++++++++++++++-------- src/compute/sort/row/mod.rs | 4 +- src/compute/sort/row/variable.rs | 6 +- src/io/json_integration/read/array.rs | 3 +- tests/it/array/union.rs | 125 ++++++++++++++++-- 7 files changed, 250 insertions(+), 72 deletions(-) diff --git a/src/array/growable/union.rs b/src/array/growable/union.rs index f050d1272e7..8356c228a4b 100644 --- a/src/array/growable/union.rs +++ b/src/array/growable/union.rs @@ -89,11 +89,11 @@ impl<'a> Growable<'a> for GrowableUnion<'a> { fn extend_validity(&mut self, _additional: usize) {} fn as_arc(&mut self) -> Arc { - Arc::new(self.to()) + self.to().arced() } fn as_box(&mut self) -> Box { - Box::new(self.to()) + self.to().boxed() } } diff --git a/src/array/union/iterator.rs b/src/array/union/iterator.rs index e72d4735824..97ae9e26b7e 100644 --- a/src/array/union/iterator.rs +++ b/src/array/union/iterator.rs @@ -8,6 +8,7 @@ pub struct UnionIter<'a> { } impl<'a> UnionIter<'a> { + #[inline] pub fn new(array: &'a UnionArray) -> Self { Self { array, current: 0 } } @@ -16,16 +17,18 @@ impl<'a> UnionIter<'a> { impl<'a> Iterator for UnionIter<'a> { type Item = Box; + #[inline] fn next(&mut self) -> Option { if self.current == self.array.len() { None } else { let old = self.current; self.current += 1; - Some(self.array.value(old)) + Some(unsafe { self.array.value_unchecked(old) }) } } + #[inline] fn size_hint(&self) -> (usize, Option) { let len = self.array.len() - self.current; (len, Some(len)) diff --git a/src/array/union/mod.rs b/src/array/union/mod.rs index 70028d6cf2c..3aa963c00a8 100644 --- a/src/array/union/mod.rs +++ b/src/array/union/mod.rs @@ -1,5 +1,3 @@ -use ahash::AHashMap; - use crate::{ bitmap::Bitmap, buffer::Buffer, @@ -14,7 +12,6 @@ mod ffi; pub(super) mod fmt; mod iterator; -type FieldEntry = (usize, Box); type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode); /// [`UnionArray`] represents an array whose each slot can contain different values. @@ -29,10 +26,13 @@ type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode); // ``` #[derive(Clone)] pub struct UnionArray { + // Invariant: every item in `types` is `> 0 && < fields.len()` types: Buffer, - // None represents when there is no typeid - fields_hash: Option>, + // Invariant: `map.len() == fields.len()` + // Invariant: every item in `map` is `> 0 && < fields.len()` + map: Option<[usize; 127]>, fields: Vec>, + // Invariant: when set, `offsets.len() == types.len()` offsets: Option>, data_type: DataType, offset: usize, @@ -44,6 +44,7 @@ impl UnionArray { /// This function errors iff: /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. /// * the fields's len is different from the `data_type`'s children's length + /// * The number of `fields` is larger than `i8::MAX` /// * any of the values's data type is different from its corresponding children' data type pub fn try_new( data_type: DataType, @@ -58,6 +59,10 @@ impl UnionArray { "The number of `fields` must equal the number of children fields in DataType::Union", )); }; + let number_of_fields: i8 = fields + .len() + .try_into() + .map_err(|_| Error::oos("The number of `fields` cannot be larger than i8::MAX"))?; f .iter().map(|a| a.data_type()) @@ -74,27 +79,75 @@ impl UnionArray { } })?; + if let Some(offsets) = &offsets { + if offsets.len() != types.len() { + return Err(Error::oos( + "In a UnionArray, the offsets' length must be equal to the number of types", + )); + } + } if offsets.is_none() != mode.is_sparse() { return Err(Error::oos( - "The offsets must be set when the Union is dense and vice-versa", + "In a sparse UnionArray, the offsets must be set (and vice-versa)", )); } - let fields_hash = ids.as_ref().map(|ids| { - ids.iter() - .map(|x| *x as i8) - .enumerate() - .zip(fields.iter().cloned()) - .map(|((i, type_), field)| (type_, (i, field))) - .collect() - }); - - // not validated: - // * `offsets` is valid - // * max id < fields.len() + // build hash + let map = if let Some(&ids) = ids.as_ref() { + if ids.len() != fields.len() { + return Err(Error::oos( + "In a union, when the ids are set, their length must be equal to the number of fields", + )); + } + + // example: + // * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5] + // * ids = [5, 7] + // => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...] + let mut hash = [0; 127]; + + for (pos, &id) in ids.iter().enumerate() { + if !(0..=127).contains(&id) { + return Err(Error::oos( + "In a union, when the ids are set, every id must belong to [0, 128[", + )); + } + hash[id as usize] = pos; + } + + types.iter().try_for_each(|&type_| { + if type_ < 0 { + return Err(Error::oos("In a union, when the ids are set, every type must be >= 0")); + } + let id = hash[type_ as usize]; + if id >= fields.len() { + Err(Error::oos("In a union, when the ids are set, each id must be smaller than the number of fields.")) + } else { + Ok(()) + } + })?; + + Some(hash) + } else { + // Safety: every type in types is smaller than number of fields + let mut is_valid = true; + for &type_ in types.iter() { + if type_ < 0 || type_ >= number_of_fields { + is_valid = false + } + } + if !is_valid { + return Err(Error::oos( + "Every type in `types` must be larger than 0 and smaller than the number of fields.", + )); + } + + None + }; + Ok(Self { data_type, - fields_hash, + map, fields, offsets, types, @@ -128,7 +181,7 @@ impl UnionArray { let offsets = if mode.is_sparse() { None } else { - Some((0..length as i32).collect::>()) + Some((0..length as i32).collect::>().into()) }; // all from the same field @@ -151,12 +204,12 @@ impl UnionArray { let offsets = if mode.is_sparse() { None } else { - Some(Buffer::new()) + Some(Buffer::default()) }; Self { data_type, - fields_hash: None, + map: None, fields, offsets, types: Buffer::new(), @@ -186,17 +239,11 @@ impl UnionArray { /// This function panics iff `offset + length >= self.len()`. #[inline] pub fn slice(&self, offset: usize, length: usize) -> Self { - Self { - data_type: self.data_type.clone(), - fields: self.fields.clone(), - fields_hash: self.fields_hash.clone(), - types: self.types.clone().slice(offset, length), - offsets: self - .offsets - .clone() - .map(|offsets| offsets.slice(offset, length)), - offset: self.offset + offset, - } + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } } /// Returns a slice of this [`UnionArray`]. @@ -206,10 +253,11 @@ impl UnionArray { /// The caller must ensure that `offset + length <= self.len()`. #[inline] pub unsafe fn slice_unchecked(&self, offset: usize, length: usize) -> Self { + debug_assert!(offset + length <= self.len()); Self { data_type: self.data_type.clone(), fields: self.fields.clone(), - fields_hash: self.fields_hash.clone(), + map: self.map, types: self.types.clone().slice_unchecked(offset, length), offsets: self .offsets @@ -243,38 +291,57 @@ impl UnionArray { } #[inline] - fn field(&self, type_: i8) -> &dyn Array { - self.fields_hash - .as_ref() - .map(|x| x[&type_].1.as_ref()) - .unwrap_or_else(|| self.fields[type_ as usize].as_ref()) - } - - #[inline] - fn field_slot(&self, index: usize) -> usize { + unsafe fn field_slot_unchecked(&self, index: usize) -> usize { self.offsets() .as_ref() - .map(|x| x[index] as usize) + .map(|x| *x.get_unchecked(index) as usize) .unwrap_or(index + self.offset) } /// Returns the index and slot of the field to select from `self.fields`. + #[inline] pub fn index(&self, index: usize) -> (usize, usize) { - let type_ = self.types()[index]; - let field_index = self - .fields_hash + assert!(index < self.len()); + unsafe { self.index_unchecked(index) } + } + + /// Returns the index and slot of the field to select from `self.fields`. + /// The first value is guaranteed to be `< self.fields().len()` + /// # Safety + /// This function is safe iff `index < self.len`. + #[inline] + pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) { + debug_assert!(index < self.len()); + // Safety: assumption of the function + let type_ = unsafe { *self.types.get_unchecked(index) }; + // Safety: assumption of the struct + let type_ = self + .map .as_ref() - .map(|x| x[&type_].0) - .unwrap_or_else(|| type_ as usize); - let index = self.field_slot(index); - (field_index, index) + .map(|map| unsafe { *map.get_unchecked(type_ as usize) }) + .unwrap_or(type_ as usize); + // Safety: assumption of the function + let index = self.field_slot_unchecked(index); + (type_, index) } /// Returns the slot `index` as a [`Scalar`]. + /// # Panics + /// iff `index >= self.len()` pub fn value(&self, index: usize) -> Box { - let type_ = self.types()[index]; - let field = self.field(type_); - let index = self.field_slot(index); + assert!(index < self.len()); + unsafe { self.value_unchecked(index) } + } + + /// Returns the slot `index` as a [`Scalar`]. + /// # Safety + /// This function is safe iff `i < self.len`. + pub unsafe fn value_unchecked(&self, index: usize) -> Box { + debug_assert!(index < self.len()); + let (type_, index) = self.index_unchecked(index); + // Safety: assumption of the struct + debug_assert!(type_ < self.fields.len()); + let field = self.fields.get_unchecked(type_).as_ref(); new_scalar(field, index) } } diff --git a/src/compute/sort/row/mod.rs b/src/compute/sort/row/mod.rs index 005e046fc92..e342537d68c 100644 --- a/src/compute/sort/row/mod.rs +++ b/src/compute/sort/row/mod.rs @@ -647,9 +647,9 @@ mod tests { #[test] fn test_fixed_width() { let cols = [ - Int16Array::from_iter([Some(1), Some(2), None, Some(-5), Some(2), Some(2), Some(0)]) + Int16Array::from([Some(1), Some(2), None, Some(-5), Some(2), Some(2), Some(0)]) .to_boxed(), - Float32Array::from_iter([ + Float32Array::from([ Some(1.3), Some(2.5), None, diff --git a/src/compute/sort/row/variable.rs b/src/compute/sort/row/variable.rs index 40fd5c17735..84a89ce6989 100644 --- a/src/compute/sort/row/variable.rs +++ b/src/compute/sort/row/variable.rs @@ -76,10 +76,9 @@ pub fn encode<'a, I: Iterator>>(out: &mut Rows, i: I, op // Write `2_u8` to demarcate as non-empty, non-null string to_write[0] = NON_EMPTY_SENTINEL; - let chunks = val.chunks_exact(BLOCK_SIZE); - let remainder = chunks.remainder(); + let mut chunks = val.chunks_exact(BLOCK_SIZE); for (input, output) in chunks - .clone() + .by_ref() .zip(to_write[1..].chunks_exact_mut(BLOCK_SIZE + 1)) { let input: &[u8; BLOCK_SIZE] = input.try_into().unwrap(); @@ -92,6 +91,7 @@ pub fn encode<'a, I: Iterator>>(out: &mut Rows, i: I, op output[BLOCK_SIZE] = BLOCK_CONTINUATION; } + let remainder = chunks.remainder(); if !remainder.is_empty() { let start_offset = 1 + (block_count - 1) * (BLOCK_SIZE + 1); to_write[start_offset..start_offset + remainder.len()] diff --git a/src/io/json_integration/read/array.rs b/src/io/json_integration/read/array.rs index 67af89a18b2..bca82eafde4 100644 --- a/src/io/json_integration/read/array.rs +++ b/src/io/json_integration/read/array.rs @@ -414,7 +414,8 @@ pub fn to_array( } _ => panic!(), }) - .collect(), + .collect::>() + .into(), ) }) .unwrap_or_default(); diff --git a/tests/it/array/union.rs b/tests/it/array/union.rs index 69ba887d250..c250c3cceea 100644 --- a/tests/it/array/union.rs +++ b/tests/it/array/union.rs @@ -6,7 +6,7 @@ use arrow2::{ scalar::{new_scalar, PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar}, }; -fn next_unchecked(iter: &mut I) -> T +fn next_unwrap(iter: &mut I) -> T where I: Iterator>, T: Clone + 'static, @@ -105,15 +105,15 @@ fn iter_sparse() -> Result<()> { let mut iter = array.iter(); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), &Some(1) ); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), &None ); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), Some("c") ); assert_eq!(iter.next(), None); @@ -139,15 +139,15 @@ fn iter_dense() -> Result<()> { let mut iter = array.iter(); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), &Some(1) ); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), &None ); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), Some("c") ); assert_eq!(iter.next(), None); @@ -173,7 +173,7 @@ fn iter_sparse_slice() -> Result<()> { let mut iter = array_slice.iter(); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), &Some(3) ); assert_eq!(iter.next(), None); @@ -200,7 +200,7 @@ fn iter_dense_slice() -> Result<()> { let mut iter = array_slice.iter(); assert_eq!( - next_unchecked::, _>(&mut iter).value(), + next_unwrap::, _>(&mut iter).value(), &Some(3) ); assert_eq!(iter.next(), None); @@ -264,3 +264,110 @@ fn scalar() -> Result<()> { Ok(()) } + +#[test] +fn dense_without_offsets_is_error() { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Dense); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); +} + +#[test] +fn fields_must_match() { + let fields = vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Sparse); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); +} + +#[test] +fn sparse_with_offsets_is_error() { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + let offsets = vec![0, 1, 0].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn offsets_must_be_in_bounds() { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length og types + let offsets = vec![0, 1].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn sparse_with_wrong_offsets1_is_error() { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length of types + let offsets = vec![0, 1, 10].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn types_must_be_in_bounds() -> Result<()> { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + // 10 > num fields + let types = vec![0, 10].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); + Ok(()) +}