Skip to content

Commit

Permalink
support truncate transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy committed Dec 24, 2024
1 parent a4137e0 commit 05c440f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 44 deletions.
44 changes: 28 additions & 16 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
if TYPE_CHECKING:
import pyarrow as pa

ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)

S = TypeVar("S")
T = TypeVar("T")

Expand Down Expand Up @@ -193,6 +195,24 @@ def supports_pyarrow_transform(self) -> bool:
@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...

def _pyiceberg_transform_wrapper(
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
) -> Callable[["ArrayLike"], "ArrayLike"]:
import pyarrow as pa

def _transform(array: "ArrayLike") -> "ArrayLike":
if isinstance(array, pa.Array):
return transform_func(array, *args)
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(transform_func(arr, *args))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return _transform


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -309,23 +329,9 @@ def __repr__(self) -> str:
return f"BucketTransform(num_buckets={self._num_buckets})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
from pyiceberg_core import transform as pyiceberg_core_transform

ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)

def bucket(array: ArrayLike) -> ArrayLike:
if isinstance(array, pa.Array):
return pyiceberg_core_transform.bucket(array, self._num_buckets)
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return bucket
return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)

@property
def supports_pyarrow_transform(self) -> bool:
Expand Down Expand Up @@ -847,7 +853,13 @@ def __repr__(self) -> str:
return f"TruncateTransform(width={self._width})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)

@property
def supports_pyarrow_transform(self) -> bool:
return True


@singledispatch
Expand Down
111 changes: 83 additions & 28 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,50 +719,105 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
@pytest.mark.parametrize(
"spec",
[
# mixed with non-identity is not supported
(
PartitionSpec(
PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"),
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"),
)
),
# none of non-identity is supported
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))),
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))),
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))),
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))),
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))),
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))),
],
)
def test_unsupported_transform(
spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table
@pytest.mark.parametrize("format_version", [1, 2])
def test_truncate_transform(
spec: PartitionSpec,
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
format_version: int,
) -> None:
identifier = "default.unsupported_transform"
identifier = "default.truncate_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
schema=TABLE_SCHEMA,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
properties={"format-version": "1"},
)

with pytest.raises(
ValueError,
match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *",
):
tbl.append(arrow_table_with_null)
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == 3
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == 3


@pytest.mark.integration
@pytest.mark.parametrize(
"spec, expected_rows",
[
# none of non-identity is supported
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3),
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3),
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3),
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3),
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_bucket_transform(
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
spec: PartitionSpec,
expected_rows: int,
format_version: int,
) -> None:
identifier = "default.bucket_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
)

assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == expected_rows
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == expected_rows


@pytest.mark.integration
Expand Down
17 changes: 17 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,3 +1584,20 @@ def test_bucket_pyarrow_transforms(
) -> None:
transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets)
assert expected == transform.pyarrow_transform(source_type)(input_arr)


@pytest.mark.parametrize(
"source_type, input_arr, expected, width",
[
(StringType(), pa.array(["hello", "iceberg"]), pa.array(["hel", "ice"]), 3),
(IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10),
],
)
def test_truncate_pyarrow_transforms(
source_type: PrimitiveType,
input_arr: Union[pa.Array, pa.ChunkedArray],
expected: Union[pa.Array, pa.ChunkedArray],
width: int,
) -> None:
transform: Transform[Any, Any] = TruncateTransform(width=width)
assert expected == transform.pyarrow_transform(source_type)(input_arr)

0 comments on commit 05c440f

Please sign in to comment.