Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Extension Types] Add support for cross-lang extension types. #899

Merged
merged 9 commits into from
May 18, 2023
Merged
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ prettytable-rs = "^0.10"
rand = "^0.8"

[dependencies.arrow2]
branch = "clark/expand-casting-support"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we merge it into main in our fork before merging this PR?

Copy link
Contributor Author

@clarkzinzow clarkzinzow May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main con with working off of main is that we'd be making the divergence from upstream Arrow2 implicit, and if we keep merging upstream Arrow2 main into our fork's main, our differing commits might be lost in the git history. Instead, I think it would be better to anchor to a branch where all of the diverging commits are at the head and try to keep the branching small/short-lived.

The flow could be:

  1. Create a branch on Arrow2 fork containing requisite changes to Arrow2.
  2. Update Daft to point to that branch.
  3. Submit a PR from Arrow2 fork branch to upstream Arrow2.
  4. If the PR is updated during the review process, we can update the locked Daft dependency with a cargo update PR.
  5. When the PR is merged and is included in an Arrow2 release, switch Daft's Arrow2 dependency to point to that crates.io release.
  6. If we need another Arrow2 change stacked on the existing Arrow2 branch/PR, we could create a nightly snapshot branch containing both changes, similar to what Polars does: https://github.com/pola-rs/polars/blob/528590cfa57e48f5bd902ad027f5e35318644110/Cargo.toml#L47

I think this should help keep the difference with upstream Arrow2 explicit, which should be nicer? Let me know what you think about that flow.

features = ["compute", "io_ipc"]
git = "https://github.com/Eventual-Inc/arrow2"
package = "arrow2"
version = "0.17.1"

[dependencies.bincode]
Expand All @@ -15,6 +18,9 @@ version = "1.3.3"
features = ["serde"]
version = "1.9.2"

[dependencies.lazy_static]
version = "1.4.0"

[dependencies.num-traits]
version = "0.2"

Expand Down
1 change: 0 additions & 1 deletion daft/arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def ensure_array(arr: pa.Array) -> pa.Array:


class _FixSliceOffsets:

# TODO(Clark): For pyarrow < 12.0.0, struct array slice offsets are dropped
# when converting to record batches. We work around this below by flattening
# the field arrays for all struct arrays, which propagates said offsets to
Expand Down
57 changes: 56 additions & 1 deletion daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,33 @@

import pyarrow as pa

from daft.context import get_context
from daft.daft import PyDataType

_RAY_DATA_EXTENSIONS_AVAILABLE = True
_TENSOR_EXTENSION_TYPES = []
try:
import ray
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False
else:
_RAY_VERSION = tuple(int(s) for s in ray.__version__.split("."))
try:
# Variable-shaped tensor column support was added in Ray 2.1.0.
if _RAY_VERSION >= (2, 2, 0):
from ray.data.extensions import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

_TENSOR_EXTENSION_TYPES = [ArrowTensorType, ArrowVariableShapedTensorType]
else:
from ray.data.extensions import ArrowTensorType

_TENSOR_EXTENSION_TYPES = [ArrowTensorType]
except ImportError:
_RAY_DATA_EXTENSIONS_AVAILABLE = False


class DataType:
_dtype: PyDataType
Expand Down Expand Up @@ -96,6 +121,10 @@ def fixed_size_list(cls, name: str, dtype: DataType, size: int) -> DataType:
def struct(cls, fields: dict[str, DataType]) -> DataType:
return cls._from_pydatatype(PyDataType.struct({name: datatype._dtype for name, datatype in fields.items()}))

@classmethod
def extension(cls, name: str, storage_dtype: DataType, metadata: str | None = None) -> DataType:
return cls._from_pydatatype(PyDataType.extension(name, storage_dtype._dtype, metadata))

@classmethod
def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
if pa.types.is_int8(arrow_type):
Expand Down Expand Up @@ -140,9 +169,35 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
assert isinstance(arrow_type, pa.StructType)
fields = [arrow_type[i] for i in range(arrow_type.num_fields)]
return cls.struct({field.name: cls.from_arrow_type(field.type) for field in fields})
elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)):
# TODO(Clark): Add a native cross-lang extension type representation for Ray's tensor extension types.
return cls.python()
elif isinstance(arrow_type, pa.PyExtensionType):
# TODO(Clark): Add a native cross-lang extension type representation for PyExtensionTypes.
raise ValueError(
"pyarrow extension types that subclass pa.PyExtensionType can't be used in Daft, since they can't be "
f"used in non-Python Arrow implementations and Daft uses the Rust Arrow2 implementation: {arrow_type}"
)
elif isinstance(arrow_type, pa.BaseExtensionType):
if get_context().runner_config.name == "ray":
raise ValueError(
f"pyarrow extension types are not supported for the Ray runner: {arrow_type}. If you need support "
"for this, please let us know on this issue: "
"https://github.com/Eventual-Inc/Daft/issues/933"
)
name = arrow_type.extension_name
try:
metadata = arrow_type.__arrow_ext_serialize__().decode()
except AttributeError:
metadata = None
return cls.extension(
name,
cls.from_arrow_type(arrow_type.storage_type),
metadata,
)
else:
# Fall back to a Python object type.
# TODO(Clark): Add native support for remaining Arrow types and extension types.
# TODO(Clark): Add native support for remaining Arrow types.
return cls.python()

@classmethod
Expand Down
22 changes: 19 additions & 3 deletions src/array/ops/arrow2/sort/primitive/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ mod tests {
use super::*;

use arrow2::array::ord;
use arrow2::array::Array;
use arrow2::array::PrimitiveArray;
use arrow2::datatypes::DataType;

Expand All @@ -177,13 +178,28 @@ mod tests {
) where
T: NativeType + std::cmp::Ord,
{
let input = PrimitiveArray::<T>::from(data).to(data_type.clone());
let expected = PrimitiveArray::<T>::from(expected_data).to(data_type.clone());
let input = PrimitiveArray::<T>::from(data)
.to(data_type.clone())
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let expected = PrimitiveArray::<T>::from(expected_data)
.to(data_type.clone())
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let output = sort_by(&input, ord::total_cmp, &options, None);
assert_eq!(expected, output);

// with limit
let expected = PrimitiveArray::<T>::from(&expected_data[..3]).to(data_type);
let expected = PrimitiveArray::<T>::from(&expected_data[..3])
.to(data_type)
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let output = sort_by(&input, ord::total_cmp, &options, Some(3));
assert_eq!(expected, output)
}
Expand Down
21 changes: 19 additions & 2 deletions src/array/ops/broadcast.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
array::DataArray,
datatypes::{
BinaryArray, BooleanArray, DaftNumericType, DataType, FixedSizeListArray, ListArray,
NullArray, StructArray, Utf8Array,
BinaryArray, BooleanArray, DaftNumericType, DataType, ExtensionArray, FixedSizeListArray,
ListArray, NullArray, StructArray, Utf8Array,
},
error::{DaftError, DaftResult},
};
Expand Down Expand Up @@ -251,6 +251,23 @@ impl Broadcastable for StructArray {
}
}

impl Broadcastable for ExtensionArray {
fn broadcast(&self, num: usize) -> DaftResult<Self> {
if self.len() != 1 {
return Err(DaftError::ValueError(format!(
"Attempting to broadcast non-unit length Array named: {}",
self.name()
)));
}
let array = self.data();
let mut growable = arrow2::array::growable::make_growable(&[array], true, num);
for _ in 0..num {
growable.extend(0, 0, 1);
}
ExtensionArray::new(self.field.clone(), growable.as_box())
}
}

#[cfg(feature = "python")]
impl Broadcastable for crate::datatypes::PythonArray {
fn broadcast(&self, num: usize) -> DaftResult<Self> {
Expand Down
1 change: 1 addition & 0 deletions src/array/ops/compare_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ impl_todo_daft_comparable!(BinaryArray);
impl_todo_daft_comparable!(StructArray);
impl_todo_daft_comparable!(FixedSizeListArray);
impl_todo_daft_comparable!(ListArray);
impl_todo_daft_comparable!(ExtensionArray);

#[cfg(feature = "python")]
impl_todo_daft_comparable!(PythonArray);
62 changes: 3 additions & 59 deletions src/array/ops/filter.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,17 @@
use crate::{
array::DataArray,
datatypes::{
logical::DateArray, BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray,
ListArray, NullArray, StructArray, Utf8Array,
},
datatypes::{logical::DateArray, BooleanArray, DaftArrowBackedType},
error::DaftResult,
};

use super::as_arrow::AsArrow;

impl<T> DataArray<T>
where
T: DaftNumericType,
T: DaftArrowBackedType,
{
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl Utf8Array {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl BinaryArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl BooleanArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl NullArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let set_bits = mask.len() - mask.as_arrow().values().unset_bits();
Ok(NullArray::full_null(
self.name(),
self.data_type(),
set_bits,
))
}
}

impl ListArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl FixedSizeListArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}

impl StructArray {
pub fn filter(&self, mask: &BooleanArray) -> DaftResult<Self> {
let result = arrow2::compute::filter::filter(self.as_arrow(), mask.as_arrow())?;
let result = arrow2::compute::filter::filter(self.data(), mask.as_arrow())?;
Self::try_from((self.field.clone(), result))
}
}
Expand Down
10 changes: 2 additions & 8 deletions src/array/ops/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ where
let arrow_dtype = dtype.to_arrow();
match arrow_dtype {
Ok(arrow_dtype) => DataArray::<T>::new(
Arc::new(Field {
name: name.to_string(),
dtype: dtype.clone(),
}),
Arc::new(Field::new(name.to_string(), dtype.clone())),
arrow2::array::new_null_array(arrow_dtype, length),
)
.unwrap(),
Expand All @@ -54,10 +51,7 @@ where
let arrow_dtype = dtype.to_arrow();
match arrow_dtype {
Ok(arrow_dtype) => DataArray::<T>::new(
Arc::new(Field {
name: name.to_string(),
dtype: dtype.clone(),
}),
Arc::new(Field::new(name.to_string(), dtype.clone())),
arrow2::array::new_empty_array(arrow_dtype),
)
.unwrap(),
Expand Down
Loading