diff --git a/daft/arrow_utils.py b/daft/arrow_utils.py index 7f2e8ca40f..31c9a948a9 100644 --- a/daft/arrow_utils.py +++ b/daft/arrow_utils.py @@ -63,7 +63,6 @@ def ensure_array(arr: pa.Array) -> pa.Array: class _FixSliceOffsets: - # TODO(Clark): For pyarrow < 12.0.0, struct array slice offsets are dropped # when converting to record batches. We work around this below by flattening # the field arrays for all struct arrays, which propagates said offsets to diff --git a/daft/datatype.py b/daft/datatype.py index 0488551ab8..c9e32c2d59 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -6,6 +6,30 @@ from daft.daft import PyDataType +_RAY_DATA_EXTENSIONS_AVAILABLE = True +_TENSOR_EXTENSION_TYPES = [] +try: + import ray +except ImportError: + _RAY_DATA_EXTENSIONS_AVAILABLE = False +else: + _RAY_VERSION = tuple(int(s) for s in ray.__version__.split(".")) + try: + # Variable-shaped tensor column support was added in Ray 2.1.0. + if _RAY_VERSION >= (2, 2, 0): + from ray.data.extensions import ( + ArrowTensorType, + ArrowVariableShapedTensorType, + ) + + _TENSOR_EXTENSION_TYPES = [ArrowTensorType, ArrowVariableShapedTensorType] + else: + from ray.data.extensions import ArrowTensorType + + _TENSOR_EXTENSION_TYPES = [ArrowTensorType] + except ImportError: + _RAY_DATA_EXTENSIONS_AVAILABLE = False + class DataType: _dtype: PyDataType @@ -96,6 +120,10 @@ def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType: def struct(cls, fields: dict[str, DataType]) -> DataType: return cls._from_pydatatype(PyDataType.struct({name: datatype._dtype for name, datatype in fields.items()})) + @classmethod + def extension(cls, name: str, storage_dtype: DataType, metadata: str | None = None) -> DataType: + return cls._from_pydatatype(PyDataType.extension(name, storage_dtype._dtype, metadata)) + @classmethod def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: if pa.types.is_int8(arrow_type): @@ -140,9 +168,29 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: assert isinstance(arrow_type, pa.StructType) fields = [arrow_type[i] for i in range(arrow_type.num_fields)] return cls.struct({field.name: cls.from_arrow_type(field.type) for field in fields}) + elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)): + # TODO(Clark): Add a native cross-lang extension type representation for Ray's tensor extension types. + return cls.python() + elif isinstance(arrow_type, pa.PyExtensionType): + # TODO(Clark): Add a native cross-lang extension type representation for PyExtensionTypes. + raise ValueError( + "pyarrow extension types that subclass pa.PyExtensionType can't be used in Daft, since they can't be " + f"used in non-Python Arrow implementations and Daft uses the Rust Arrow2 implementation: {arrow_type}" + ) + elif isinstance(arrow_type, pa.BaseExtensionType): + name = arrow_type.extension_name + try: + metadata = arrow_type.__arrow_ext_serialize__().decode() + except AttributeError: + metadata = None + return cls.extension( + name, + cls.from_arrow_type(arrow_type.storage_type), + metadata, + ) else: # Fall back to a Python object type. - # TODO(Clark): Add native support for remaining Arrow types and extension types. + # TODO(Clark): Add native support for remaining Arrow types. return cls.python() @classmethod diff --git a/src/array/ops/broadcast.rs b/src/array/ops/broadcast.rs index 027b26fbfc..c5477a13e7 100644 --- a/src/array/ops/broadcast.rs +++ b/src/array/ops/broadcast.rs @@ -1,8 +1,8 @@ use crate::{ array::{BaseArray, DataArray}, datatypes::{ - BinaryArray, BooleanArray, DaftNumericType, DataType, FixedSizeListArray, ListArray, - NullArray, StructArray, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, DataType, ExtensionArray, FixedSizeListArray, + ListArray, NullArray, StructArray, Utf8Array, }, error::{DaftError, DaftResult}, }; @@ -251,6 +251,23 @@ impl Broadcastable for StructArray { } } +impl Broadcastable for ExtensionArray { + fn broadcast(&self, num: usize) -> DaftResult { + if self.len() != 1 { + return Err(DaftError::ValueError(format!( + "Attempting to broadcast non-unit length Array named: {}", + self.name() + ))); + } + let array = self.data(); + let mut growable = arrow2::array::growable::make_growable(&[array], true, num); + for _ in 0..num { + growable.extend(0, 0, 1); + } + ExtensionArray::new(self.field.clone(), growable.as_box()) + } +} + #[cfg(feature = "python")] impl Broadcastable for crate::datatypes::PythonArray { fn broadcast(&self, num: usize) -> DaftResult { diff --git a/src/array/ops/filter.rs b/src/array/ops/filter.rs index aa3b32613d..5f272c4334 100644 --- a/src/array/ops/filter.rs +++ b/src/array/ops/filter.rs @@ -1,8 +1,8 @@ use crate::{ array::DataArray, datatypes::{ - BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, ListArray, NullArray, - StructArray, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray, ListArray, + NullArray, StructArray, Utf8Array, }, error::DaftResult, }; @@ -74,6 +74,13 @@ impl StructArray { } } +impl ExtensionArray { + pub fn filter(&self, mask: &BooleanArray) -> DaftResult { + let result = arrow2::compute::filter::filter(self.data(), mask.downcast())?; + DataArray::try_from((self.name(), result)) + } +} + #[cfg(feature = "python")] impl crate::datatypes::PythonArray { pub fn filter(&self, mask: &BooleanArray) -> DaftResult { diff --git a/src/array/ops/full.rs b/src/array/ops/full.rs index 7e61d88607..1ec30cc368 100644 --- a/src/array/ops/full.rs +++ b/src/array/ops/full.rs @@ -14,10 +14,7 @@ where let arrow_dtype = dtype.to_arrow(); match arrow_dtype { Ok(arrow_dtype) => DataArray::::new( - Arc::new(Field { - name: name.to_string(), - dtype: dtype.clone(), - }), + Arc::new(Field::new(name.to_string(), dtype.clone())), arrow2::array::new_null_array(arrow_dtype, length), ) .unwrap(), @@ -29,10 +26,7 @@ where let arrow_dtype = dtype.to_arrow(); match arrow_dtype { Ok(arrow_dtype) => DataArray::::new( - Arc::new(Field { - name: name.to_string(), - dtype: dtype.clone(), - }), + Arc::new(Field::new(name.to_string(), dtype.clone())), arrow2::array::new_empty_array(arrow_dtype), ) .unwrap(), diff --git a/src/array/ops/if_else.rs b/src/array/ops/if_else.rs index ab79bc1a47..80a3058e3b 100644 --- a/src/array/ops/if_else.rs +++ b/src/array/ops/if_else.rs @@ -1,7 +1,7 @@ use crate::array::{BaseArray, DataArray}; use crate::datatypes::{ - BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, ListArray, NullArray, - PythonArray, StructArray, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray, ListArray, + NullArray, PythonArray, StructArray, Utf8Array, }; use crate::error::{DaftError, DaftResult}; use crate::utils::arrow::arrow_bitmap_and_helper; @@ -224,62 +224,57 @@ impl PythonArray { } } -fn from_arrow_if_then_else(predicate: &BooleanArray, if_true: &T, if_false: &T) -> DaftResult -where - T: BaseArray - + Downcastable - + for<'a> TryFrom<(&'a str, Box), Error = DaftError>, - ::Output: arrow2::array::Array, -{ - let result = arrow2::compute::if_then_else::if_then_else( - predicate.downcast(), - if_true.downcast(), - if_false.downcast(), - )?; - T::try_from((if_true.name(), result)) -} - fn nested_if_then_else(predicate: &BooleanArray, if_true: &T, if_false: &T) -> DaftResult where T: BaseArray + Broadcastable - + Downcastable + for<'a> TryFrom<(&'a str, Box), Error = DaftError>, - ::Output: arrow2::array::Array, { // TODO(Clark): Support streaming broadcasting, i.e. broadcasting without inflating scalars to full array length. - match (predicate.len(), if_true.len(), if_false.len()) { + let result = match (predicate.len(), if_true.len(), if_false.len()) { (predicate_len, if_true_len, if_false_len) - if predicate_len == if_true_len && if_true_len == if_false_len => from_arrow_if_then_else(predicate, if_true, if_false), - (1, if_true_len, 1) => from_arrow_if_then_else( - &predicate.broadcast(if_true_len)?, - if_true, - &if_false.broadcast(if_true_len)?, - ), - (1, 1, if_false_len) => from_arrow_if_then_else( - &predicate.broadcast(if_false_len)?, - &if_true.broadcast(if_false_len)?, - if_false, - ), - (predicate_len, 1, 1) => from_arrow_if_then_else( - predicate, - &if_true.broadcast(predicate_len)?, - &if_false.broadcast(predicate_len)?, - ), - (predicate_len, if_true_len, 1) if predicate_len == if_true_len => from_arrow_if_then_else( - predicate, - if_true, - &if_false.broadcast(predicate_len)?, - ), - (predicate_len, 1, if_false_len) if predicate_len == if_false_len => from_arrow_if_then_else( - predicate, - &if_true.broadcast(predicate_len)?, - if_false, - ), + if predicate_len == if_true_len && if_true_len == if_false_len => + { + arrow2::compute::if_then_else::if_then_else( + predicate.downcast(), + if_true.data(), + if_false.data(), + )? + } + (1, if_true_len, 1) => arrow2::compute::if_then_else::if_then_else( + predicate.broadcast(if_true_len)?.downcast(), + if_true.data(), + if_false.broadcast(if_true_len)?.data(), + )?, + (1, 1, if_false_len) => arrow2::compute::if_then_else::if_then_else( + predicate.broadcast(if_false_len)?.downcast(), + if_true.broadcast(if_false_len)?.data(), + if_false.data(), + )?, + (predicate_len, 1, 1) => arrow2::compute::if_then_else::if_then_else( + predicate.downcast(), + if_true.broadcast(predicate_len)?.data(), + if_false.broadcast(predicate_len)?.data(), + )?, + (predicate_len, if_true_len, 1) if predicate_len == if_true_len => { + arrow2::compute::if_then_else::if_then_else( + predicate.downcast(), + if_true.data(), + if_false.broadcast(predicate_len)?.data(), + )? + } + (predicate_len, 1, if_false_len) if predicate_len == if_false_len => { + arrow2::compute::if_then_else::if_then_else( + predicate.downcast(), + if_true.broadcast(predicate_len)?.data(), + if_false.data(), + )? + } (p, s, o) => { - Err(DaftError::ValueError(format!("Cannot run if_else against arrays with non-broadcastable lengths: if_true={s}, if_false={o}, predicate={p}"))) + return Err(DaftError::ValueError(format!("Cannot run if_else against arrays with non-broadcastable lengths: if_true={s}, if_false={o}, predicate={p}"))); } - } + }; + T::try_from((if_true.name(), result)) } impl ListArray { @@ -307,3 +302,13 @@ impl StructArray { nested_if_then_else(predicate, self, other) } } + +impl ExtensionArray { + pub fn if_else( + &self, + other: &ExtensionArray, + predicate: &BooleanArray, + ) -> DaftResult { + nested_if_then_else(predicate, self, other) + } +} diff --git a/src/array/ops/list.rs b/src/array/ops/list.rs index 4ec0f772ce..f4061aadc5 100644 --- a/src/array/ops/list.rs +++ b/src/array/ops/list.rs @@ -64,10 +64,10 @@ impl ListArray { }; with_match_arrow_daft_types!(child_data_type,|$T| { - let new_data_arr = DataArray::<$T>::new(Arc::new(Field { - name: self.field.name.clone(), - dtype: child_data_type.clone(), - }), growable.as_box())?; + let new_data_arr = DataArray::<$T>::new(Arc::new(Field::new( + self.field.name.clone(), + child_data_type.clone(), + )), growable.as_box())?; Ok(new_data_arr.into_series()) }) } @@ -120,10 +120,10 @@ impl FixedSizeListArray { }; with_match_arrow_daft_types!(child_data_type,|$T| { - let new_data_arr = DataArray::<$T>::new(Arc::new(Field { - name: self.field.name.clone(), - dtype: child_data_type.clone(), - }), growable.as_box())?; + let new_data_arr = DataArray::<$T>::new(Arc::new(Field::new( + self.field.name.clone(), + child_data_type.clone(), + )), growable.as_box())?; Ok(new_data_arr.into_series()) }) } diff --git a/src/array/ops/take.rs b/src/array/ops/take.rs index 88dd6b2019..92462665f7 100644 --- a/src/array/ops/take.rs +++ b/src/array/ops/take.rs @@ -1,8 +1,8 @@ use crate::{ array::{BaseArray, DataArray}, datatypes::{ - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, FixedSizeListArray, ListArray, - NullArray, StructArray, Utf8Array, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, + FixedSizeListArray, ListArray, NullArray, StructArray, Utf8Array, }, error::DaftResult, }; @@ -291,6 +291,41 @@ impl StructArray { } } +impl ExtensionArray { + #[inline] + pub fn get(&self, idx: usize) -> Option> { + if idx >= self.len() { + panic!("Out of bounds: {} vs len: {}", idx, self.len()) + } + let is_valid = self + .data + .validity() + .map_or(true, |validity| validity.get_bit(idx)); + if is_valid { + Some(arrow2::scalar::new_scalar(self.data(), idx)) + } else { + None + } + } + + pub fn take(&self, idx: &DataArray) -> DaftResult + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + let result = arrow2::compute::take::take(self.data(), idx.downcast())?; + Self::try_from((self.name(), result)) + } + + pub fn str_value(&self, idx: usize) -> DaftResult { + let val = self.get(idx); + match val { + None => Ok("None".to_string()), + Some(v) => Ok(format!("{v:?}")), + } + } +} + #[cfg(feature = "python")] impl crate::datatypes::PythonArray { #[inline] diff --git a/src/datatypes/dtype.rs b/src/datatypes/dtype.rs index 1199184e18..e0b4b29127 100644 --- a/src/datatypes/dtype.rs +++ b/src/datatypes/dtype.rs @@ -72,6 +72,8 @@ pub enum DataType { List(Box), /// A nested [`DataType`] with a given number of [`Field`]s. Struct(Vec), + /// Extension type. + Extension(String, Box, Option), // Stop ArrowTypes DaftType(Box), Python, @@ -117,6 +119,11 @@ impl DataType { .collect::>>()?; ArrowType::Struct(fields) }), + DataType::Extension(name, dtype, metadata) => Ok(ArrowType::Extension( + name.clone(), + Box::new(dtype.to_arrow()?), + metadata.clone(), + )), _ => Err(DaftError::TypeError(format!( "Can not convert {self:?} into arrow type" ))), @@ -128,6 +135,7 @@ impl DataType { match self { Date => Int32, Duration(_) | Timestamp(..) | Time(_) => Int64, + Extension(_, inner, _) => *inner.clone(), _ => self.clone(), } } @@ -151,23 +159,41 @@ impl DataType { // DataType::Float16 | DataType::Float32 | DataType::Float64 => true, + DataType::Extension(_, inner, _) => inner.is_numeric(), _ => false } } #[inline] pub fn is_temporal(&self) -> bool { - matches!(self, DataType::Date) + match self { + DataType::Date => true, + DataType::Extension(_, inner, _) => inner.is_temporal(), + _ => false, + } } #[inline] pub fn is_null(&self) -> bool { - matches!(self, DataType::Null) + match self { + DataType::Null => true, + DataType::Extension(_, inner, _) => inner.is_null(), + _ => false, + } + } + + #[inline] + pub fn is_extension(&self) -> bool { + matches!(self, DataType::Extension(..)) } #[inline] pub fn is_python(&self) -> bool { - matches!(self, DataType::Python) + match self { + DataType::Python => true, + DataType::Extension(_, inner, _) => inner.is_python(), + _ => false, + } } #[inline] @@ -230,6 +256,11 @@ impl From<&ArrowType> for DataType { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); DataType::Struct(fields) } + ArrowType::Extension(name, dtype, metadata) => DataType::Extension( + name.clone(), + Box::new(dtype.as_ref().into()), + metadata.clone(), + ), _ => panic!("DataType :{item:?} is not supported"), } } diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 5a7137f812..a6c3de9fbd 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -6,10 +6,13 @@ use crate::{datatypes::dtype::DataType, error::DaftResult}; use serde::{Deserialize, Serialize}; +pub type Metadata = std::collections::BTreeMap; + #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)] pub struct Field { pub name: String, pub dtype: DataType, + pub metadata: Metadata, } impl Field { @@ -17,18 +20,33 @@ impl Field { Field { name: name.into(), dtype, + metadata: Default::default(), } } + + pub fn with_metadata(self, metadata: Metadata) -> Self { + Self { + name: self.name, + dtype: self.dtype, + metadata, + } + } + pub fn to_arrow(&self) -> DaftResult { - Ok(ArrowField::new( - self.name.clone(), - self.dtype.to_arrow()?, - true, - )) + Ok( + ArrowField::new(self.name.clone(), self.dtype.to_arrow()?, true) + .with_metadata(self.metadata.clone()), + ) } + pub fn rename>(&self, name: S) -> Self { - Field::new(name, self.dtype.clone()) + Self { + name: name.into(), + dtype: self.dtype.clone(), + metadata: self.metadata.clone(), + } } + pub fn to_list_field(&self) -> DaftResult { if self.dtype.is_python() { return Ok(self.clone()); @@ -37,6 +55,7 @@ impl Field { Ok(Field { name: self.name.clone(), dtype: list_dtype, + metadata: self.metadata.clone(), }) } } @@ -46,6 +65,7 @@ impl From<&ArrowField> for Field { Field { name: af.name.clone(), dtype: af.data_type().into(), + metadata: af.metadata.clone(), } } } diff --git a/src/datatypes/matching.rs b/src/datatypes/matching.rs index 7b6df6f64d..808c9bf69c 100644 --- a/src/datatypes/matching.rs +++ b/src/datatypes/matching.rs @@ -29,6 +29,7 @@ macro_rules! with_match_daft_types {( FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, Struct(_) => __with_ty__! { StructType }, + Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, _ => panic!("{:?} not implemented for with_match_daft_types", $key_type) @@ -63,6 +64,7 @@ macro_rules! with_match_physical_daft_types {( FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, Struct(_) => __with_ty__! { StructType }, + Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, _ => panic!("{:?} not implemented for with_match_physical_daft_types", $key_type) @@ -97,6 +99,7 @@ macro_rules! with_match_arrow_daft_types {( List(_) => __with_ty__! { ListType }, FixedSizeList(..) => __with_ty__! { FixedSizeListType }, Struct(_) => __with_ty__! { StructType }, + Extension(_, _, _) => __with_ty__! { ExtensionType }, Utf8 => __with_ty__! { Utf8Type }, _ => panic!("{:?} not implemented", $key_type) } diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index 0493f46dd3..24b8a775c7 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -75,6 +75,7 @@ impl_daft_arrow_datatype!(Utf8Type, Utf8); impl_daft_arrow_datatype!(FixedSizeListType, Unknown); impl_daft_arrow_datatype!(ListType, Unknown); impl_daft_arrow_datatype!(StructType, Unknown); +impl_daft_arrow_datatype!(ExtensionType, Unknown); impl_daft_non_arrow_datatype!(PythonType, Python); pub trait NumericNative: @@ -217,4 +218,5 @@ pub type Utf8Array = DataArray; pub type FixedSizeListArray = DataArray; pub type ListArray = DataArray; pub type StructArray = DataArray; +pub type ExtensionArray = DataArray; pub type PythonArray = DataArray; diff --git a/src/dsl/functions/temporal/day.rs b/src/dsl/functions/temporal/day.rs index 1c96a8e622..4927d17d57 100644 --- a/src/dsl/functions/temporal/day.rs +++ b/src/dsl/functions/temporal/day.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for DayEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::UInt32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::UInt32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to day to be temporal, got {}", field.dtype diff --git a/src/dsl/functions/temporal/day_of_week.rs b/src/dsl/functions/temporal/day_of_week.rs index 3630c54659..d943e8e90b 100644 --- a/src/dsl/functions/temporal/day_of_week.rs +++ b/src/dsl/functions/temporal/day_of_week.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for DayOfWeekEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::UInt32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::UInt32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to day to be temporal, got {}", field.dtype diff --git a/src/dsl/functions/temporal/month.rs b/src/dsl/functions/temporal/month.rs index e9247ab1d4..3d03523d88 100644 --- a/src/dsl/functions/temporal/month.rs +++ b/src/dsl/functions/temporal/month.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for MonthEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::UInt32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::UInt32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to month to be temporal, got {}", field.dtype diff --git a/src/dsl/functions/temporal/year.rs b/src/dsl/functions/temporal/year.rs index 9aa5579399..96713118d4 100644 --- a/src/dsl/functions/temporal/year.rs +++ b/src/dsl/functions/temporal/year.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for YearEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::Int32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::Int32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to year to be temporal, got {}", field.dtype diff --git a/src/python/datatype.rs b/src/python/datatype.rs index 61cb52d40a..7fb3bf255d 100644 --- a/src/python/datatype.rs +++ b/src/python/datatype.rs @@ -140,6 +140,20 @@ impl PyDataType { .into()) } + #[staticmethod] + pub fn extension( + name: &str, + storage_data_type: Self, + metadata: Option<&str>, + ) -> PyResult { + Ok(DataType::Extension( + name.to_string(), + Box::new(storage_data_type.dtype), + metadata.map(|s| s.to_string()), + ) + .into()) + } + #[staticmethod] pub fn python() -> PyResult { Ok(DataType::Python.into()) diff --git a/src/series/mod.rs b/src/series/mod.rs index 1196a03e18..17000831cd 100644 --- a/src/series/mod.rs +++ b/src/series/mod.rs @@ -44,7 +44,7 @@ impl Series { pub fn as_physical(&self) -> DaftResult { let physical_dtype = self.data_type().to_physical(); - if &physical_dtype == self.data_type() { + if &physical_dtype == self.data_type() || self.data_type().is_extension() { Ok(self.clone()) } else { self.cast(&physical_dtype) diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 8a574281d8..e70cc431c0 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -18,6 +18,8 @@ from daft.dataframe import DataFrame from daft.datatype import DataType +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + class MyObj: pass @@ -151,9 +153,10 @@ def test_create_dataframe_arrow(valid_data: list[dict[str, float]], multiple) -> assert df.to_arrow() == expected -def test_create_dataframe_arrow_tensor(valid_data: list[dict[str, float]]) -> None: +def test_create_dataframe_arrow_tensor_ray(valid_data: list[dict[str, float]]) -> None: pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} - pydict["obj"] = ArrowTensorArray.from_numpy(np.ones((len(valid_data), 2, 2))) + ata = ArrowTensorArray.from_numpy(np.ones((len(valid_data), 2, 2))) + pydict["obj"] = ata t = pa.Table.from_pydict(pydict) df = daft.from_arrow(t) assert set(df.column_names) == set(t.column_names) @@ -165,6 +168,48 @@ def test_create_dataframe_arrow_tensor(valid_data: list[dict[str, float]]) -> No assert df.to_arrow() == expected +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +def test_create_dataframe_arrow_tensor_canonical(valid_data: list[dict[str, float]]) -> None: + pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} + dtype = pa.fixed_shape_tensor(pa.int64(), (2, 2)) + storage = pa.array([list(range(4 * i, 4 * (i + 1))) for i in range(len(valid_data))], pa.list_(pa.int64(), 4)) + ata = pa.ExtensionArray.from_storage(dtype, storage) + pydict["obj"] = ata + t = pa.Table.from_pydict(pydict) + df = daft.from_arrow(t) + assert set(df.column_names) == set(t.column_names) + # Type not natively supported, so should have Python object dtype. + assert df.schema()["obj"].dtype == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(dtype.storage_type), '{"shape":[2,2]}' + ) + casted_field = t.schema.field("variety").with_type(pa.large_string()) + expected = t.cast(t.schema.set(t.schema.get_field_index("variety"), casted_field)) + # Check roundtrip. + assert df.to_arrow() == expected + + +class UuidType(pa.PyExtensionType): + def __init__(self): + pa.PyExtensionType.__init__(self, pa.binary(5)) + + def __reduce__(self): + return UuidType, () + + +def test_create_dataframe_arrow_ext_type_raises(valid_data: list[dict[str, float]]) -> None: + pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} + uuid_type = UuidType() + storage_array = pa.array([f"foo-{i}" for i in range(len(valid_data))], pa.binary(5)) + arr = pa.ExtensionArray.from_storage(uuid_type, storage_array) + pydict["obj"] = arr + t = pa.Table.from_pydict(pydict) + with pytest.raises(ValueError): + daft.from_arrow(t) + + def test_create_dataframe_arrow_unsupported_dtype(valid_data: list[dict[str, float]]) -> None: pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} pydict["obj"] = [datetime.datetime.now() for _ in range(len(valid_data))] diff --git a/tests/series/test_concat.py b/tests/series/test_concat.py index 3eac05c416..13026048f9 100644 --- a/tests/series/test_concat.py +++ b/tests/series/test_concat.py @@ -2,12 +2,16 @@ import itertools +import numpy as np import pyarrow as pa import pytest +from ray.data.extensions import ArrowTensorArray from daft import DataType, Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + class MockObject: def __init__(self, test_val): @@ -85,6 +89,57 @@ def test_series_concat_struct_array(chunks) -> None: counter += 1 +@pytest.mark.parametrize("chunks", [1, 2, 3, 10]) +def test_series_concat_tensor_array_ray(chunks) -> None: + element_shape = (2, 2) + num_elements_per_tensor = np.prod(element_shape) + chunk_size = 3 + chunk_shape = (chunk_size,) + element_shape + chunks = [ + np.arange(i * chunk_size * num_elements_per_tensor, (i + 1) * chunk_size * num_elements_per_tensor).reshape( + chunk_shape + ) + for i in range(chunks) + ] + series = [Series.from_arrow(ArrowTensorArray.from_numpy(chunk)) for chunk in chunks] + + concated = Series.concat(series) + + assert concated.datatype() == DataType.python() + expected = [chunk[i] for chunk in chunks for i in range(len(chunk))] + np.testing.assert_equal(concated.to_pylist(), expected) + + +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.parametrize("chunks", [1, 2, 3, 10]) +def test_series_concat_tensor_array_canonical(chunks) -> None: + element_shape = (2, 2) + num_elements_per_tensor = np.prod(element_shape) + chunk_size = 3 + chunk_shape = (chunk_size,) + element_shape + chunks = [ + np.arange(i * chunk_size * num_elements_per_tensor, (i + 1) * chunk_size * num_elements_per_tensor).reshape( + chunk_shape + ) + for i in range(chunks) + ] + ext_arrays = [pa.FixedShapeTensorArray.from_numpy_ndarray(chunk) for chunk in chunks] + series = [Series.from_arrow(ext_array) for ext_array in ext_arrays] + + concated = Series.concat(series) + + assert concated.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(ext_arrays[0].type.storage_type), '{"shape":[2,2]}' + ) + expected = [chunk[i] for chunk in chunks for i in range(len(chunk))] + concated_arrow = concated.to_arrow() + assert isinstance(concated_arrow.type, pa.FixedShapeTensorType) + np.testing.assert_equal(concated_arrow.to_numpy_ndarray(), expected) + + @pytest.mark.parametrize("chunks", [1, 2, 3, 10]) def test_series_concat_pyobj(chunks) -> None: series = [] diff --git a/tests/series/test_filter.py b/tests/series/test_filter.py index 22877af565..94af15ba57 100644 --- a/tests/series/test_filter.py +++ b/tests/series/test_filter.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numpy as np import pyarrow as pa import pytest @@ -7,6 +8,8 @@ from daft.series import Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + @pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES + ARROW_STRING_TYPES) def test_series_filter(dtype) -> None: @@ -110,6 +113,25 @@ def test_series_filter_on_struct_array() -> None: assert result.to_pylist() == expected +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +def test_series_filter_on_canonical_tensor_extension_array() -> None: + arr = np.arange(20).reshape((5, 2, 2)) + data = pa.FixedShapeTensorArray.from_numpy_ndarray(arr) + + s = Series.from_arrow(data) + pymask = [False, True, True, None, False] + mask = Series.from_pylist(pymask) + + result = s.filter(mask) + + assert s.datatype() == result.datatype() + expected = [val for val, keep in zip(s.to_pylist(), pymask) if keep] + assert result.to_pylist() == expected + + @pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES + ARROW_STRING_TYPES) def test_series_filter_broadcast(dtype) -> None: data = pa.array([1, 2, 3, None, 5, None]) diff --git a/tests/series/test_if_else.py b/tests/series/test_if_else.py index 552b960430..11f9114738 100644 --- a/tests/series/test_if_else.py +++ b/tests/series/test_if_else.py @@ -1,11 +1,14 @@ from __future__ import annotations +import numpy as np import pyarrow as pa import pytest from daft import Series from daft.datatype import DataType +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + @pytest.mark.parametrize("if_true_value", [1, None]) @pytest.mark.parametrize("if_false_value", [0, None]) @@ -293,6 +296,59 @@ def test_series_if_else_struct(if_true, if_false, expected) -> None: assert result.to_pylist() == expected +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.parametrize( + ["if_true", "if_false", "expected"], + [ + # Same length, same type + ( + np.arange(16).reshape((4, 2, 2)), + np.arange(16, 32).reshape((4, 2, 2)), + np.array( + [[[0, 1], [2, 3]], [[20, 21], [22, 23]], [[np.nan, np.nan], [np.nan, np.nan]], [[12, 13], [14, 15]]] + ), + ), + # Broadcast left + ( + np.arange(4).reshape((1, 2, 2)), + np.arange(16, 32).reshape((4, 2, 2)), + np.array([[[0, 1], [2, 3]], [[20, 21], [22, 23]], [[np.nan, np.nan], [np.nan, np.nan]], [[0, 1], [2, 3]]]), + ), + # Broadcast right + ( + np.arange(16).reshape((4, 2, 2)), + np.arange(16, 20).reshape((1, 2, 2)), + np.array( + [[[0, 1], [2, 3]], [[16, 17], [18, 19]], [[np.nan, np.nan], [np.nan, np.nan]], [[12, 13], [14, 15]]] + ), + ), + # Broadcast both + ( + np.arange(4).reshape((1, 2, 2)), + np.arange(16, 20).reshape((1, 2, 2)), + np.array([[[0, 1], [2, 3]], [[16, 17], [18, 19]], [[np.nan, np.nan], [np.nan, np.nan]], [[0, 1], [2, 3]]]), + ), + ], +) +def test_series_if_else_extension_type(if_true, if_false, expected) -> None: + if_true_arrow = pa.FixedShapeTensorArray.from_numpy_ndarray(if_true) + if_false_arrow = pa.FixedShapeTensorArray.from_numpy_ndarray(if_false) + if_true_series = Series.from_arrow(if_true_arrow) + if_false_series = Series.from_arrow(if_false_arrow) + predicate_series = Series.from_arrow(pa.array([True, False, None, True])) + + result = predicate_series.if_else(if_true_series, if_false_series) + + assert result.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(if_true_arrow.type.storage_type), '{"shape":[2,2]}' + ) + result_arrow = result.to_arrow() + np.testing.assert_equal(result_arrow.to_numpy_ndarray(), expected) + + @pytest.mark.parametrize( "if_true_length", [1, 3], diff --git a/tests/series/test_size_bytes.py b/tests/series/test_size_bytes.py index 93f4703a50..52e7090716 100644 --- a/tests/series/test_size_bytes.py +++ b/tests/series/test_size_bytes.py @@ -3,6 +3,7 @@ import itertools import math +import numpy as np import pyarrow as pa import pytest @@ -10,7 +11,8 @@ from daft.series import Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES -PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) +PYARROW_GE_7_0_0 = ARROW_VERSION >= (7, 0, 0) def get_total_buffer_size(arr: pa.Array) -> int: @@ -177,3 +179,30 @@ def test_series_struct_size_bytes(size, with_nulls) -> None: ) else: assert s.size_bytes() == get_total_buffer_size(data) + conversion_to_large_string_bytes + + +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.parametrize("dtype, size", itertools.product(ARROW_INT_TYPES + ARROW_FLOAT_TYPES, [0, 1, 2, 8, 9, 16])) +@pytest.mark.parametrize("with_nulls", [True, False]) +def test_series_extension_type_size_bytes(dtype, size, with_nulls) -> None: + tensor_type = pa.fixed_shape_tensor(pa.int64(), (2, 2)) + if size == 0: + storage = pa.array([], pa.list_(pa.int64(), 4)) + data = pa.FixedShapeTensorArray.from_storage(tensor_type, storage) + elif with_nulls: + pydata = np.arange(4 * size).reshape((size, 4)).tolist()[:-1] + [None] + storage = pa.array(pydata, pa.list_(pa.int64(), 4)) + data = pa.FixedShapeTensorArray.from_storage(tensor_type, storage) + else: + arr = np.arange(4 * size).reshape((size, 2, 2)) + data = pa.FixedShapeTensorArray.from_numpy_ndarray(arr) + + s = Series.from_arrow(data) + + assert s.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(data.type.storage_type), '{"shape":[2,2]}' + ) + assert s.size_bytes() == get_total_buffer_size(data) diff --git a/tests/series/test_take.py b/tests/series/test_take.py index 2bb9f1baae..ef67a39fa9 100644 --- a/tests/series/test_take.py +++ b/tests/series/test_take.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numpy as np import pyarrow as pa import pytest @@ -7,6 +8,8 @@ from daft.series import Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + @pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES + ARROW_STRING_TYPES) def test_series_take(dtype) -> None: @@ -115,6 +118,33 @@ def test_series_struct_take() -> None: assert result.to_pylist() == expected +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +def test_series_extension_type_take() -> None: + pydata = np.arange(24).reshape((6, 4)).tolist() + pydata[2] = None + storage = pa.array(pydata, pa.list_(pa.int64(), 4)) + tensor_type = pa.fixed_shape_tensor(pa.int64(), (2, 2)) + data = pa.FixedShapeTensorArray.from_storage(tensor_type, storage) + + s = Series.from_arrow(data) + assert s.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(tensor_type.storage_type), '{"shape":[2,2]}' + ) + pyidx = [2, 0, None, 5] + idx = Series.from_pylist(pyidx) + + result = s.take(idx) + assert result.datatype() == s.datatype() + assert len(result) == 4 + + original_data = s.to_pylist() + expected = [original_data[i] if i is not None else None for i in pyidx] + assert result.to_pylist() == expected + + def test_series_deeply_nested_take() -> None: # Test take on a Series with a deeply nested type: struct of list of struct of list of strings. data = pa.array([{"a": [{"b": ["foo", "bar"]}]}, {"a": [{"b": ["baz", "quux"]}]}]) diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index ac34da13e0..ac47e5149d 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -14,6 +14,8 @@ from daft.series import Series from daft.table import Table +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + PYTHON_TYPE_ARRAYS = { "int": [1, 2], "float": [1.0, 2.0], @@ -26,7 +28,7 @@ "empty_struct": [{}, {}], "null": [None, None], # The following types are not natively supported and will be cast to Python object types. - "tensor": list(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + "tensor": list(np.arange(8).reshape(2, 2, 2)), "timestamp": [datetime.datetime.now(), datetime.datetime.now()], } @@ -43,6 +45,7 @@ "empty_struct": DataType.struct({"": DataType.null()}), "null": DataType.null(), # The following types are not natively supported and will be cast to Python object types. + # TODO(Clark): Change the tensor inferred type to be the canonical fixed-shape tensor extension type. "tensor": DataType.python(), "timestamp": DataType.python(), } @@ -116,6 +119,12 @@ "timestamp": pa.timestamp("us"), } +if ARROW_VERSION >= (12, 0, 0): + ARROW_TYPE_ARRAYS["ext_type"] = ( + pa.FixedShapeTensorArray.from_numpy_ndarray(np.array(PYTHON_TYPE_ARRAYS["tensor"])), + ) + ARROW_ROUNDTRIP_TYPES["ext_type"] = (pa.fixed_shape_tensor(pa.int64(), (2, 2)),) + def test_from_pydict_roundtrip() -> None: table = Table.from_pydict(PYTHON_TYPE_ARRAYS)