Skip to content

Commit

Permalink
Add support for cross-lang extension types.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed May 5, 2023
1 parent 4d8c4ea commit 6011c52
Show file tree
Hide file tree
Showing 25 changed files with 519 additions and 102 deletions.
1 change: 0 additions & 1 deletion daft/arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def ensure_array(arr: pa.Array) -> pa.Array:


class _FixSliceOffsets:

# TODO(Clark): For pyarrow < 12.0.0, struct array slice offsets are dropped
# when converting to record batches. We work around this below by flattening
# the field arrays for all struct arrays, which propagates said offsets to
Expand Down
50 changes: 49 additions & 1 deletion daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,30 @@

from daft.daft import PyDataType

_RAY_DATA_EXTENSIONS_AVAILABLE = True
_TENSOR_EXTENSION_TYPES = []
try:
import ray
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False
else:
_RAY_VERSION = tuple(int(s) for s in ray.__version__.split("."))
try:
# Variable-shaped tensor column support was added in Ray 2.1.0.
if _RAY_VERSION >= (2, 2, 0):
from ray.data.extensions import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

_TENSOR_EXTENSION_TYPES = [ArrowTensorType, ArrowVariableShapedTensorType]
else:
from ray.data.extensions import ArrowTensorType

_TENSOR_EXTENSION_TYPES = [ArrowTensorType]
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False


class DataType:
_dtype: PyDataType
Expand Down Expand Up @@ -96,6 +120,10 @@ def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType:
def struct(cls, fields: dict[str, DataType]) -> DataType:
return cls._from_pydatatype(PyDataType.struct({name: datatype._dtype for name, datatype in fields.items()}))

@classmethod
def extension(cls, name: str, storage_dtype: DataType, metadata: str | None = None) -> DataType:
return cls._from_pydatatype(PyDataType.extension(name, storage_dtype._dtype, metadata))

@classmethod
def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
if pa.types.is_int8(arrow_type):
Expand Down Expand Up @@ -140,9 +168,29 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
assert isinstance(arrow_type, pa.StructType)
fields = [arrow_type[i] for i in range(arrow_type.num_fields)]
return cls.struct({field.name: cls.from_arrow_type(field.type) for field in fields})
elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)):
# TODO(Clark): Add a native cross-lang extension type representation for Ray's tensor extension types.
return cls.python()
elif isinstance(arrow_type, pa.PyExtensionType):
# TODO(Clark): Add a native cross-lang extension type representation for PyExtensionTypes.
raise ValueError(
"pyarrow extension types that subclass pa.PyExtensionType can't be used in Daft, since they can't be "
f"used in non-Python Arrow implementations and Daft uses the Rust Arrow2 implementation: {arrow_type}"
)
elif isinstance(arrow_type, pa.BaseExtensionType):
name = arrow_type.extension_name
try:
metadata = arrow_type.__arrow_ext_serialize__().decode()
except AttributeError:
metadata = None
return cls.extension(
name,
cls.from_arrow_type(arrow_type.storage_type),
metadata,
)
else:
# Fall back to a Python object type.
# TODO(Clark): Add native support for remaining Arrow types and extension types.
# TODO(Clark): Add native support for remaining Arrow types.
return cls.python()

@classmethod
Expand Down
21 changes: 19 additions & 2 deletions src/array/ops/broadcast.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
array::{BaseArray, DataArray},
datatypes::{
BinaryArray, BooleanArray, DaftNumericType, DataType, FixedSizeListArray, ListArray,
NullArray, StructArray, Utf8Array,
BinaryArray, BooleanArray, DaftNumericType, DataType, ExtensionArray, FixedSizeListArray,
ListArray, NullArray, StructArray, Utf8Array,
},
error::{DaftError, DaftResult},
};
Expand Down Expand Up @@ -251,6 +251,23 @@ impl Broadcastable for StructArray {
}
}

impl Broadcastable for ExtensionArray {
fn broadcast(&self, num: usize) -> DaftResult<Self> {
if self.len() != 1 {
return Err(DaftError::ValueError(format!(
"Attempting to broadcast non-unit length Array named: {}",
self.name()
)));
}
let array = self.data();
let mut growable = arrow2::array::growable::make_growable(&[array], true, num);
for _ in 0..num {
growable.extend(0, 0, 1);
}
ExtensionArray::new(self.field.clone(), growable.as_box())
}
}

#[cfg(feature = "python")]
impl Broadcastable for crate::datatypes::PythonArray {
fn broadcast(&self, num: usize) -> DaftResult<Self> {
Expand Down
11 changes: 9 additions & 2 deletions src/array/ops/filter.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
array::DataArray,
datatypes::{
BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, ListArray, NullArray,
StructArray, Utf8Array,
BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray, ListArray,
NullArray, StructArray, Utf8Array,
},
error::DaftResult,
};
Expand Down Expand Up @@ -74,6 +74,13 @@ impl StructArray {
}
}

impl ExtensionArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.data(), mask.downcast())?;
DataArray::try_from((self.name(), result))
}
}

#[cfg(feature = "python")]
impl crate::datatypes::PythonArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
Expand Down
10 changes: 2 additions & 8 deletions src/array/ops/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ where
let arrow_dtype = dtype.to_arrow();
match arrow_dtype {
Ok(arrow_dtype) => DataArray::<T>::new(
Arc::new(Field {
name: name.to_string(),
dtype: dtype.clone(),
}),
Arc::new(Field::new(name.to_string(), dtype.clone())),
arrow2::array::new_null_array(arrow_dtype, length),
)
.unwrap(),
Expand All @@ -29,10 +26,7 @@ where
let arrow_dtype = dtype.to_arrow();
match arrow_dtype {
Ok(arrow_dtype) => DataArray::<T>::new(
Arc::new(Field {
name: name.to_string(),
dtype: dtype.clone(),
}),
Arc::new(Field::new(name.to_string(), dtype.clone())),
arrow2::array::new_empty_array(arrow_dtype),
)
.unwrap(),
Expand Down
101 changes: 53 additions & 48 deletions src/array/ops/if_else.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::array::{BaseArray, DataArray};
use crate::datatypes::{
BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, ListArray, NullArray,
PythonArray, StructArray, Utf8Array,
BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray, ListArray,
NullArray, PythonArray, StructArray, Utf8Array,
};
use crate::error::{DaftError, DaftResult};
use crate::utils::arrow::arrow_bitmap_and_helper;
Expand Down Expand Up @@ -224,62 +224,57 @@ impl PythonArray {
}
}

fn from_arrow_if_then_else<T>(predicate: &BooleanArray, if_true: &T, if_false: &T) -> DaftResult<T>
where
T: BaseArray
+ Downcastable
+ for<'a> TryFrom<(&'a str, Box<dyn arrow2::array::Array>), Error = DaftError>,
<T as Downcastable>::Output: arrow2::array::Array,
{
let result = arrow2::compute::if_then_else::if_then_else(
predicate.downcast(),
if_true.downcast(),
if_false.downcast(),
)?;
T::try_from((if_true.name(), result))
}

fn nested_if_then_else<T>(predicate: &BooleanArray, if_true: &T, if_false: &T) -> DaftResult<T>
where
T: BaseArray
+ Broadcastable
+ Downcastable
+ for<'a> TryFrom<(&'a str, Box<dyn arrow2::array::Array>), Error = DaftError>,
<T as Downcastable>::Output: arrow2::array::Array,
{
// TODO(Clark): Support streaming broadcasting, i.e. broadcasting without inflating scalars to full array length.
match (predicate.len(), if_true.len(), if_false.len()) {
let result = match (predicate.len(), if_true.len(), if_false.len()) {
(predicate_len, if_true_len, if_false_len)
if predicate_len == if_true_len && if_true_len == if_false_len => from_arrow_if_then_else(predicate, if_true, if_false),
(1, if_true_len, 1) => from_arrow_if_then_else(
&predicate.broadcast(if_true_len)?,
if_true,
&if_false.broadcast(if_true_len)?,
),
(1, 1, if_false_len) => from_arrow_if_then_else(
&predicate.broadcast(if_false_len)?,
&if_true.broadcast(if_false_len)?,
if_false,
),
(predicate_len, 1, 1) => from_arrow_if_then_else(
predicate,
&if_true.broadcast(predicate_len)?,
&if_false.broadcast(predicate_len)?,
),
(predicate_len, if_true_len, 1) if predicate_len == if_true_len => from_arrow_if_then_else(
predicate,
if_true,
&if_false.broadcast(predicate_len)?,
),
(predicate_len, 1, if_false_len) if predicate_len == if_false_len => from_arrow_if_then_else(
predicate,
&if_true.broadcast(predicate_len)?,
if_false,
),
if predicate_len == if_true_len && if_true_len == if_false_len =>
{
arrow2::compute::if_then_else::if_then_else(
predicate.downcast(),
if_true.data(),
if_false.data(),
)?
}
(1, if_true_len, 1) => arrow2::compute::if_then_else::if_then_else(
predicate.broadcast(if_true_len)?.downcast(),
if_true.data(),
if_false.broadcast(if_true_len)?.data(),
)?,
(1, 1, if_false_len) => arrow2::compute::if_then_else::if_then_else(
predicate.broadcast(if_false_len)?.downcast(),
if_true.broadcast(if_false_len)?.data(),
if_false.data(),
)?,
(predicate_len, 1, 1) => arrow2::compute::if_then_else::if_then_else(
predicate.downcast(),
if_true.broadcast(predicate_len)?.data(),
if_false.broadcast(predicate_len)?.data(),
)?,
(predicate_len, if_true_len, 1) if predicate_len == if_true_len => {
arrow2::compute::if_then_else::if_then_else(
predicate.downcast(),
if_true.data(),
if_false.broadcast(predicate_len)?.data(),
)?
}
(predicate_len, 1, if_false_len) if predicate_len == if_false_len => {
arrow2::compute::if_then_else::if_then_else(
predicate.downcast(),
if_true.broadcast(predicate_len)?.data(),
if_false.data(),
)?
}
(p, s, o) => {
Err(DaftError::ValueError(format!("Cannot run if_else against arrays with non-broadcastable lengths: if_true={s}, if_false={o}, predicate={p}")))
return Err(DaftError::ValueError(format!("Cannot run if_else against arrays with non-broadcastable lengths: if_true={s}, if_false={o}, predicate={p}")));
}
}
};
T::try_from((if_true.name(), result))
}

impl ListArray {
Expand Down Expand Up @@ -307,3 +302,13 @@ impl StructArray {
nested_if_then_else(predicate, self, other)
}
}

impl ExtensionArray {
pub fn if_else(
&self,
other: &ExtensionArray,
predicate: &BooleanArray,
) -> DaftResult<ExtensionArray> {
nested_if_then_else(predicate, self, other)
}
}
16 changes: 8 additions & 8 deletions src/array/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ impl ListArray {
};

with_match_arrow_daft_types!(child_data_type,|$T| {
let new_data_arr = DataArray::<$T>::new(Arc::new(Field {
name: self.field.name.clone(),
dtype: child_data_type.clone(),
}), growable.as_box())?;
let new_data_arr = DataArray::<$T>::new(Arc::new(Field::new(
self.field.name.clone(),
child_data_type.clone(),
)), growable.as_box())?;
Ok(new_data_arr.into_series())
})
}
Expand Down Expand Up @@ -120,10 +120,10 @@ impl FixedSizeListArray {
};

with_match_arrow_daft_types!(child_data_type,|$T| {
let new_data_arr = DataArray::<$T>::new(Arc::new(Field {
name: self.field.name.clone(),
dtype: child_data_type.clone(),
}), growable.as_box())?;
let new_data_arr = DataArray::<$T>::new(Arc::new(Field::new(
self.field.name.clone(),
child_data_type.clone(),
)), growable.as_box())?;
Ok(new_data_arr.into_series())
})
}
Expand Down
39 changes: 37 additions & 2 deletions src/array/ops/take.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
array::{BaseArray, DataArray},
datatypes::{
BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, FixedSizeListArray, ListArray,
NullArray, StructArray, Utf8Array,
BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray,
FixedSizeListArray, ListArray, NullArray, StructArray, Utf8Array,
},
error::DaftResult,
};
Expand Down Expand Up @@ -291,6 +291,41 @@ impl StructArray {
}
}

impl ExtensionArray {
#[inline]
pub fn get(&self, idx: usize) -> Option<Box<dyn arrow2::scalar::Scalar>> {
if idx >= self.len() {
panic!("Out of bounds: {} vs len: {}", idx, self.len())
}
let is_valid = self
.data
.validity()
.map_or(true, |validity| validity.get_bit(idx));
if is_valid {
Some(arrow2::scalar::new_scalar(self.data(), idx))
} else {
None
}
}

pub fn take<I>(&self, idx: &DataArray<I>) -> DaftResult<Self>
where
I: DaftIntegerType,
<I as DaftNumericType>::Native: arrow2::types::Index,
{
let result = arrow2::compute::take::take(self.data(), idx.downcast())?;
Self::try_from((self.name(), result))
}

pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let val = self.get(idx);
match val {
None => Ok("None".to_string()),
Some(v) => Ok(format!("{v:?}")),
}
}
}

#[cfg(feature = "python")]
impl crate::datatypes::PythonArray {
#[inline]
Expand Down
Loading

0 comments on commit 6011c52

Please sign in to comment.