diff --git a/Cargo.lock b/Cargo.lock index 59793c84fc..4dfd2920d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -42,8 +42,7 @@ dependencies = [ [[package]] name = "arrow2" version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0f73029049896b3d70ba17756afef171ceef3569016cfa9dbca58d29e0e16f9" +source = "git+https://github.com/Eventual-Inc/arrow2?branch=clark/expand-casting-support#2ace8097342d5634746915b094c5a3cdf53f75b9" dependencies = [ "ahash", "arrow-format", @@ -166,6 +165,7 @@ dependencies = [ "dyn-clone", "fnv", "indexmap", + "lazy_static", "num-traits", "prettytable-rs", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index cb5b0e1fd0..377002d9cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,10 @@ prettytable-rs = "^0.10" rand = "^0.8" [dependencies.arrow2] +branch = "clark/expand-casting-support" features = ["compute", "io_ipc"] +git = "https://github.com/Eventual-Inc/arrow2" +package = "arrow2" version = "0.17.1" [dependencies.bincode] @@ -15,6 +18,9 @@ version = "1.3.3" features = ["serde"] version = "1.9.2" +[dependencies.lazy_static] +version = "1.4.0" + [dependencies.num-traits] version = "0.2" diff --git a/daft/arrow_utils.py b/daft/arrow_utils.py index 7f2e8ca40f..31c9a948a9 100644 --- a/daft/arrow_utils.py +++ b/daft/arrow_utils.py @@ -63,7 +63,6 @@ def ensure_array(arr: pa.Array) -> pa.Array: class _FixSliceOffsets: - # TODO(Clark): For pyarrow < 12.0.0, struct array slice offsets are dropped # when converting to record batches. We work around this below by flattening # the field arrays for all struct arrays, which propagates said offsets to diff --git a/daft/datatype.py b/daft/datatype.py index 0488551ab8..5dae6d7316 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -4,8 +4,33 @@ import pyarrow as pa +from daft.context import get_context from daft.daft import PyDataType +_RAY_DATA_EXTENSIONS_AVAILABLE = True +_TENSOR_EXTENSION_TYPES = [] +try: + import ray +except ImportError: + _RAY_DATA_EXTENSIONS_AVAILABLE = False +else: + _RAY_VERSION = tuple(int(s) for s in ray.__version__.split(".")) + try: + # Variable-shaped tensor column support was added in Ray 2.1.0. + if _RAY_VERSION >= (2, 2, 0): + from ray.data.extensions import ( + ArrowTensorType, + ArrowVariableShapedTensorType, + ) + + _TENSOR_EXTENSION_TYPES = [ArrowTensorType, ArrowVariableShapedTensorType] + else: + from ray.data.extensions import ArrowTensorType + + _TENSOR_EXTENSION_TYPES = [ArrowTensorType] + except ImportError: + _RAY_DATA_EXTENSIONS_AVAILABLE = False + class DataType: _dtype: PyDataType @@ -96,6 +121,10 @@ def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType: def struct(cls, fields: dict[str, DataType]) -> DataType: return cls._from_pydatatype(PyDataType.struct({name: datatype._dtype for name, datatype in fields.items()})) + @classmethod + def extension(cls, name: str, storage_dtype: DataType, metadata: str | None = None) -> DataType: + return cls._from_pydatatype(PyDataType.extension(name, storage_dtype._dtype, metadata)) + @classmethod def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: if pa.types.is_int8(arrow_type): @@ -140,9 +169,35 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: assert isinstance(arrow_type, pa.StructType) fields = [arrow_type[i] for i in range(arrow_type.num_fields)] return cls.struct({field.name: cls.from_arrow_type(field.type) for field in fields}) + elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)): + # TODO(Clark): Add a native cross-lang extension type representation for Ray's tensor extension types. + return cls.python() + elif isinstance(arrow_type, pa.PyExtensionType): + # TODO(Clark): Add a native cross-lang extension type representation for PyExtensionTypes. + raise ValueError( + "pyarrow extension types that subclass pa.PyExtensionType can't be used in Daft, since they can't be " + f"used in non-Python Arrow implementations and Daft uses the Rust Arrow2 implementation: {arrow_type}" + ) + elif isinstance(arrow_type, pa.BaseExtensionType): + if get_context().runner_config.name == "ray": + raise ValueError( + f"pyarrow extension types are not supported for the Ray runner: {arrow_type}. If you need support " + "for this, please let us know on this issue: " + "https://github.com/Eventual-Inc/Daft/issues/933" + ) + name = arrow_type.extension_name + try: + metadata = arrow_type.__arrow_ext_serialize__().decode() + except AttributeError: + metadata = None + return cls.extension( + name, + cls.from_arrow_type(arrow_type.storage_type), + metadata, + ) else: # Fall back to a Python object type. - # TODO(Clark): Add native support for remaining Arrow types and extension types. + # TODO(Clark): Add native support for remaining Arrow types. return cls.python() @classmethod diff --git a/src/array/ops/arrow2/sort/primitive/sort.rs b/src/array/ops/arrow2/sort/primitive/sort.rs index 8fd6c06c80..36ba05d7a7 100644 --- a/src/array/ops/arrow2/sort/primitive/sort.rs +++ b/src/array/ops/arrow2/sort/primitive/sort.rs @@ -166,6 +166,7 @@ mod tests { use super::*; use arrow2::array::ord; + use arrow2::array::Array; use arrow2::array::PrimitiveArray; use arrow2::datatypes::DataType; @@ -177,13 +178,28 @@ mod tests { ) where T: NativeType + std::cmp::Ord, { - let input = PrimitiveArray::::from(data).to(data_type.clone()); - let expected = PrimitiveArray::::from(expected_data).to(data_type.clone()); + let input = PrimitiveArray::::from(data) + .to(data_type.clone()) + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let expected = PrimitiveArray::::from(expected_data) + .to(data_type.clone()) + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); let output = sort_by(&input, ord::total_cmp, &options, None); assert_eq!(expected, output); // with limit - let expected = PrimitiveArray::::from(&expected_data[..3]).to(data_type); + let expected = PrimitiveArray::::from(&expected_data[..3]) + .to(data_type) + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); let output = sort_by(&input, ord::total_cmp, &options, Some(3)); assert_eq!(expected, output) } diff --git a/src/array/ops/broadcast.rs b/src/array/ops/broadcast.rs index dd73082f34..1d9b166e04 100644 --- a/src/array/ops/broadcast.rs +++ b/src/array/ops/broadcast.rs @@ -1,8 +1,8 @@ use crate::{ array::DataArray, datatypes::{ - BinaryArray, BooleanArray, DaftNumericType, DataType, FixedSizeListArray, ListArray, - NullArray, StructArray, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, DataType, ExtensionArray, FixedSizeListArray, + ListArray, NullArray, StructArray, Utf8Array, }, error::{DaftError, DaftResult}, }; @@ -251,6 +251,23 @@ impl Broadcastable for StructArray { } } +impl Broadcastable for ExtensionArray { + fn broadcast(&self, num: usize) -> DaftResult { + if self.len() != 1 { + return Err(DaftError::ValueError(format!( + "Attempting to broadcast non-unit length Array named: {}", + self.name() + ))); + } + let array = self.data(); + let mut growable = arrow2::array::growable::make_growable(&[array], true, num); + for _ in 0..num { + growable.extend(0, 0, 1); + } + ExtensionArray::new(self.field.clone(), growable.as_box()) + } +} + #[cfg(feature = "python")] impl Broadcastable for crate::datatypes::PythonArray { fn broadcast(&self, num: usize) -> DaftResult { diff --git a/src/array/ops/compare_agg.rs b/src/array/ops/compare_agg.rs index 30f345d565..6716c12862 100644 --- a/src/array/ops/compare_agg.rs +++ b/src/array/ops/compare_agg.rs @@ -337,6 +337,7 @@ impl_todo_daft_comparable!(BinaryArray); impl_todo_daft_comparable!(StructArray); impl_todo_daft_comparable!(FixedSizeListArray); impl_todo_daft_comparable!(ListArray); +impl_todo_daft_comparable!(ExtensionArray); #[cfg(feature = "python")] impl_todo_daft_comparable!(PythonArray); diff --git a/src/array/ops/filter.rs b/src/array/ops/filter.rs index 411e8a4747..df2516c360 100644 --- a/src/array/ops/filter.rs +++ b/src/array/ops/filter.rs @@ -1,9 +1,6 @@ use crate::{ array::DataArray, - datatypes::{ - logical::DateArray, BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, - ListArray, NullArray, StructArray, Utf8Array, - }, + datatypes::{logical::DateArray, BooleanArray, DaftArrowBackedType}, error::DaftResult, }; @@ -11,63 +8,10 @@ use super::as_arrow::AsArrow; impl DataArray where - T: DaftNumericType, + T: DaftArrowBackedType, { pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?; - Self::try_from((self.field.clone(), result)) - } -} - -impl Utf8Array { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?; - Self::try_from((self.field.clone(), result)) - } -} - -impl BinaryArray { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?; - Self::try_from((self.field.clone(), result)) - } -} - -impl BooleanArray { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?; - Self::try_from((self.field.clone(), result)) - } -} - -impl NullArray { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let set_bits = mask.len() - mask.as_arrow().values().unset_bits(); - Ok(NullArray::full_null( - self.name(), - self.data_type(), - set_bits, - )) - } -} - -impl ListArray { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?; - Self::try_from((self.field.clone(), result)) - } -} - -impl FixedSizeListArray { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?; - Self::try_from((self.field.clone(), result)) - } -} - -impl StructArray { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { - let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?; + let result = arrow2::compute::filter::filter(self.data(), mask.as_arrow())?; Self::try_from((self.field.clone(), result)) } } diff --git a/src/array/ops/full.rs b/src/array/ops/full.rs index a142441fa1..ff0b59af16 100644 --- a/src/array/ops/full.rs +++ b/src/array/ops/full.rs @@ -29,10 +29,7 @@ where let arrow_dtype = dtype.to_arrow(); match arrow_dtype { Ok(arrow_dtype) => DataArray::::new( - Arc::new(Field { - name: name.to_string(), - dtype: dtype.clone(), - }), + Arc::new(Field::new(name.to_string(), dtype.clone())), arrow2::array::new_null_array(arrow_dtype, length), ) .unwrap(), @@ -54,10 +51,7 @@ where let arrow_dtype = dtype.to_arrow(); match arrow_dtype { Ok(arrow_dtype) => DataArray::::new( - Arc::new(Field { - name: name.to_string(), - dtype: dtype.clone(), - }), + Arc::new(Field::new(name.to_string(), dtype.clone())), arrow2::array::new_empty_array(arrow_dtype), ) .unwrap(), diff --git a/src/array/ops/if_else.rs b/src/array/ops/if_else.rs index 3125f51bb7..10227d56b5 100644 --- a/src/array/ops/if_else.rs +++ b/src/array/ops/if_else.rs @@ -1,8 +1,8 @@ use crate::array::DataArray; use crate::datatypes::logical::DateArray; use crate::datatypes::{ - BinaryArray, BooleanArray, DaftArrowBackedType, DaftNumericType, Field, FixedSizeListArray, - ListArray, NullArray, StructArray, Utf8Array, + BinaryArray, BooleanArray, DaftArrowBackedType, DaftNumericType, ExtensionArray, Field, + FixedSizeListArray, ListArray, NullArray, StructArray, Utf8Array, }; use crate::error::{DaftError, DaftResult}; use crate::utils::arrow::arrow_bitmap_and_helper; @@ -229,68 +229,60 @@ impl PythonArray { } } -fn from_arrow_if_then_else( - predicate: &BooleanArray, - if_true: &DataArray, - if_false: &DataArray, -) -> DaftResult> -where - DataArray: - AsArrow + for<'a> TryFrom<(Arc, Box), Error = DaftError>, - as AsArrow>::Output: arrow2::array::Array, -{ - let result = arrow2::compute::if_then_else::if_then_else( - predicate.as_arrow(), - if_true.as_arrow(), - if_false.as_arrow(), - )?; - DataArray::try_from((if_true.field.clone(), result)) -} - fn nested_if_then_else( predicate: &BooleanArray, if_true: &DataArray, if_false: &DataArray, ) -> DaftResult> where - DataArray: AsArrow - + Broadcastable + DataArray: Broadcastable + for<'a> TryFrom<(Arc, Box), Error = DaftError>, - as AsArrow>::Output: arrow2::array::Array, { // TODO(Clark): Support streaming broadcasting, i.e. broadcasting without inflating scalars to full array length. - match (predicate.len(), if_true.len(), if_false.len()) { + let result = match (predicate.len(), if_true.len(), if_false.len()) { (predicate_len, if_true_len, if_false_len) - if predicate_len == if_true_len && if_true_len == if_false_len => from_arrow_if_then_else(predicate, if_true, if_false), - (1, if_true_len, 1) => from_arrow_if_then_else( - &predicate.broadcast(if_true_len)?, - if_true, - &if_false.broadcast(if_true_len)?, - ), - (1, 1, if_false_len) => from_arrow_if_then_else( - &predicate.broadcast(if_false_len)?, - &if_true.broadcast(if_false_len)?, - if_false, - ), - (predicate_len, 1, 1) => from_arrow_if_then_else( - predicate, - &if_true.broadcast(predicate_len)?, - &if_false.broadcast(predicate_len)?, - ), - (predicate_len, if_true_len, 1) if predicate_len == if_true_len => from_arrow_if_then_else( - predicate, - if_true, - &if_false.broadcast(predicate_len)?, - ), - (predicate_len, 1, if_false_len) if predicate_len == if_false_len => from_arrow_if_then_else( - predicate, - &if_true.broadcast(predicate_len)?, - if_false, - ), + if predicate_len == if_true_len && if_true_len == if_false_len => + { + arrow2::compute::if_then_else::if_then_else( + predicate.as_arrow(), + if_true.data(), + if_false.data(), + )? + } + (1, if_true_len, 1) => arrow2::compute::if_then_else::if_then_else( + predicate.broadcast(if_true_len)?.as_arrow(), + if_true.data(), + if_false.broadcast(if_true_len)?.data(), + )?, + (1, 1, if_false_len) => arrow2::compute::if_then_else::if_then_else( + predicate.broadcast(if_false_len)?.as_arrow(), + if_true.broadcast(if_false_len)?.data(), + if_false.data(), + )?, + (predicate_len, 1, 1) => arrow2::compute::if_then_else::if_then_else( + predicate.as_arrow(), + if_true.broadcast(predicate_len)?.data(), + if_false.broadcast(predicate_len)?.data(), + )?, + (predicate_len, if_true_len, 1) if predicate_len == if_true_len => { + arrow2::compute::if_then_else::if_then_else( + predicate.as_arrow(), + if_true.data(), + if_false.broadcast(predicate_len)?.data(), + )? + } + (predicate_len, 1, if_false_len) if predicate_len == if_false_len => { + arrow2::compute::if_then_else::if_then_else( + predicate.as_arrow(), + if_true.broadcast(predicate_len)?.data(), + if_false.data(), + )? + } (p, s, o) => { - Err(DaftError::ValueError(format!("Cannot run if_else against arrays with non-broadcastable lengths: if_true={s}, if_false={o}, predicate={p}"))) + return Err(DaftError::ValueError(format!("Cannot run if_else against arrays with non-broadcastable lengths: if_true={s}, if_false={o}, predicate={p}"))); } - } + }; + DataArray::try_from((if_true.field.clone(), result)) } impl ListArray { @@ -319,6 +311,16 @@ impl StructArray { } } +impl ExtensionArray { + pub fn if_else( + &self, + other: &ExtensionArray, + predicate: &BooleanArray, + ) -> DaftResult { + nested_if_then_else(predicate, self, other) + } +} + impl DateArray { pub fn if_else(&self, other: &DateArray, predicate: &BooleanArray) -> DaftResult { let new_array = self.physical.if_else(&other.physical, predicate)?; diff --git a/src/array/ops/sort.rs b/src/array/ops/sort.rs index 62156dcb76..2248f34e33 100644 --- a/src/array/ops/sort.rs +++ b/src/array/ops/sort.rs @@ -2,8 +2,8 @@ use crate::{ array::DataArray, datatypes::{ logical::DateArray, BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, - FixedSizeListArray, Float32Array, Float64Array, ListArray, NullArray, StructArray, - Utf8Array, + ExtensionArray, FixedSizeListArray, Float32Array, Float64Array, ListArray, NullArray, + StructArray, Utf8Array, }, error::DaftResult, kernels::search_sorted::{build_compare_with_nulls, cmp_float}, @@ -577,6 +577,12 @@ impl StructArray { } } +impl ExtensionArray { + pub fn sort(&self, _descending: bool) -> DaftResult { + todo!("impl sort for ExtensionArray") + } +} + #[cfg(feature = "python")] impl PythonArray { pub fn sort(&self, _descending: bool) -> DaftResult { diff --git a/src/array/ops/take.rs b/src/array/ops/take.rs index 1a5d3a0f23..5945f6d66d 100644 --- a/src/array/ops/take.rs +++ b/src/array/ops/take.rs @@ -2,7 +2,7 @@ use crate::{ array::DataArray, datatypes::{ logical::DateArray, BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, - FixedSizeListArray, ListArray, NullArray, StructArray, Utf8Array, + ExtensionArray, FixedSizeListArray, ListArray, NullArray, StructArray, Utf8Array, }, error::DaftResult, }; @@ -291,6 +291,41 @@ impl StructArray { } } +impl ExtensionArray { + #[inline] + pub fn get(&self, idx: usize) -> Option> { + if idx >= self.len() { + panic!("Out of bounds: {} vs len: {}", idx, self.len()) + } + let is_valid = self + .data + .validity() + .map_or(true, |validity| validity.get_bit(idx)); + if is_valid { + Some(arrow2::scalar::new_scalar(self.data(), idx)) + } else { + None + } + } + + pub fn take(&self, idx: &DataArray) -> DaftResult + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + let result = arrow2::compute::take::take(self.data(), idx.as_arrow())?; + Self::try_from((self.field.clone(), result)) + } + + pub fn str_value(&self, idx: usize) -> DaftResult { + let val = self.get(idx); + match val { + None => Ok("None".to_string()), + Some(v) => Ok(format!("{v:?}")), + } + } +} + #[cfg(feature = "python")] impl crate::datatypes::PythonArray { #[inline] diff --git a/src/array/pseudo_arrow/mod.rs b/src/array/pseudo_arrow/mod.rs index cf84631be1..15211ec527 100644 --- a/src/array/pseudo_arrow/mod.rs +++ b/src/array/pseudo_arrow/mod.rs @@ -313,4 +313,8 @@ impl Array for PseudoArrowArray { .map(|x| x.unset_bits()) .unwrap_or(0) } + + fn change_type(&mut self, _: arrow2::datatypes::DataType) { + unimplemented!("PseudoArray doesn't hold a data type and therefore does not support the change_type API.") + } } diff --git a/src/datatypes/dtype.rs b/src/datatypes/dtype.rs index c553e3747d..74b4be17d4 100644 --- a/src/datatypes/dtype.rs +++ b/src/datatypes/dtype.rs @@ -72,6 +72,8 @@ pub enum DataType { List(Box), /// A nested [`DataType`] with a given number of [`Field`]s. Struct(Vec), + /// Extension type. + Extension(String, Box, Option), // Stop ArrowTypes DaftType(Box), Python, @@ -117,6 +119,11 @@ impl DataType { .collect::>>()?; ArrowType::Struct(fields) }), + DataType::Extension(name, dtype, metadata) => Ok(ArrowType::Extension( + name.clone(), + Box::new(dtype.to_arrow()?), + metadata.clone(), + )), _ => Err(DaftError::TypeError(format!( "Can not convert {self:?} into arrow type" ))), @@ -128,12 +135,15 @@ impl DataType { match self { Date => Int32, Duration(_) | Timestamp(..) | Time(_) => Int64, - List(field) => List(Box::new(Field::new( - field.name.clone(), - field.dtype.to_physical(), - ))), + List(field) => List(Box::new( + Field::new(field.name.clone(), field.dtype.to_physical()) + .with_metadata(field.metadata.clone()), + )), FixedSizeList(field, size) => FixedSizeList( - Box::new(Field::new(field.name.clone(), field.dtype.to_physical())), + Box::new( + Field::new(field.name.clone(), field.dtype.to_physical()) + .with_metadata(field.metadata.clone()), + ), *size, ), _ => self.clone(), @@ -159,23 +169,41 @@ impl DataType { // DataType::Float16 | DataType::Float32 | DataType::Float64 => true, + DataType::Extension(_, inner, _) => inner.is_numeric(), _ => false } } #[inline] pub fn is_temporal(&self) -> bool { - matches!(self, DataType::Date) + match self { + DataType::Date => true, + DataType::Extension(_, inner, _) => inner.is_temporal(), + _ => false, + } } #[inline] pub fn is_null(&self) -> bool { - matches!(self, DataType::Null) + match self { + DataType::Null => true, + DataType::Extension(_, inner, _) => inner.is_null(), + _ => false, + } + } + + #[inline] + pub fn is_extension(&self) -> bool { + matches!(self, DataType::Extension(..)) } #[inline] pub fn is_python(&self) -> bool { - matches!(self, DataType::Python) + match self { + DataType::Python => true, + DataType::Extension(_, inner, _) => inner.is_python(), + _ => false, + } } #[inline] @@ -257,6 +285,11 @@ impl From<&ArrowType> for DataType { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); DataType::Struct(fields) } + ArrowType::Extension(name, dtype, metadata) => DataType::Extension( + name.clone(), + Box::new(dtype.as_ref().into()), + metadata.clone(), + ), _ => panic!("DataType :{item:?} is not supported"), } } diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 5a7137f812..130d09ce25 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -1,4 +1,5 @@ use std::fmt::{Display, Formatter, Result}; +use std::sync::Arc; use arrow2::datatypes::Field as ArrowField; @@ -6,46 +7,66 @@ use crate::{datatypes::dtype::DataType, error::DaftResult}; use serde::{Deserialize, Serialize}; +pub type Metadata = std::collections::BTreeMap; + #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)] pub struct Field { pub name: String, pub dtype: DataType, + pub metadata: Arc, } impl Field { pub fn new>(name: S, dtype: DataType) -> Self { - Field { + Self { name: name.into(), dtype, + metadata: Default::default(), } } + + pub fn with_metadata>>(self, metadata: M) -> Self { + Self { + name: self.name, + dtype: self.dtype, + metadata: metadata.into(), + } + } + pub fn to_arrow(&self) -> DaftResult { - Ok(ArrowField::new( - self.name.clone(), - self.dtype.to_arrow()?, - true, - )) + Ok( + ArrowField::new(self.name.clone(), self.dtype.to_arrow()?, true) + .with_metadata(self.metadata.as_ref().clone()), + ) } + pub fn rename>(&self, name: S) -> Self { - Field::new(name, self.dtype.clone()) + Self { + name: name.into(), + dtype: self.dtype.clone(), + metadata: self.metadata.clone(), + } } + pub fn to_list_field(&self) -> DaftResult { if self.dtype.is_python() { return Ok(self.clone()); } let list_dtype = DataType::List(Box::new(self.clone())); - Ok(Field { + Ok(Self { name: self.name.clone(), dtype: list_dtype, + metadata: self.metadata.clone(), }) } } impl From<&ArrowField> for Field { fn from(af: &ArrowField) -> Self { - Field { + Self { name: af.name.clone(), dtype: af.data_type().into(), + metadata: af.metadata.clone().into(), } } } diff --git a/src/datatypes/matching.rs b/src/datatypes/matching.rs index 9c27ce3cc4..a1872629c7 100644 --- a/src/datatypes/matching.rs +++ b/src/datatypes/matching.rs @@ -29,6 +29,7 @@ macro_rules! with_match_daft_types {( FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, Struct(_) => __with_ty__! { StructType }, + Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, _ => panic!("{:?} not implemented for with_match_daft_types", $key_type) @@ -63,6 +64,7 @@ macro_rules! with_match_physical_daft_types {( FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, Struct(_) => __with_ty__! { StructType }, + Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, _ => panic!("{:?} not implemented for with_match_physical_daft_types", $key_type) @@ -97,6 +99,7 @@ macro_rules! with_match_arrow_daft_types {( List(_) => __with_ty__! { ListType }, FixedSizeList(..) => __with_ty__! { FixedSizeListType }, Struct(_) => __with_ty__! { StructType }, + Extension(_, _, _) => __with_ty__! { ExtensionType }, Utf8 => __with_ty__! { Utf8Type }, _ => panic!("{:?} not implemented", $key_type) } diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index a91f73c879..71d8e51353 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -97,6 +97,7 @@ impl_daft_arrow_datatype!(Utf8Type, Utf8); impl_daft_arrow_datatype!(FixedSizeListType, Unknown); impl_daft_arrow_datatype!(ListType, Unknown); impl_daft_arrow_datatype!(StructType, Unknown); +impl_daft_arrow_datatype!(ExtensionType, Unknown); #[cfg(feature = "python")] impl_daft_non_arrow_datatype!(PythonType, Python); @@ -243,6 +244,7 @@ pub type Utf8Array = DataArray; pub type FixedSizeListArray = DataArray; pub type ListArray = DataArray; pub type StructArray = DataArray; +pub type ExtensionArray = DataArray; #[cfg(feature = "python")] pub type PythonArray = DataArray; diff --git a/src/dsl/functions/temporal/day.rs b/src/dsl/functions/temporal/day.rs index 1c96a8e622..4927d17d57 100644 --- a/src/dsl/functions/temporal/day.rs +++ b/src/dsl/functions/temporal/day.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for DayEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::UInt32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::UInt32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to day to be temporal, got {}", field.dtype diff --git a/src/dsl/functions/temporal/day_of_week.rs b/src/dsl/functions/temporal/day_of_week.rs index 3630c54659..d943e8e90b 100644 --- a/src/dsl/functions/temporal/day_of_week.rs +++ b/src/dsl/functions/temporal/day_of_week.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for DayOfWeekEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::UInt32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::UInt32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to day to be temporal, got {}", field.dtype diff --git a/src/dsl/functions/temporal/month.rs b/src/dsl/functions/temporal/month.rs index e9247ab1d4..3d03523d88 100644 --- a/src/dsl/functions/temporal/month.rs +++ b/src/dsl/functions/temporal/month.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for MonthEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::UInt32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::UInt32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to month to be temporal, got {}", field.dtype diff --git a/src/dsl/functions/temporal/year.rs b/src/dsl/functions/temporal/year.rs index 9aa5579399..96713118d4 100644 --- a/src/dsl/functions/temporal/year.rs +++ b/src/dsl/functions/temporal/year.rs @@ -18,10 +18,9 @@ impl FunctionEvaluator for YearEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { match inputs { [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => Ok(Field { - name: field.name, - dtype: DataType::Int32, - }), + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::Int32)) + } Ok(field) => Err(DaftError::TypeError(format!( "Expected input to year to be temporal, got {}", field.dtype diff --git a/src/ffi.rs b/src/ffi.rs index 930fcab003..0ac6878e44 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -9,8 +9,11 @@ use pyo3::types::PyList; use pyo3::{PyAny, PyObject, PyResult, Python}; use crate::{ - error::DaftResult, schema::SchemaRef, series::Series, table::Table, - utils::arrow::cast_array_if_needed, + error::DaftResult, + schema::SchemaRef, + series::Series, + table::Table, + utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed}, }; pub type ArrayRef = Box; @@ -71,7 +74,7 @@ pub fn record_batches_to_table( .into_iter() .enumerate() .map(|(i, c)| { - let c = cast_array_if_needed(c); + let c = cast_array_for_daft_if_needed(c); Series::try_from((names.get(i).unwrap().as_str(), c)) }) .collect::>>()?; @@ -110,6 +113,7 @@ pub fn table_to_record_batch(table: &Table, py: Python, pyarrow: &PyModule) -> P for i in 0..table.num_columns() { let s = table.get_column_by_index(i)?; let arrow_array = s.to_arrow(); + let arrow_array = cast_array_from_daft_if_needed(arrow_array.to_boxed()); let py_array = to_py_array(arrow_array, py, pyarrow)?; arrays.push(py_array); names.push(s.name().to_string()); diff --git a/src/lib.rs b/src/lib.rs index 20df02fbf9..c4c82f7c15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,7 @@ #![feature(hash_raw_entry)] +#[macro_use] +extern crate lazy_static; + mod array; mod datatypes; mod dsl; diff --git a/src/python/datatype.rs b/src/python/datatype.rs index 7bb3650c2a..10ef6faf68 100644 --- a/src/python/datatype.rs +++ b/src/python/datatype.rs @@ -140,6 +140,20 @@ impl PyDataType { .into()) } + #[staticmethod] + pub fn extension( + name: &str, + storage_data_type: Self, + metadata: Option<&str>, + ) -> PyResult { + Ok(DataType::Extension( + name.to_string(), + Box::new(storage_data_type.dtype), + metadata.map(|s| s.to_string()), + ) + .into()) + } + #[staticmethod] pub fn python() -> PyResult { Ok(DataType::Python.into()) diff --git a/src/python/series.rs b/src/python/series.rs index 3a08f4e322..590a419dd8 100644 --- a/src/python/series.rs +++ b/src/python/series.rs @@ -7,7 +7,7 @@ use crate::{ datatypes::{DataType, Field, PythonType, UInt64Type}, ffi, series::{self, IntoSeries, Series}, - utils::arrow::cast_array_if_needed, + utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed}, }; use super::datatype::PyDataType; @@ -24,7 +24,7 @@ impl PySeries { #[staticmethod] pub fn from_arrow(name: &str, pyarrow_array: &PyAny) -> PyResult { let arrow_array = ffi::array_to_rust(pyarrow_array)?; - let arrow_array = cast_array_if_needed(arrow_array.to_boxed()); + let arrow_array = cast_array_for_daft_if_needed(arrow_array.to_boxed()); let series = series::Series::try_from((name, arrow_array))?; Ok(series.into()) } @@ -51,6 +51,7 @@ impl PySeries { pub fn to_arrow(&self) -> PyResult { let arrow_array = self.series.to_arrow(); + let arrow_array = cast_array_from_daft_if_needed(arrow_array); Python::with_gil(|py| { let pyarrow = py.import("pyarrow")?; ffi::to_py_array(arrow_array, py, pyarrow) diff --git a/src/series/array_impl/data_array.rs b/src/series/array_impl/data_array.rs index 4ca8fe7512..eab60ed7d4 100644 --- a/src/series/array_impl/data_array.rs +++ b/src/series/array_impl/data_array.rs @@ -12,9 +12,9 @@ use crate::datatypes::PythonArray; use crate::series::Field; use crate::{ datatypes::{ - BinaryArray, BooleanArray, FixedSizeListArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, ListArray, NullArray, StructArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, Utf8Array, + BinaryArray, BooleanArray, ExtensionArray, FixedSizeListArray, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, ListArray, NullArray, StructArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }, error::DaftResult, series::series_like::SeriesLike, @@ -273,5 +273,6 @@ impl_series_like_for_data_array!(Utf8Array); impl_series_like_for_data_array!(FixedSizeListArray); impl_series_like_for_data_array!(ListArray); impl_series_like_for_data_array!(StructArray); +impl_series_like_for_data_array!(ExtensionArray); #[cfg(feature = "python")] impl_series_like_for_data_array!(PythonArray); diff --git a/src/utils/arrow.rs b/src/utils/arrow.rs index 122ac36021..1f19e1d80b 100644 --- a/src/utils/arrow.rs +++ b/src/utils/arrow.rs @@ -1,105 +1,138 @@ +use std::collections::HashMap; +use std::sync::Mutex; + use arrow2::compute::cast; -pub fn cast_array_if_needed( - arrow_array: Box, -) -> Box { - match arrow_array.data_type() { - arrow2::datatypes::DataType::Utf8 => { - cast::utf8_to_large_utf8(arrow_array.as_any().downcast_ref().unwrap()).boxed() - } - arrow2::datatypes::DataType::Binary => cast::binary_to_large_binary( - arrow_array.as_any().downcast_ref().unwrap(), - arrow2::datatypes::DataType::LargeBinary, - ) - .boxed(), +// TODO(Clark): Refactor to GILOnceCell in order to avoid deadlock between the below mutex and the Python GIL. +lazy_static! { + static ref REGISTRY: Mutex> = + Mutex::new(HashMap::new()); +} + +fn coerce_to_daft_compatible_type( + dtype: &arrow2::datatypes::DataType, +) -> Option { + match dtype { + arrow2::datatypes::DataType::Utf8 => Some(arrow2::datatypes::DataType::LargeUtf8), + arrow2::datatypes::DataType::Binary => Some(arrow2::datatypes::DataType::LargeBinary), arrow2::datatypes::DataType::List(field) => { - let array = arrow_array - .as_any() - .downcast_ref::>() - .unwrap(); - let new_values = cast_array_if_needed(array.values().clone()); - let offsets = array.offsets().into(); - arrow2::array::ListArray::::new( - arrow2::datatypes::DataType::LargeList(Box::new(arrow2::datatypes::Field::new( - field.name.clone(), - new_values.data_type().clone(), - field.is_nullable, - ))), - offsets, - new_values, - arrow_array.validity().cloned(), - ) - .boxed() + let new_field = match coerce_to_daft_compatible_type(field.data_type()) { + Some(new_inner_dtype) => Box::new( + arrow2::datatypes::Field::new( + field.name.clone(), + new_inner_dtype, + field.is_nullable, + ) + .with_metadata(field.metadata.clone()), + ), + None => field.clone(), + }; + Some(arrow2::datatypes::DataType::LargeList(new_field)) } arrow2::datatypes::DataType::LargeList(field) => { - // Types nested within LargeList may need casting. - let array = arrow_array - .as_any() - .downcast_ref::>() - .unwrap(); - let new_values = cast_array_if_needed(array.values().clone()); - if new_values.data_type() == array.values().data_type() { - return arrow_array; - } - arrow2::array::ListArray::::new( - arrow2::datatypes::DataType::LargeList(Box::new(arrow2::datatypes::Field::new( + let new_inner_dtype = coerce_to_daft_compatible_type(field.data_type())?; + Some(arrow2::datatypes::DataType::LargeList(Box::new( + arrow2::datatypes::Field::new( field.name.clone(), - new_values.data_type().clone(), + new_inner_dtype, field.is_nullable, - ))), - array.offsets().clone(), - new_values, - arrow_array.validity().cloned(), - ) - .boxed() + ) + .with_metadata(field.metadata.clone()), + ))) } arrow2::datatypes::DataType::FixedSizeList(field, size) => { - // Types nested within FixedSizeList may need casting. - let array = arrow_array - .as_any() - .downcast_ref::() - .unwrap(); - let new_values = cast_array_if_needed(array.values().clone()); - if new_values.data_type() == array.values().data_type() { - return arrow_array; - } - arrow2::array::FixedSizeListArray::new( - arrow2::datatypes::DataType::FixedSizeList( - Box::new(arrow2::datatypes::Field::new( + let new_inner_dtype = coerce_to_daft_compatible_type(field.data_type())?; + Some(arrow2::datatypes::DataType::FixedSizeList( + Box::new( + arrow2::datatypes::Field::new( field.name.clone(), - new_values.data_type().clone(), + new_inner_dtype, field.is_nullable, - )), - *size, + ) + .with_metadata(field.metadata.clone()), ), - new_values, - arrow_array.validity().cloned(), - ) - .boxed() + *size, + )) } arrow2::datatypes::DataType::Struct(fields) => { - let new_arrays = arrow_array - .as_any() - .downcast_ref::() - .unwrap() - .values() - .iter() - .map(|field_arr| cast_array_if_needed(field_arr.clone())) - .collect::>>(); let new_fields = fields .iter() - .zip(new_arrays.iter().map(|arr| arr.data_type().clone())) - .map(|(field, dtype)| { - arrow2::datatypes::Field::new(field.name.clone(), dtype, field.is_nullable) - }) - .collect(); - Box::new(arrow2::array::StructArray::new( - arrow2::datatypes::DataType::Struct(new_fields), - new_arrays, - arrow_array.validity().cloned(), + .map( + |field| match coerce_to_daft_compatible_type(field.data_type()) { + Some(new_inner_dtype) => arrow2::datatypes::Field::new( + field.name.clone(), + new_inner_dtype, + field.is_nullable, + ) + .with_metadata(field.metadata.clone()), + None => field.clone(), + }, + ) + .collect::>(); + if &new_fields == fields { + None + } else { + Some(arrow2::datatypes::DataType::Struct(new_fields)) + } + } + arrow2::datatypes::DataType::Extension(name, inner, metadata) => { + let new_inner_dtype = coerce_to_daft_compatible_type(inner.as_ref())?; + REGISTRY.lock().unwrap().insert(name.clone(), dtype.clone()); + Some(arrow2::datatypes::DataType::Extension( + name.clone(), + Box::new(new_inner_dtype), + metadata.clone(), )) } - _ => arrow_array, + _ => None, + } +} + +pub fn cast_array_for_daft_if_needed( + arrow_array: Box, +) -> Box { + match coerce_to_daft_compatible_type(arrow_array.data_type()) { + Some(coerced_dtype) => cast::cast( + arrow_array.as_ref(), + &coerced_dtype, + cast::CastOptions { + wrapped: true, + partial: false, + }, + ) + .unwrap(), + None => arrow_array, + } +} + +fn coerce_from_daft_compatible_type( + dtype: &arrow2::datatypes::DataType, +) -> Option { + match dtype { + arrow2::datatypes::DataType::Extension(name, _, _) + if REGISTRY.lock().unwrap().contains_key(name) => + { + let entry = REGISTRY.lock().unwrap(); + Some(entry.get(name).unwrap().clone()) + } + _ => None, + } +} + +pub fn cast_array_from_daft_if_needed( + arrow_array: Box, +) -> Box { + match coerce_from_daft_compatible_type(arrow_array.data_type()) { + Some(coerced_dtype) => cast::cast( + arrow_array.as_ref(), + &coerced_dtype, + cast::CastOptions { + wrapped: true, + partial: false, + }, + ) + .unwrap(), + None => arrow_array, } } diff --git a/tests/conftest.py b/tests/conftest.py index 217c9c459d..000bf8f622 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,30 @@ from __future__ import annotations import pandas as pd +import pyarrow as pa +import pytest + + +class UuidType(pa.ExtensionType): + NAME = "daft.uuid" + + def __init__(self): + pa.ExtensionType.__init__(self, pa.binary(), self.NAME) + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(self, storage_type, serialized): + return UuidType() + + +@pytest.fixture +def uuid_ext_type() -> UuidType: + ext_type = UuidType() + pa.register_extension_type(ext_type) + yield ext_type + pa.unregister_extension_type(ext_type.NAME) def assert_df_equals( diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 8a574281d8..ff09560e39 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -15,8 +15,12 @@ import daft from daft.api_annotations import APITypeError +from daft.context import get_context from daft.dataframe import DataFrame from daft.datatype import DataType +from tests.conftest import UuidType + +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) class MyObj: @@ -151,9 +155,10 @@ def test_create_dataframe_arrow(valid_data: list[dict[str, float]], multiple) -> assert df.to_arrow() == expected -def test_create_dataframe_arrow_tensor(valid_data: list[dict[str, float]]) -> None: +def test_create_dataframe_arrow_tensor_ray(valid_data: list[dict[str, float]]) -> None: pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} - pydict["obj"] = ArrowTensorArray.from_numpy(np.ones((len(valid_data), 2, 2))) + ata = ArrowTensorArray.from_numpy(np.ones((len(valid_data), 2, 2))) + pydict["obj"] = ata t = pa.Table.from_pydict(pydict) df = daft.from_arrow(t) assert set(df.column_names) == set(t.column_names) @@ -165,6 +170,87 @@ def test_create_dataframe_arrow_tensor(valid_data: list[dict[str, float]]) -> No assert df.to_arrow() == expected +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="Pickling canonical tensor extension type is not supported by pyarrow", +) +def test_create_dataframe_arrow_tensor_canonical(valid_data: list[dict[str, float]]) -> None: + pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} + dtype = pa.fixed_shape_tensor(pa.int64(), (2, 2)) + storage = pa.array([list(range(4 * i, 4 * (i + 1))) for i in range(len(valid_data))], pa.list_(pa.int64(), 4)) + ata = pa.ExtensionArray.from_storage(dtype, storage) + pydict["obj"] = ata + t = pa.Table.from_pydict(pydict) + df = daft.from_arrow(t) + assert set(df.column_names) == set(t.column_names) + assert df.schema()["obj"].dtype == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(dtype.storage_type), '{"shape":[2,2]}' + ) + casted_field = t.schema.field("variety").with_type(pa.large_string()) + expected = t.cast(t.schema.set(t.schema.get_field_index("variety"), casted_field)) + # Check roundtrip. + assert df.to_arrow() == expected + + +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +def test_create_dataframe_arrow_extension_type(valid_data: list[dict[str, float]], uuid_ext_type: UuidType) -> None: + pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} + storage = pa.array([f"{i}".encode() for i in range(len(valid_data))]) + pydict["obj"] = pa.ExtensionArray.from_storage(uuid_ext_type, storage) + t = pa.Table.from_pydict(pydict) + df = daft.from_arrow(t) + assert set(df.column_names) == set(t.column_names) + assert df.schema()["obj"].dtype == DataType.extension( + uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), "" + ) + casted_field = t.schema.field("variety").with_type(pa.large_string()) + expected = t.cast(t.schema.set(t.schema.get_field_index("variety"), casted_field)) + # Check roundtrip. + assert df.to_arrow() == expected + + +# TODO(Clark): Remove this test once pyarrow extension types are supported for Ray clusters. +@pytest.mark.skipif( + get_context().runner_config.name != "ray", + reason="This test requires the Ray runner.", +) +def test_create_dataframe_arrow_extension_type_fails_for_ray( + valid_data: list[dict[str, float]], uuid_ext_type: UuidType +) -> None: + pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} + storage = pa.array([f"{i}".encode() for i in range(len(valid_data))]) + pydict["obj"] = pa.ExtensionArray.from_storage(uuid_ext_type, storage) + t = pa.Table.from_pydict(pydict) + with pytest.raises(ValueError): + daft.from_arrow(t).to_arrow() + + +class PyExtType(pa.PyExtensionType): + def __init__(self): + pa.PyExtensionType.__init__(self, pa.binary()) + + def __reduce__(self): + return PyExtType, () + + +def test_create_dataframe_arrow_py_ext_type_raises(valid_data: list[dict[str, float]]) -> None: + pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} + uuid_type = PyExtType() + storage_array = pa.array([f"foo-{i}".encode() for i in range(len(valid_data))], pa.binary()) + arr = pa.ExtensionArray.from_storage(uuid_type, storage_array) + pydict["obj"] = arr + t = pa.Table.from_pydict(pydict) + with pytest.raises(ValueError): + daft.from_arrow(t) + + def test_create_dataframe_arrow_unsupported_dtype(valid_data: list[dict[str, float]]) -> None: pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()} pydict["obj"] = [datetime.datetime.now() for _ in range(len(valid_data))] diff --git a/tests/series/test_concat.py b/tests/series/test_concat.py index 3eac05c416..b66e40c656 100644 --- a/tests/series/test_concat.py +++ b/tests/series/test_concat.py @@ -2,12 +2,18 @@ import itertools +import numpy as np import pyarrow as pa import pytest +from ray.data.extensions import ArrowTensorArray from daft import DataType, Series +from daft.context import get_context +from tests.conftest import * from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + class MockObject: def __init__(self, test_val): @@ -85,6 +91,85 @@ def test_series_concat_struct_array(chunks) -> None: counter += 1 +@pytest.mark.parametrize("chunks", [1, 2, 3, 10]) +def test_series_concat_tensor_array_ray(chunks) -> None: + element_shape = (2, 2) + num_elements_per_tensor = np.prod(element_shape) + chunk_size = 3 + chunk_shape = (chunk_size,) + element_shape + chunks = [ + np.arange(i * chunk_size * num_elements_per_tensor, (i + 1) * chunk_size * num_elements_per_tensor).reshape( + chunk_shape + ) + for i in range(chunks) + ] + series = [Series.from_arrow(ArrowTensorArray.from_numpy(chunk)) for chunk in chunks] + + concated = Series.concat(series) + + assert concated.datatype() == DataType.python() + expected = [chunk[i] for chunk in chunks for i in range(len(chunk))] + np.testing.assert_equal(concated.to_pylist(), expected) + + +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="Pickling canonical tensor extension type is not supported by pyarrow", +) +@pytest.mark.parametrize("chunks", [1, 2, 3, 10]) +def test_series_concat_tensor_array_canonical(chunks) -> None: + element_shape = (2, 2) + num_elements_per_tensor = np.prod(element_shape) + chunk_size = 3 + chunk_shape = (chunk_size,) + element_shape + chunks = [ + np.arange(i * chunk_size * num_elements_per_tensor, (i + 1) * chunk_size * num_elements_per_tensor).reshape( + chunk_shape + ) + for i in range(chunks) + ] + ext_arrays = [pa.FixedShapeTensorArray.from_numpy_ndarray(chunk) for chunk in chunks] + series = [Series.from_arrow(ext_array) for ext_array in ext_arrays] + + concated = Series.concat(series) + + assert concated.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(ext_arrays[0].type.storage_type), '{"shape":[2,2]}' + ) + expected = [chunk[i] for chunk in chunks for i in range(len(chunk))] + concated_arrow = concated.to_arrow() + assert isinstance(concated_arrow.type, pa.FixedShapeTensorType) + np.testing.assert_equal(concated_arrow.to_numpy_ndarray(), expected) + + +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +@pytest.mark.parametrize("chunks", [1, 2, 3, 10]) +def test_series_concat_extension_type(uuid_ext_type, chunks) -> None: + chunk_size = 3 + storage_arrays = [ + pa.array([f"{i}".encode() for i in range(j * chunk_size, (j + 1) * chunk_size)]) for j in range(chunks) + ] + ext_arrays = [pa.ExtensionArray.from_storage(uuid_ext_type, storage) for storage in storage_arrays] + series = [Series.from_arrow(ext_array) for ext_array in ext_arrays] + + concated = Series.concat(series) + + assert concated.datatype() == DataType.extension( + uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), "" + ) + concated_arrow = concated.to_arrow() + assert isinstance(concated_arrow.type, UuidType) + assert concated_arrow.type == uuid_ext_type + assert concated_arrow == pa.concat_arrays(ext_arrays) + + @pytest.mark.parametrize("chunks", [1, 2, 3, 10]) def test_series_concat_pyobj(chunks) -> None: series = [] diff --git a/tests/series/test_filter.py b/tests/series/test_filter.py index 22877af565..238a16923a 100644 --- a/tests/series/test_filter.py +++ b/tests/series/test_filter.py @@ -1,12 +1,16 @@ from __future__ import annotations +import numpy as np import pyarrow as pa import pytest +from daft.context import get_context from daft.datatype import DataType from daft.series import Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + @pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES + ARROW_STRING_TYPES) def test_series_filter(dtype) -> None: @@ -110,6 +114,48 @@ def test_series_filter_on_struct_array() -> None: assert result.to_pylist() == expected +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +def test_series_filter_on_extension_array(uuid_ext_type) -> None: + arr = pa.array(f"{i}".encode() for i in range(5)) + data = pa.ExtensionArray.from_storage(uuid_ext_type, arr) + + s = Series.from_arrow(data) + pymask = [False, True, True, None, False] + mask = Series.from_pylist(pymask) + + result = s.filter(mask) + + assert s.datatype() == result.datatype() + expected = [val for val, keep in zip(s.to_pylist(), pymask) if keep] + assert result.to_pylist() == expected + + +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="Pickling canonical tensor extension type is not supported by pyarrow", +) +def test_series_filter_on_canonical_tensor_extension_array() -> None: + arr = np.arange(20).reshape((5, 2, 2)) + data = pa.FixedShapeTensorArray.from_numpy_ndarray(arr) + + s = Series.from_arrow(data) + pymask = [False, True, True, None, False] + mask = Series.from_pylist(pymask) + + result = s.filter(mask) + + assert s.datatype() == result.datatype() + expected = [val for val, keep in zip(s.to_pylist(), pymask) if keep] + assert result.to_pylist() == expected + + @pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES + ARROW_STRING_TYPES) def test_series_filter_broadcast(dtype) -> None: data = pa.array([1, 2, 3, None, 5, None]) diff --git a/tests/series/test_if_else.py b/tests/series/test_if_else.py index 552b960430..2974428d49 100644 --- a/tests/series/test_if_else.py +++ b/tests/series/test_if_else.py @@ -1,11 +1,15 @@ from __future__ import annotations +import numpy as np import pyarrow as pa import pytest from daft import Series +from daft.context import get_context from daft.datatype import DataType +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + @pytest.mark.parametrize("if_true_value", [1, None]) @pytest.mark.parametrize("if_false_value", [0, None]) @@ -293,6 +297,113 @@ def test_series_if_else_struct(if_true, if_false, expected) -> None: assert result.to_pylist() == expected +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +@pytest.mark.parametrize( + ["if_true_storage", "if_false_storage", "expected_storage"], + [ + # Same length, same type + ( + pa.array([f"{i}".encode() for i in range(4)]), + pa.array([f"{i}".encode() for i in range(4, 8)]), + pa.array([b"0", b"5", None, b"3"]), + ), + # Broadcast left + ( + pa.array([b"0"]), + pa.array([f"{i}".encode() for i in range(4, 8)]), + pa.array([b"0", b"5", None, b"0"]), + ), + # Broadcast right + ( + pa.array([f"{i}".encode() for i in range(4)]), + pa.array([b"4"]), + pa.array([b"0", b"4", None, b"3"]), + ), + # Broadcast both + ( + pa.array([b"0"]), + pa.array([b"4"]), + pa.array([b"0", b"4", None, b"0"]), + ), + ], +) +def test_series_if_else_extension_type(uuid_ext_type, if_true_storage, if_false_storage, expected_storage) -> None: + if_true_arrow = pa.ExtensionArray.from_storage(uuid_ext_type, if_true_storage) + if_false_arrow = pa.ExtensionArray.from_storage(uuid_ext_type, if_false_storage) + expected_arrow = pa.ExtensionArray.from_storage(uuid_ext_type, expected_storage) + if_true_series = Series.from_arrow(if_true_arrow) + if_false_series = Series.from_arrow(if_false_arrow) + predicate_series = Series.from_arrow(pa.array([True, False, None, True])) + + result = predicate_series.if_else(if_true_series, if_false_series) + + assert result.datatype() == DataType.extension( + uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), "" + ) + result_arrow = result.to_arrow() + assert result_arrow == expected_arrow + + +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="Pickling canonical tensor extension type is not supported by pyarrow", +) +@pytest.mark.parametrize( + ["if_true", "if_false", "expected"], + [ + # Same length, same type + ( + np.arange(16).reshape((4, 2, 2)), + np.arange(16, 32).reshape((4, 2, 2)), + np.array( + [[[0, 1], [2, 3]], [[20, 21], [22, 23]], [[np.nan, np.nan], [np.nan, np.nan]], [[12, 13], [14, 15]]] + ), + ), + # Broadcast left + ( + np.arange(4).reshape((1, 2, 2)), + np.arange(16, 32).reshape((4, 2, 2)), + np.array([[[0, 1], [2, 3]], [[20, 21], [22, 23]], [[np.nan, np.nan], [np.nan, np.nan]], [[0, 1], [2, 3]]]), + ), + # Broadcast right + ( + np.arange(16).reshape((4, 2, 2)), + np.arange(16, 20).reshape((1, 2, 2)), + np.array( + [[[0, 1], [2, 3]], [[16, 17], [18, 19]], [[np.nan, np.nan], [np.nan, np.nan]], [[12, 13], [14, 15]]] + ), + ), + # Broadcast both + ( + np.arange(4).reshape((1, 2, 2)), + np.arange(16, 20).reshape((1, 2, 2)), + np.array([[[0, 1], [2, 3]], [[16, 17], [18, 19]], [[np.nan, np.nan], [np.nan, np.nan]], [[0, 1], [2, 3]]]), + ), + ], +) +def test_series_if_else_canonical_tensor_extension_type(if_true, if_false, expected) -> None: + if_true_arrow = pa.FixedShapeTensorArray.from_numpy_ndarray(if_true) + if_false_arrow = pa.FixedShapeTensorArray.from_numpy_ndarray(if_false) + if_true_series = Series.from_arrow(if_true_arrow) + if_false_series = Series.from_arrow(if_false_arrow) + predicate_series = Series.from_arrow(pa.array([True, False, None, True])) + + result = predicate_series.if_else(if_true_series, if_false_series) + + assert result.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(if_true_arrow.type.storage_type), '{"shape":[2,2]}' + ) + result_arrow = result.to_arrow() + np.testing.assert_equal(result_arrow.to_numpy_ndarray(), expected) + + @pytest.mark.parametrize( "if_true_length", [1, 3], diff --git a/tests/series/test_size_bytes.py b/tests/series/test_size_bytes.py index 93f4703a50..9b50766175 100644 --- a/tests/series/test_size_bytes.py +++ b/tests/series/test_size_bytes.py @@ -3,14 +3,17 @@ import itertools import math +import numpy as np import pyarrow as pa import pytest +from daft.context import get_context from daft.datatype import DataType from daft.series import Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES -PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) +PYARROW_GE_7_0_0 = ARROW_VERSION >= (7, 0, 0) def get_total_buffer_size(arr: pa.Array) -> int: @@ -177,3 +180,60 @@ def test_series_struct_size_bytes(size, with_nulls) -> None: ) else: assert s.size_bytes() == get_total_buffer_size(data) + conversion_to_large_string_bytes + + +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +@pytest.mark.parametrize("size", [1, 2, 8, 9, 16]) +@pytest.mark.parametrize("with_nulls", [True, False]) +def test_series_extension_type_size_bytes(uuid_ext_type, size, with_nulls) -> None: + pydata = [f"{i}".encode() for i in range(size)] + + # TODO(Clark): Change to size > 0 condition when pyarrow extension arrays support generic construction on null arrays. + if with_nulls and size > 1: + pydata = pydata[:-1] + [None] + storage = pa.array(pydata) + data = pa.ExtensionArray.from_storage(uuid_ext_type, storage) + + s = Series.from_arrow(data) + + size_bytes = s.size_bytes() + + assert s.datatype() == DataType.extension( + uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), "" + ) + post_daft_cast_data = storage.cast(pa.large_binary()) + assert size_bytes == get_total_buffer_size(post_daft_cast_data) + + +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="Pickling canonical tensor extension type is not supported by pyarrow", +) +@pytest.mark.parametrize("dtype, size", itertools.product(ARROW_INT_TYPES + ARROW_FLOAT_TYPES, [0, 1, 2, 8, 9, 16])) +@pytest.mark.parametrize("with_nulls", [True, False]) +def test_series_canonical_tensor_extension_type_size_bytes(dtype, size, with_nulls) -> None: + tensor_type = pa.fixed_shape_tensor(pa.int64(), (2, 2)) + if size == 0: + storage = pa.array([], pa.list_(pa.int64(), 4)) + data = pa.FixedShapeTensorArray.from_storage(tensor_type, storage) + elif with_nulls: + pydata = np.arange(4 * size).reshape((size, 4)).tolist()[:-1] + [None] + storage = pa.array(pydata, pa.list_(pa.int64(), 4)) + data = pa.FixedShapeTensorArray.from_storage(tensor_type, storage) + else: + arr = np.arange(4 * size).reshape((size, 2, 2)) + data = pa.FixedShapeTensorArray.from_numpy_ndarray(arr) + + s = Series.from_arrow(data) + + assert s.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(data.type.storage_type), '{"shape":[2,2]}' + ) + assert s.size_bytes() == get_total_buffer_size(data) diff --git a/tests/series/test_take.py b/tests/series/test_take.py index 2bb9f1baae..847db326a4 100644 --- a/tests/series/test_take.py +++ b/tests/series/test_take.py @@ -1,12 +1,16 @@ from __future__ import annotations +import numpy as np import pyarrow as pa import pytest +from daft.context import get_context from daft.datatype import DataType from daft.series import Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + @pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES + ARROW_STRING_TYPES) def test_series_take(dtype) -> None: @@ -115,6 +119,62 @@ def test_series_struct_take() -> None: assert result.to_pylist() == expected +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +def test_series_extension_type_take(uuid_ext_type) -> None: + pydata = [f"{i}".encode() for i in range(6)] + pydata[2] = None + storage = pa.array(pydata) + data = pa.ExtensionArray.from_storage(uuid_ext_type, storage) + + s = Series.from_arrow(data) + assert s.datatype() == DataType.extension( + uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), "" + ) + pyidx = [2, 0, None, 5] + idx = Series.from_pylist(pyidx) + + result = s.take(idx) + assert result.datatype() == s.datatype() + assert len(result) == 4 + + expected = [pydata[i] if i is not None else None for i in pyidx] + assert result.to_pylist() == expected + + +@pytest.mark.skipif( + ARROW_VERSION < (12, 0, 0), + reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.", +) +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="Pickling canonical tensor extension type is not supported by pyarrow", +) +def test_series_canonical_tensor_extension_type_take() -> None: + pydata = np.arange(24).reshape((6, 4)).tolist() + pydata[2] = None + storage = pa.array(pydata, pa.list_(pa.int64(), 4)) + tensor_type = pa.fixed_shape_tensor(pa.int64(), (2, 2)) + data = pa.FixedShapeTensorArray.from_storage(tensor_type, storage) + + s = Series.from_arrow(data) + assert s.datatype() == DataType.extension( + "arrow.fixed_shape_tensor", DataType.from_arrow_type(tensor_type.storage_type), '{"shape":[2,2]}' + ) + pyidx = [2, 0, None, 5] + idx = Series.from_pylist(pyidx) + + result = s.take(idx) + assert result.datatype() == s.datatype() + assert len(result) == 4 + + original_data = s.to_pylist() + expected = [original_data[i] if i is not None else None for i in pyidx] + assert result.to_pylist() == expected + + def test_series_deeply_nested_take() -> None: # Test take on a Series with a deeply nested type: struct of list of struct of list of strings. data = pa.array([{"a": [{"b": ["foo", "bar"]}]}, {"a": [{"b": ["baz", "quux"]}]}]) diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index 652d7848b9..3b0396fffa 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -11,9 +11,12 @@ from ray.data.extensions import ArrowTensorArray, ArrowTensorType from daft import DataType +from daft.context import get_context from daft.series import Series from daft.table import Table +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + PYTHON_TYPE_ARRAYS = { "int": [1, 2], "float": [1.0, 2.0], @@ -26,7 +29,7 @@ "empty_struct": [{}, {}], "null": [None, None], # The following types are not natively supported and will be cast to Python object types. - "tensor": list(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + "tensor": list(np.arange(8).reshape(2, 2, 2)), "timestamp": [datetime.datetime.now(), datetime.datetime.now()], } @@ -43,6 +46,7 @@ "empty_struct": DataType.struct({"": DataType.null()}), "null": DataType.null(), # The following types are not natively supported and will be cast to Python object types. + # TODO(Clark): Change the tensor inferred type to be the canonical fixed-shape tensor extension type. "tensor": DataType.python(), "timestamp": DataType.python(), } @@ -116,6 +120,24 @@ "timestamp": pa.timestamp("us"), } +if ARROW_VERSION >= (12, 0, 0) and get_context().runner_config.name != "ray": + ARROW_ROUNDTRIP_TYPES["canonical_tensor"] = pa.fixed_shape_tensor(pa.int64(), (2, 2)) + ARROW_TYPE_ARRAYS["canonical_tensor"] = pa.FixedShapeTensorArray.from_numpy_ndarray( + np.array(PYTHON_TYPE_ARRAYS["tensor"]) + ) + + +def _with_uuid_ext_type(uuid_ext_type) -> tuple[dict, dict]: + if get_context().runner_config.name == "ray": + # pyarrow extension types aren't supported in Ray clusters yet. + return ARROW_ROUNDTRIP_TYPES, ARROW_TYPE_ARRAYS + arrow_roundtrip_types = ARROW_ROUNDTRIP_TYPES.copy() + arrow_type_arrays = ARROW_TYPE_ARRAYS.copy() + arrow_roundtrip_types["ext_type"] = uuid_ext_type + storage = ARROW_TYPE_ARRAYS["binary"] + arrow_type_arrays["ext_type"] = pa.ExtensionArray.from_storage(uuid_ext_type, storage) + return arrow_roundtrip_types, arrow_type_arrays + def test_from_pydict_roundtrip() -> None: table = Table.from_pydict(PYTHON_TYPE_ARRAYS) @@ -134,24 +156,27 @@ def test_from_pydict_roundtrip() -> None: assert table.to_arrow() == expected_table -def test_from_pydict_arrow_roundtrip() -> None: - table = Table.from_pydict(ARROW_TYPE_ARRAYS) +def test_from_pydict_arrow_roundtrip(uuid_ext_type) -> None: + arrow_roundtrip_types, arrow_type_arrays = _with_uuid_ext_type(uuid_ext_type) + print(arrow_roundtrip_types) + table = Table.from_pydict(arrow_type_arrays) assert len(table) == 2 - assert set(table.column_names()) == set(ARROW_TYPE_ARRAYS.keys()) + assert set(table.column_names()) == set(arrow_type_arrays.keys()) for field in table.schema(): - assert field.dtype == DataType.from_arrow_type(ARROW_TYPE_ARRAYS[field.name].type) - expected_table = pa.table(ARROW_TYPE_ARRAYS).cast(pa.schema(ARROW_ROUNDTRIP_TYPES)) + assert field.dtype == DataType.from_arrow_type(arrow_type_arrays[field.name].type) + expected_table = pa.table(arrow_type_arrays).cast(pa.schema(arrow_roundtrip_types)) assert table.to_arrow() == expected_table -def test_from_arrow_roundtrip() -> None: - pa_table = pa.table(ARROW_TYPE_ARRAYS) +def test_from_arrow_roundtrip(uuid_ext_type) -> None: + arrow_roundtrip_types, arrow_type_arrays = _with_uuid_ext_type(uuid_ext_type) + pa_table = pa.table(arrow_type_arrays) table = Table.from_arrow(pa_table) assert len(table) == 2 - assert set(table.column_names()) == set(ARROW_TYPE_ARRAYS.keys()) + assert set(table.column_names()) == set(arrow_type_arrays.keys()) for field in table.schema(): - assert field.dtype == DataType.from_arrow_type(ARROW_TYPE_ARRAYS[field.name].type) - expected_table = pa.table(ARROW_TYPE_ARRAYS).cast(pa.schema(ARROW_ROUNDTRIP_TYPES)) + assert field.dtype == DataType.from_arrow_type(arrow_type_arrays[field.name].type) + expected_table = pa.table(arrow_type_arrays).cast(pa.schema(arrow_roundtrip_types)) assert table.to_arrow() == expected_table @@ -224,6 +249,24 @@ def test_from_pydict_arrow_struct_array() -> None: assert daft_table.to_arrow()["a"].combine_chunks() == expected +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +def test_from_pydict_arrow_extension_array(uuid_ext_type) -> None: + pydata = [f"{i}".encode() for i in range(6)] + pydata[2] = None + storage = pa.array(pydata) + arrow_arr = pa.ExtensionArray.from_storage(uuid_ext_type, storage) + daft_table = Table.from_pydict({"a": arrow_arr}) + assert "a" in daft_table.column_names() + # Although Daft will internally represent the binary storage array as a large_binary array, + # it should be cast back to the ingress extension type. + result = daft_table.to_arrow()["a"].combine_chunks() + assert result.type == uuid_ext_type + assert result == arrow_arr + + def test_from_pydict_arrow_deeply_nested() -> None: # Test a struct of lists of struct of lists of strings. data = [{"a": [{"b": ["foo", "bar"]}]}, {"a": [{"b": ["baz", "quux"]}]}] @@ -378,6 +421,24 @@ def test_from_arrow_struct_array() -> None: assert daft_table.to_arrow()["a"].combine_chunks() == expected +@pytest.mark.skipif( + get_context().runner_config.name == "ray", + reason="pyarrow extension types aren't supported on Ray clusters.", +) +def test_from_arrow_extension_array(uuid_ext_type) -> None: + pydata = [f"{i}".encode() for i in range(6)] + pydata[2] = None + storage = pa.array(pydata) + arrow_arr = pa.ExtensionArray.from_storage(uuid_ext_type, storage) + daft_table = Table.from_arrow(pa.table({"a": arrow_arr})) + assert "a" in daft_table.column_names() + # Although Daft will internally represent the binary storage array as a large_binary array, + # it should be cast back to the ingress extension type. + result = daft_table.to_arrow()["a"].combine_chunks() + assert result.type == uuid_ext_type + assert result == arrow_arr + + def test_from_arrow_deeply_nested() -> None: # Test a struct of lists of struct of lists of strings. data = [{"a": [{"b": ["foo", "bar"]}]}, {"a": [{"b": ["baz", "quux"]}]}]