Skip to content

Commit

Permalink
[Extension Types] Add support for cross-lang extension types. (#899)
Browse files Browse the repository at this point in the history
This PR adds support for cross-language extension types, i.e. extension
types that use a language-agnostic serialization method.

## TODOs

- [x] Add non-tensor test coverage for Arrow < 12.0.0.
  • Loading branch information
clarkzinzow authored May 18, 2023
1 parent 63b21be commit 19e3881
Show file tree
Hide file tree
Showing 35 changed files with 992 additions and 270 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ prettytable-rs = "^0.10"
rand = "^0.8"

[dependencies.arrow2]
branch = "clark/expand-casting-support"
features = ["compute", "io_ipc"]
git = "https://github.com/Eventual-Inc/arrow2"
package = "arrow2"
version = "0.17.1"

[dependencies.bincode]
Expand All @@ -15,6 +18,9 @@ version = "1.3.3"
features = ["serde"]
version = "1.9.2"

[dependencies.lazy_static]
version = "1.4.0"

[dependencies.num-traits]
version = "0.2"

Expand Down
1 change: 0 additions & 1 deletion daft/arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def ensure_array(arr: pa.Array) -> pa.Array:


class _FixSliceOffsets:

# TODO(Clark): For pyarrow < 12.0.0, struct array slice offsets are dropped
# when converting to record batches. We work around this below by flattening
# the field arrays for all struct arrays, which propagates said offsets to
Expand Down
57 changes: 56 additions & 1 deletion daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,33 @@

import pyarrow as pa

from daft.context import get_context
from daft.daft import PyDataType

_RAY_DATA_EXTENSIONS_AVAILABLE = True
_TENSOR_EXTENSION_TYPES = []
try:
import ray
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False
else:
_RAY_VERSION = tuple(int(s) for s in ray.__version__.split("."))
try:
# Variable-shaped tensor column support was added in Ray 2.1.0.
if _RAY_VERSION >= (2, 2, 0):
from ray.data.extensions import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

_TENSOR_EXTENSION_TYPES = [ArrowTensorType, ArrowVariableShapedTensorType]
else:
from ray.data.extensions import ArrowTensorType

_TENSOR_EXTENSION_TYPES = [ArrowTensorType]
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False


class DataType:
_dtype: PyDataType
Expand Down Expand Up @@ -96,6 +121,10 @@ def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType:
def struct(cls, fields: dict[str, DataType]) -> DataType:
return cls._from_pydatatype(PyDataType.struct({name: datatype._dtype for name, datatype in fields.items()}))

@classmethod
def extension(cls, name: str, storage_dtype: DataType, metadata: str | None = None) -> DataType:
return cls._from_pydatatype(PyDataType.extension(name, storage_dtype._dtype, metadata))

@classmethod
def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
if pa.types.is_int8(arrow_type):
Expand Down Expand Up @@ -140,9 +169,35 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
assert isinstance(arrow_type, pa.StructType)
fields = [arrow_type[i] for i in range(arrow_type.num_fields)]
return cls.struct({field.name: cls.from_arrow_type(field.type) for field in fields})
elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)):
# TODO(Clark): Add a native cross-lang extension type representation for Ray's tensor extension types.
return cls.python()
elif isinstance(arrow_type, pa.PyExtensionType):
# TODO(Clark): Add a native cross-lang extension type representation for PyExtensionTypes.
raise ValueError(
"pyarrow extension types that subclass pa.PyExtensionType can't be used in Daft, since they can't be "
f"used in non-Python Arrow implementations and Daft uses the Rust Arrow2 implementation: {arrow_type}"
)
elif isinstance(arrow_type, pa.BaseExtensionType):
if get_context().runner_config.name == "ray":
raise ValueError(
f"pyarrow extension types are not supported for the Ray runner: {arrow_type}. If you need support "
"for this, please let us know on this issue: "
"https://github.com/Eventual-Inc/Daft/issues/933"
)
name = arrow_type.extension_name
try:
metadata = arrow_type.__arrow_ext_serialize__().decode()
except AttributeError:
metadata = None
return cls.extension(
name,
cls.from_arrow_type(arrow_type.storage_type),
metadata,
)
else:
# Fall back to a Python object type.
# TODO(Clark): Add native support for remaining Arrow types and extension types.
# TODO(Clark): Add native support for remaining Arrow types.
return cls.python()

@classmethod
Expand Down
22 changes: 19 additions & 3 deletions src/array/ops/arrow2/sort/primitive/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ mod tests {
use super::*;

use arrow2::array::ord;
use arrow2::array::Array;
use arrow2::array::PrimitiveArray;
use arrow2::datatypes::DataType;

Expand All @@ -177,13 +178,28 @@ mod tests {
) where
T: NativeType + std::cmp::Ord,
{
let input = PrimitiveArray::<T>::from(data).to(data_type.clone());
let expected = PrimitiveArray::<T>::from(expected_data).to(data_type.clone());
let input = PrimitiveArray::<T>::from(data)
.to(data_type.clone())
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let expected = PrimitiveArray::<T>::from(expected_data)
.to(data_type.clone())
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let output = sort_by(&input, ord::total_cmp, &options, None);
assert_eq!(expected, output);

// with limit
let expected = PrimitiveArray::<T>::from(&expected_data[..3]).to(data_type);
let expected = PrimitiveArray::<T>::from(&expected_data[..3])
.to(data_type)
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let output = sort_by(&input, ord::total_cmp, &options, Some(3));
assert_eq!(expected, output)
}
Expand Down
21 changes: 19 additions & 2 deletions src/array/ops/broadcast.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
array::DataArray,
datatypes::{
BinaryArray, BooleanArray, DaftNumericType, DataType, FixedSizeListArray, ListArray,
NullArray, StructArray, Utf8Array,
BinaryArray, BooleanArray, DaftNumericType, DataType, ExtensionArray, FixedSizeListArray,
ListArray, NullArray, StructArray, Utf8Array,
},
error::{DaftError, DaftResult},
};
Expand Down Expand Up @@ -251,6 +251,23 @@ impl Broadcastable for StructArray {
}
}

impl Broadcastable for ExtensionArray {
fn broadcast(&self, num: usize) -> DaftResult<Self> {
if self.len() != 1 {
return Err(DaftError::ValueError(format!(
"Attempting to broadcast non-unit length Array named: {}",
self.name()
)));
}
let array = self.data();
let mut growable = arrow2::array::growable::make_growable(&[array], true, num);
for _ in 0..num {
growable.extend(0, 0, 1);
}
ExtensionArray::new(self.field.clone(), growable.as_box())
}
}

#[cfg(feature = "python")]
impl Broadcastable for crate::datatypes::PythonArray {
fn broadcast(&self, num: usize) -> DaftResult<Self> {
Expand Down
1 change: 1 addition & 0 deletions src/array/ops/compare_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ impl_todo_daft_comparable!(BinaryArray);
impl_todo_daft_comparable!(StructArray);
impl_todo_daft_comparable!(FixedSizeListArray);
impl_todo_daft_comparable!(ListArray);
impl_todo_daft_comparable!(ExtensionArray);

#[cfg(feature = "python")]
impl_todo_daft_comparable!(PythonArray);
62 changes: 3 additions & 59 deletions src/array/ops/filter.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,17 @@
use crate::{
array::DataArray,
datatypes::{
logical::DateArray, BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray,
ListArray, NullArray, StructArray, Utf8Array,
},
datatypes::{logical::DateArray, BooleanArray, DaftArrowBackedType},
error::DaftResult,
};

use super::as_arrow::AsArrow;

impl<T> DataArray<T>
where
T: DaftNumericType,
T: DaftArrowBackedType,
{
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl Utf8Array {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl BinaryArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl BooleanArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl NullArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let set_bits = mask.len() - mask.as_arrow().values().unset_bits();
Ok(NullArray::full_null(
self.name(),
self.data_type(),
set_bits,
))
}
}

impl ListArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl FixedSizeListArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl StructArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
let result = arrow2::compute::filter::filter(self.data(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}
Expand Down
10 changes: 2 additions & 8 deletions src/array/ops/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ where
let arrow_dtype = dtype.to_arrow();
match arrow_dtype {
Ok(arrow_dtype) => DataArray::<T>::new(
Arc::new(Field {
name: name.to_string(),
dtype: dtype.clone(),
}),
Arc::new(Field::new(name.to_string(), dtype.clone())),
arrow2::array::new_null_array(arrow_dtype, length),
)
.unwrap(),
Expand All @@ -54,10 +51,7 @@ where
let arrow_dtype = dtype.to_arrow();
match arrow_dtype {
Ok(arrow_dtype) => DataArray::<T>::new(
Arc::new(Field {
name: name.to_string(),
dtype: dtype.clone(),
}),
Arc::new(Field::new(name.to_string(), dtype.clone())),
arrow2::array::new_empty_array(arrow_dtype),
)
.unwrap(),
Expand Down
Loading

0 comments on commit 19e3881

Please sign in to comment.