diff --git a/src/array/mod.rs b/src/array/mod.rs index 02735c3d0b..8302057eda 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -16,10 +16,10 @@ //! //! Most arrays contain a [`MutableArray`] counterpart that is neither clonable nor slicable, but //! can be operated in-place. -use std::any::Any; +use std::any::{type_name, Any}; use std::sync::Arc; -use crate::error::Result; +use crate::error::{Error, Result}; use crate::{ bitmap::{Bitmap, MutableBitmap}, datatypes::DataType, @@ -142,6 +142,12 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// Clone a `&dyn Array` to an owned `Box`. fn to_boxed(&self) -> Box; + + #[doc(hidden)] + #[inline] + fn as_mut_any(&mut self) -> &mut dyn Any { + self.as_any_mut() + } } dyn_clone::clone_trait_object!(Array); @@ -703,6 +709,93 @@ impl<'a> AsRef<(dyn Array + 'a)> for dyn Array { } } +mod downcast { + use super::{Array, DataType, MutableArray}; + use std::any::Any; + + /// Arrays that can be downcasted to a concrete type ([`Array`] and [`MutableArray`]). + pub trait ArrayAny { + /// The [`DataType`] of the array. + fn data_type(&self) -> &DataType; + /// Converts itself to a reference of [`Any`]. + fn as_any(&self) -> &dyn Any; + /// Converts itself to a mutable reference of [`Any`]. + fn as_mut_any(&mut self) -> &mut dyn Any; + } + + macro_rules! impl_array_any { + ($ty:ident) => { + impl ArrayAny for dyn $ty { + #[inline] + fn data_type(&self) -> &DataType { + $ty::data_type(self) + } + + #[inline] + fn as_any(&self) -> &dyn Any { + $ty::as_any(self) + } + + #[inline] + fn as_mut_any(&mut self) -> &mut dyn Any { + $ty::as_mut_any(self) + } + } + + impl ArrayAny for Box { + #[inline] + fn data_type(&self) -> &DataType { + $ty::data_type(self.as_ref()) + } + + #[inline] + fn as_any(&self) -> &dyn Any { + $ty::as_any(self.as_ref()) + } + + #[inline] + fn as_mut_any(&mut self) -> &mut dyn Any { + $ty::as_mut_any(self.as_mut()) + } + } + }; + } + + impl_array_any!(Array); + impl_array_any!(MutableArray); +} + +/// Downcast an array reference to a concrete type. +#[inline] +pub fn downcast_ref(array: &(impl downcast::ArrayAny + ?Sized)) -> Result<&M> { + array.as_any().downcast_ref().ok_or_else(|| { + Error::oos(format!( + "Unable to downcast array of data type {:?} into {}", + array.data_type(), + type_name::() + )) + }) +} + +/// Downcast a mutable array reference to a concrete type. +#[inline] +pub fn downcast_mut(array: &mut (impl downcast::ArrayAny + ?Sized)) -> Result<&mut M> { + let arr_ptr = array.as_mut_any() as *mut dyn Any; + // Safety: this is sound and is only to avoid non-polonius borrow checker which erroneously + // thinks that array will be mutable borrowed even past the return point; we know that the + // pointer comes from a mutable reference and we are returning a reference bound to the same + // lifetime. + if let Some(array) = unsafe { (*arr_ptr).downcast_mut::() } { + Ok(array) + } else { + Err(Error::oos(format!( + "Unable to downcast array of data type {:?} into {}", + array.data_type(), + type_name::() + ))) + } +} + mod binary; mod boolean; mod dictionary; diff --git a/tests/it/array/mod.rs b/tests/it/array/mod.rs index 85318ba628..02a100d444 100644 --- a/tests/it/array/mod.rs +++ b/tests/it/array/mod.rs @@ -13,7 +13,10 @@ mod struct_; mod union; mod utf8; -use arrow2::array::{clone, new_empty_array, new_null_array, Array, PrimitiveArray}; +use arrow2::array::{ + clone, downcast_mut, downcast_ref, new_empty_array, new_null_array, Array, MutableArray, + MutablePrimitiveArray, PrimitiveArray, +}; use arrow2::bitmap::Bitmap; use arrow2::datatypes::{DataType, Field, UnionMode}; @@ -140,3 +143,27 @@ fn test_with_validity() { struct A { array: Box, } + +#[test] +fn test_downcast() { + let arr = PrimitiveArray::from_slice([1i32, 2, 3]); + let arr_box: Box = Box::new(arr.clone()); + assert_eq!(downcast_ref::>(&arr_box).unwrap(), &arr); + assert_eq!( + downcast_ref::>(arr_box.as_ref()).unwrap(), + &arr + ); + assert!(downcast_ref::>(&arr_box).is_err()); + + let mut_arr = MutablePrimitiveArray::from_slice([1i32, 2, 3]); + let mut mut_arr_box: Box = Box::new(mut_arr.clone()); + assert_eq!( + downcast_mut::>(&mut mut_arr_box).unwrap(), + &mut_arr + ); + assert_eq!( + downcast_mut::>(mut_arr_box.as_mut()).unwrap(), + &mut_arr + ); + assert!(downcast_mut::>(&mut mut_arr_box).is_err()); +}