Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Bucket and Truncate transforms on write #1345

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 34 additions & 2 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
if TYPE_CHECKING:
import pyarrow as pa

ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)

S = TypeVar("S")
T = TypeVar("T")

Expand Down Expand Up @@ -193,6 +195,24 @@ def supports_pyarrow_transform(self) -> bool:
@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...

def _pyiceberg_transform_wrapper(
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
) -> Callable[["ArrayLike"], "ArrayLike"]:
import pyarrow as pa

def _transform(array: "ArrayLike") -> "ArrayLike":
if isinstance(array, pa.Array):
return transform_func(array, *args)
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(transform_func(arr, *args))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return _transform


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -309,7 +329,13 @@ def __repr__(self) -> str:
return f"BucketTransform(num_buckets={self._num_buckets})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TimeResolution(IntEnum):
Expand Down Expand Up @@ -827,7 +853,13 @@ def __repr__(self) -> str:
return f"TruncateTransform(width={self._width})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)

@property
def supports_pyarrow_transform(self) -> bool:
return True


@singledispatch
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true }
sqlalchemy = { version = "^2.0.18", optional = true }
getdaft = { version = ">=0.2.12", optional = true }
cachetools = "^5.5.0"
pyiceberg-core = { version = "^0.4.0", optional = true }

[tool.poetry.group.dev.dependencies]
pytest = "7.4.4"
Expand Down Expand Up @@ -827,6 +828,10 @@ ignore_missing_imports = true
module = "daft.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "pyiceberg_core.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "pyparsing.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -887,6 +892,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"]
sql-sqlite = ["sqlalchemy"]
gcsfs = ["gcsfs"]
rest-sigv4 = ["boto3"]
pyiceberg-core = ["pyiceberg-core"]

[tool.pytest.ini_options]
markers = [
Expand Down
110 changes: 82 additions & 28 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,50 +719,104 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
@pytest.mark.parametrize(
"spec",
[
# mixed with non-identity is not supported
(
PartitionSpec(
PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"),
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"),
)
),
# none of non-identity is supported
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))),
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))),
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))),
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))),
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))),
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))),
],
)
def test_unsupported_transform(
spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table
@pytest.mark.parametrize("format_version", [1, 2])
def test_truncate_transform(
spec: PartitionSpec,
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
format_version: int,
) -> None:
identifier = "default.unsupported_transform"
identifier = "default.truncate_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
schema=TABLE_SCHEMA,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
properties={"format-version": "1"},
)

with pytest.raises(
ValueError,
match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *",
):
tbl.append(arrow_table_with_null)
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == 3
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == 3


@pytest.mark.integration
@pytest.mark.parametrize(
"spec, expected_rows",
[
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3),
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3),
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3),
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3),
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_bucket_transform(
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
spec: PartitionSpec,
expected_rows: int,
format_version: int,
) -> None:
identifier = "default.bucket_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
)

assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == expected_rows
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == expected_rows


@pytest.mark.integration
Expand Down
46 changes: 42 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
# pylint: disable=eval-used,protected-access,redefined-outer-name
from datetime import date
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import Any, Callable, Optional, Union
from uuid import UUID

import mmh3 as mmh3
import pyarrow as pa
import pytest
from pydantic import (
BeforeValidator,
Expand Down Expand Up @@ -116,9 +117,6 @@
timestamptz_to_micros,
)

if TYPE_CHECKING:
import pyarrow as pa


@pytest.mark.parametrize(
"test_input,test_type,expected",
Expand Down Expand Up @@ -1563,3 +1561,43 @@ def test_ymd_pyarrow_transforms(
else:
with pytest.raises(ValueError):
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])


@pytest.mark.parametrize(
"source_type, input_arr, expected, num_buckets",
[
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
(
IntegerType(),
pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]),
pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]),
10,
),
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
],
)
def test_bucket_pyarrow_transforms(
source_type: PrimitiveType,
input_arr: Union[pa.Array, pa.ChunkedArray],
expected: Union[pa.Array, pa.ChunkedArray],
num_buckets: int,
Comment on lines +1580 to +1583
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wydt of reordering these for readability? num_buckets, source_type and input_arr are configs of the BucketTransform; expected is the output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think I feel indifferent here - there’s something nice about having the input and expected arrays side by side

) -> None:
transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets)
assert expected == transform.pyarrow_transform(source_type)(input_arr)


@pytest.mark.parametrize(
"source_type, input_arr, expected, width",
[
(StringType(), pa.array(["hello", "iceberg"]), pa.array(["hel", "ice"]), 3),
(IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10),
],
)
def test_truncate_pyarrow_transforms(
source_type: PrimitiveType,
input_arr: Union[pa.Array, pa.ChunkedArray],
expected: Union[pa.Array, pa.ChunkedArray],
width: int,
) -> None:
transform: Transform[Any, Any] = TruncateTransform(width=width)
assert expected == transform.pyarrow_transform(source_type)(input_arr)