Skip to content

Commit

Permalink
Add extension type test coverage.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed May 13, 2023
1 parent 00c7a0b commit 59b18a7
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 19 deletions.
2 changes: 0 additions & 2 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
metadata = arrow_type.__arrow_ext_serialize__().decode()
except AttributeError:
metadata = None
if metadata == "":
metadata = None
return cls.extension(
name,
cls.from_arrow_type(arrow_type.storage_type),
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import pyarrow as pa
import pytest


class UuidType(pa.ExtensionType):
Expand All @@ -18,6 +19,14 @@ def __arrow_ext_deserialize__(self, storage_type, serialized):
return UuidType()


@pytest.fixture
def uuid_ext_type() -> UuidType:
ext_type = UuidType()
pa.register_extension_type(ext_type)
yield ext_type
pa.unregister_extension_type(ext_type.NAME)


def assert_df_equals(
daft_df: pd.DataFrame,
pd_df: pd.DataFrame,
Expand Down
10 changes: 5 additions & 5 deletions tests/dataframe/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,16 @@ def test_create_dataframe_arrow_tensor_canonical(valid_data: list[dict[str, floa
assert df.to_arrow() == expected


def test_create_dataframe_arrow_extension_type(valid_data: list[dict[str, float]]) -> None:
def test_create_dataframe_arrow_extension_type(valid_data: list[dict[str, float]], uuid_ext_type: UuidType) -> None:
pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()}
dtype = UuidType()
pa.register_extension_type(dtype)
storage = pa.array([f"{i}".encode() for i in range(len(valid_data))])
pydict["obj"] = pa.ExtensionArray.from_storage(dtype, storage)
pydict["obj"] = pa.ExtensionArray.from_storage(uuid_ext_type, storage)
t = pa.Table.from_pydict(pydict)
df = daft.from_arrow(t)
assert set(df.column_names) == set(t.column_names)
assert df.schema()["obj"].dtype == DataType.extension(dtype.NAME, DataType.from_arrow_type(dtype.storage_type), "")
assert df.schema()["obj"].dtype == DataType.extension(
uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), ""
)
casted_field = t.schema.field("variety").with_type(pa.large_string())
expected = t.cast(t.schema.set(t.schema.get_field_index("variety"), casted_field))
# Check roundtrip.
Expand Down
21 changes: 21 additions & 0 deletions tests/series/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.data.extensions import ArrowTensorArray

from daft import DataType, Series
from tests.conftest import *
from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES

ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric())
Expand Down Expand Up @@ -140,6 +141,26 @@ def test_series_concat_tensor_array_canonical(chunks) -> None:
np.testing.assert_equal(concated_arrow.to_numpy_ndarray(), expected)


@pytest.mark.parametrize("chunks", [1, 2, 3, 10])
def test_series_concat_extension_type(uuid_ext_type, chunks) -> None:
chunk_size = 3
storage_arrays = [
pa.array([f"{i}".encode() for i in range(j * chunk_size, (j + 1) * chunk_size)]) for j in range(chunks)
]
ext_arrays = [pa.ExtensionArray.from_storage(uuid_ext_type, storage) for storage in storage_arrays]
series = [Series.from_arrow(ext_array) for ext_array in ext_arrays]

concated = Series.concat(series)

assert concated.datatype() == DataType.extension(
uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), ""
)
concated_arrow = concated.to_arrow()
assert isinstance(concated_arrow.type, UuidType)
assert concated_arrow.type == uuid_ext_type
assert concated_arrow == pa.concat_arrays(ext_arrays)


@pytest.mark.parametrize("chunks", [1, 2, 3, 10])
def test_series_concat_pyobj(chunks) -> None:
series = []
Expand Down
15 changes: 15 additions & 0 deletions tests/series/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ def test_series_filter_on_struct_array() -> None:
assert result.to_pylist() == expected


def test_series_filter_on_extension_array(uuid_ext_type) -> None:
arr = pa.array(f"{i}".encode() for i in range(5))
data = pa.ExtensionArray.from_storage(uuid_ext_type, arr)

s = Series.from_arrow(data)
pymask = [False, True, True, None, False]
mask = Series.from_pylist(pymask)

result = s.filter(mask)

assert s.datatype() == result.datatype()
expected = [val for val, keep in zip(s.to_pylist(), pymask) if keep]
assert result.to_pylist() == expected


@pytest.mark.skipif(
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
Expand Down
46 changes: 46 additions & 0 deletions tests/series/test_if_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,52 @@ def test_series_if_else_struct(if_true, if_false, expected) -> None:
assert result.to_pylist() == expected


@pytest.mark.parametrize(
["if_true_storage", "if_false_storage", "expected_storage"],
[
# Same length, same type
(
pa.array([f"{i}".encode() for i in range(4)]),
pa.array([f"{i}".encode() for i in range(4, 8)]),
pa.array([b"0", b"5", None, b"3"]),
),
# Broadcast left
(
pa.array([b"0"]),
pa.array([f"{i}".encode() for i in range(4, 8)]),
pa.array([b"0", b"5", None, b"0"]),
),
# Broadcast right
(
pa.array([f"{i}".encode() for i in range(4)]),
pa.array([b"4"]),
pa.array([b"0", b"4", None, b"3"]),
),
# Broadcast both
(
pa.array([b"0"]),
pa.array([b"4"]),
pa.array([b"0", b"4", None, b"0"]),
),
],
)
def test_series_if_else_extension_type(uuid_ext_type, if_true_storage, if_false_storage, expected_storage) -> None:
if_true_arrow = pa.ExtensionArray.from_storage(uuid_ext_type, if_true_storage)
if_false_arrow = pa.ExtensionArray.from_storage(uuid_ext_type, if_false_storage)
expected_arrow = pa.ExtensionArray.from_storage(uuid_ext_type, expected_storage)
if_true_series = Series.from_arrow(if_true_arrow)
if_false_series = Series.from_arrow(if_false_arrow)
predicate_series = Series.from_arrow(pa.array([True, False, None, True]))

result = predicate_series.if_else(if_true_series, if_false_series)

assert result.datatype() == DataType.extension(
uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), ""
)
result_arrow = result.to_arrow()
assert result_arrow == expected_arrow


@pytest.mark.skipif(
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
Expand Down
22 changes: 22 additions & 0 deletions tests/series/test_size_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,28 @@ def test_series_struct_size_bytes(size, with_nulls) -> None:
assert s.size_bytes() == get_total_buffer_size(data) + conversion_to_large_string_bytes


@pytest.mark.parametrize("size", [1, 2, 8, 9, 16])
@pytest.mark.parametrize("with_nulls", [True, False])
def test_series_extension_type_size_bytes(uuid_ext_type, size, with_nulls) -> None:
pydata = [f"{i}".encode() for i in range(size)]

# TODO(Clark): Change to size > 0 condition when pyarrow extension arrays support generic construction on null arrays.
if with_nulls and size > 1:
pydata = pydata[:-1] + [None]
storage = pa.array(pydata)
data = pa.ExtensionArray.from_storage(uuid_ext_type, storage)

s = Series.from_arrow(data)

size_bytes = s.size_bytes()

assert s.datatype() == DataType.extension(
uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), ""
)
post_daft_cast_data = storage.cast(pa.large_binary())
assert size_bytes == get_total_buffer_size(post_daft_cast_data)


@pytest.mark.skipif(
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
Expand Down
21 changes: 21 additions & 0 deletions tests/series/test_take.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,27 @@ def test_series_struct_take() -> None:
assert result.to_pylist() == expected


def test_series_extension_type_take(uuid_ext_type) -> None:
pydata = [f"{i}".encode() for i in range(6)]
pydata[2] = None
storage = pa.array(pydata)
data = pa.ExtensionArray.from_storage(uuid_ext_type, storage)

s = Series.from_arrow(data)
assert s.datatype() == DataType.extension(
uuid_ext_type.NAME, DataType.from_arrow_type(uuid_ext_type.storage_type), ""
)
pyidx = [2, 0, None, 5]
idx = Series.from_pylist(pyidx)

result = s.take(idx)
assert result.datatype() == s.datatype()
assert len(result) == 4

expected = [pydata[i] if i is not None else None for i in pyidx]
assert result.to_pylist() == expected


@pytest.mark.skipif(
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
Expand Down
65 changes: 53 additions & 12 deletions tests/table/test_from_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,19 @@
}

if ARROW_VERSION >= (12, 0, 0):
ARROW_TYPE_ARRAYS["ext_type"] = pa.FixedShapeTensorArray.from_numpy_ndarray(np.array(PYTHON_TYPE_ARRAYS["tensor"]))
ARROW_ROUNDTRIP_TYPES["ext_type"] = pa.fixed_shape_tensor(pa.int64(), (2, 2))
ARROW_ROUNDTRIP_TYPES["canonical_tensor"] = pa.fixed_shape_tensor(pa.int64(), (2, 2))
ARROW_TYPE_ARRAYS["canonical_tensor"] = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.array(PYTHON_TYPE_ARRAYS["tensor"])
)


def _with_uuid_ext_type(uuid_ext_type) -> tuple[dict, dict]:
arrow_roundtrip_types = ARROW_ROUNDTRIP_TYPES.copy()
arrow_type_arrays = ARROW_TYPE_ARRAYS.copy()
arrow_roundtrip_types["ext_type"] = uuid_ext_type
storage = ARROW_TYPE_ARRAYS["binary"]
arrow_type_arrays["ext_type"] = pa.ExtensionArray.from_storage(uuid_ext_type, storage)
return arrow_roundtrip_types, arrow_type_arrays


def test_from_pydict_roundtrip() -> None:
Expand All @@ -141,24 +152,26 @@ def test_from_pydict_roundtrip() -> None:
assert table.to_arrow() == expected_table


def test_from_pydict_arrow_roundtrip() -> None:
table = Table.from_pydict(ARROW_TYPE_ARRAYS)
def test_from_pydict_arrow_roundtrip(uuid_ext_type) -> None:
arrow_roundtrip_types, arrow_type_arrays = _with_uuid_ext_type(uuid_ext_type)
table = Table.from_pydict(arrow_type_arrays)
assert len(table) == 2
assert set(table.column_names()) == set(ARROW_TYPE_ARRAYS.keys())
assert set(table.column_names()) == set(arrow_type_arrays.keys())
for field in table.schema():
assert field.dtype == DataType.from_arrow_type(ARROW_TYPE_ARRAYS[field.name].type)
expected_table = pa.table(ARROW_TYPE_ARRAYS).cast(pa.schema(ARROW_ROUNDTRIP_TYPES))
assert field.dtype == DataType.from_arrow_type(arrow_type_arrays[field.name].type)
expected_table = pa.table(arrow_type_arrays).cast(pa.schema(arrow_roundtrip_types))
assert table.to_arrow() == expected_table


def test_from_arrow_roundtrip() -> None:
pa_table = pa.table(ARROW_TYPE_ARRAYS)
def test_from_arrow_roundtrip(uuid_ext_type) -> None:
arrow_roundtrip_types, arrow_type_arrays = _with_uuid_ext_type(uuid_ext_type)
pa_table = pa.table(arrow_type_arrays)
table = Table.from_arrow(pa_table)
assert len(table) == 2
assert set(table.column_names()) == set(ARROW_TYPE_ARRAYS.keys())
assert set(table.column_names()) == set(arrow_type_arrays.keys())
for field in table.schema():
assert field.dtype == DataType.from_arrow_type(ARROW_TYPE_ARRAYS[field.name].type)
expected_table = pa.table(ARROW_TYPE_ARRAYS).cast(pa.schema(ARROW_ROUNDTRIP_TYPES))
assert field.dtype == DataType.from_arrow_type(arrow_type_arrays[field.name].type)
expected_table = pa.table(arrow_type_arrays).cast(pa.schema(arrow_roundtrip_types))
assert table.to_arrow() == expected_table


Expand Down Expand Up @@ -231,6 +244,20 @@ def test_from_pydict_arrow_struct_array() -> None:
assert daft_table.to_arrow()["a"].combine_chunks() == expected


def test_from_pydict_arrow_extension_array(uuid_ext_type) -> None:
pydata = [f"{i}".encode() for i in range(6)]
pydata[2] = None
storage = pa.array(pydata)
arrow_arr = pa.ExtensionArray.from_storage(uuid_ext_type, storage)
daft_table = Table.from_pydict({"a": arrow_arr})
assert "a" in daft_table.column_names()
# Although Daft will internally represent the binary storage array as a large_binary array,
# it should be cast back to the ingress extension type.
result = daft_table.to_arrow()["a"].combine_chunks()
assert result.type == uuid_ext_type
assert result == arrow_arr


def test_from_pydict_arrow_deeply_nested() -> None:
# Test a struct of lists of struct of lists of strings.
data = [{"a": [{"b": ["foo", "bar"]}]}, {"a": [{"b": ["baz", "quux"]}]}]
Expand Down Expand Up @@ -385,6 +412,20 @@ def test_from_arrow_struct_array() -> None:
assert daft_table.to_arrow()["a"].combine_chunks() == expected


def test_from_arrow_extension_array(uuid_ext_type) -> None:
pydata = [f"{i}".encode() for i in range(6)]
pydata[2] = None
storage = pa.array(pydata)
arrow_arr = pa.ExtensionArray.from_storage(uuid_ext_type, storage)
daft_table = Table.from_arrow(pa.table({"a": arrow_arr}))
assert "a" in daft_table.column_names()
# Although Daft will internally represent the binary storage array as a large_binary array,
# it should be cast back to the ingress extension type.
result = daft_table.to_arrow()["a"].combine_chunks()
assert result.type == uuid_ext_type
assert result == arrow_arr


def test_from_arrow_deeply_nested() -> None:
# Test a struct of lists of struct of lists of strings.
data = [{"a": [{"b": ["foo", "bar"]}]}, {"a": [{"b": ["baz", "quux"]}]}]
Expand Down

0 comments on commit 59b18a7

Please sign in to comment.