Skip to content

Commit

Permalink
Remove support for pyarrow extension types on the Ray runner.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed May 17, 2023
1 parent 2728bcf commit 88ddc40
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
7 changes: 7 additions & 0 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pyarrow as pa

from daft.context import get_context
from daft.daft import PyDataType

_RAY_DATA_EXTENSIONS_AVAILABLE = True
Expand Down Expand Up @@ -178,6 +179,12 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
f"used in non-Python Arrow implementations and Daft uses the Rust Arrow2 implementation: {arrow_type}"
)
elif isinstance(arrow_type, pa.BaseExtensionType):
if get_context().runner_config.name == "ray":
raise ValueError(
f"pyarrow extension types are not supported for the Ray runner: {arrow_type}. If you need support "
"for this, please let us know on this issue: "
"https://github.com/Eventual-Inc/Daft/issues/933"
)
name = arrow_type.extension_name
try:
metadata = arrow_type.__arrow_ext_serialize__().decode()
Expand Down
25 changes: 25 additions & 0 deletions tests/dataframe/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import daft
from daft.api_annotations import APITypeError
from daft.context import get_context
from daft.dataframe import DataFrame
from daft.datatype import DataType
from tests.conftest import UuidType
Expand Down Expand Up @@ -173,6 +174,10 @@ def test_create_dataframe_arrow_tensor_ray(valid_data: list[dict[str, float]]) -
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
)
@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="Pickling canonical tensor extension type is not supported by pyarrow",
)
def test_create_dataframe_arrow_tensor_canonical(valid_data: list[dict[str, float]]) -> None:
pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()}
dtype = pa.fixed_shape_tensor(pa.int64(), (2, 2))
Expand All @@ -191,6 +196,10 @@ def test_create_dataframe_arrow_tensor_canonical(valid_data: list[dict[str, floa
assert df.to_arrow() == expected


@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="pyarrow extension types aren't supported on Ray clusters.",
)
def test_create_dataframe_arrow_extension_type(valid_data: list[dict[str, float]], uuid_ext_type: UuidType) -> None:
pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()}
storage = pa.array([f"{i}".encode() for i in range(len(valid_data))])
Expand All @@ -207,6 +216,22 @@ def test_create_dataframe_arrow_extension_type(valid_data: list[dict[str, float]
assert df.to_arrow() == expected


# TODO(Clark): Remove this test once pyarrow extension types are supported for Ray clusters.
@pytest.mark.skipif(
get_context().runner_config.name != "ray",
reason="This test requires the Ray runner.",
)
def test_create_dataframe_arrow_extension_type_fails_for_ray(
valid_data: list[dict[str, float]], uuid_ext_type: UuidType
) -> None:
pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()}
storage = pa.array([f"{i}".encode() for i in range(len(valid_data))])
pydict["obj"] = pa.ExtensionArray.from_storage(uuid_ext_type, storage)
t = pa.Table.from_pydict(pydict)
with pytest.raises(ValueError):
daft.from_arrow(t).to_arrow()


class PyExtType(pa.PyExtensionType):
def __init__(self):
pa.PyExtensionType.__init__(self, pa.binary())
Expand Down

0 comments on commit 88ddc40

Please sign in to comment.