From dd888ec04a12b56505fe122acdf693a1d9c2a31f Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Wed, 20 Nov 2024 02:35:53 +0000 Subject: [PATCH 1/7] introduce bucket transform --- pyiceberg/transforms.py | 22 +++++++++++++++++++++- tests/test_transforms.py | 29 +++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 1056fa525b..e89142d780 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -309,7 +309,27 @@ def __repr__(self) -> str: return f"BucketTransform(num_buckets={self._num_buckets})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + import pyarrow as pa + from pyiceberg_core import transform as pyiceberg_core_transform + + ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) + + def bucket(array: ArrayLike) -> ArrayLike: + if isinstance(array, pa.Array): + return pyiceberg_core_transform.bucket(array, self._num_buckets) + elif isinstance(array, pa.ChunkedArray): + result_chunks = [] + for arr in array.iterchunks(): + result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets)) + return pa.chunked_array(result_chunks) + else: + raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") + + return bucket + + @property + def supports_pyarrow_transform(self) -> bool: + return True class TimeResolution(IntEnum): diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3a9ffd6009..7de44de2c8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,10 +17,11 @@ # pylint: disable=eval-used,protected-access,redefined-outer-name from datetime import date from decimal import Decimal -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import Any, Callable, Optional, Union from uuid import UUID import mmh3 as mmh3 +import pyarrow as pa import pytest from pydantic import ( BeforeValidator, @@ -112,9 +113,6 @@ timestamptz_to_micros, ) -if TYPE_CHECKING: - import pyarrow as pa - @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1840,3 +1838,26 @@ def test_ymd_pyarrow_transforms( else: with pytest.raises(ValueError): transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, num_buckets", + [ + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ( + IntegerType(), + pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]), + pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]), + 10, + ), + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ], +) +def test_bucket_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + num_buckets: int, +) -> None: + transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) + assert expected == transform.pyarrow_transform(source_type)(input_arr) From bd80f39b3b426aefdd2a2866f11a7f86a5dcb6a6 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Tue, 24 Dec 2024 18:13:52 +0000 Subject: [PATCH 2/7] include pyiceberg-core --- poetry.lock | 20 ++++++++++++++++++-- pyproject.toml | 6 ++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 048578f3aa..80741e078c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "adlfs" @@ -3288,6 +3288,21 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyiceberg-core" +version = "0.4.0" +description = "" +optional = true +python-versions = "*" +files = [ + {file = "pyiceberg_core-0.4.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5aec569271c96e18428d542f9b7007117a7232c06017f95cb239d42e952ad3b4"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e74773e58efa4df83aba6f6265cdd41e446fa66fa4e343ca86395fed9f209ae"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7675d21a54bf3753c740d8df78ad7efe33f438096844e479d4f3493f84830925"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7058ad935a40b1838e4cdc5febd768878c1a51f83dca005d5a52a7fa280a2489"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-win_amd64.whl", hash = "sha256:a83eb4c2307ae3dd321a9360828fb043a4add2cc9797bef0bafa20894488fb07"}, + {file = "pyiceberg_core-0.4.0.tar.gz", hash = "sha256:d2e6138707868477b806ed354aee9c476e437913a331cb9ad9ad46b4054cd11f"}, +] + [[package]] name = "pyjwt" version = "2.9.0" @@ -4658,6 +4673,7 @@ glue = ["boto3", "mypy-boto3-glue"] hive = ["thrift"] pandas = ["pandas", "pyarrow"] pyarrow = ["pyarrow"] +pyiceberg-core = ["pyiceberg-core"] ray = ["pandas", "pyarrow", "ray", "ray"] s3fs = ["s3fs"] snappy = ["python-snappy"] @@ -4668,4 +4684,4 @@ zstandard = ["zstandard"] [metadata] lock-version = "2.0" python-versions = "^3.9, <3.13, !=3.9.7" -content-hash = "faf7cc64ff950544f90d04eea2d54bfcc118799f2c376aa43149a1f91637033a" +content-hash = "ba83ed937caf64aee5f9f449b3630965d8c805f75487321ead7cbed40df2c91a" diff --git a/pyproject.toml b/pyproject.toml index 09461ccd2f..97aca6c128 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true } sqlalchemy = { version = "^2.0.18", optional = true } getdaft = { version = ">=0.2.12", optional = true } cachetools = "^5.5.0" +pyiceberg-core = { version = "^0.4.0", optional = true } [tool.poetry.group.dev.dependencies] pytest = "7.4.4" @@ -827,6 +828,10 @@ ignore_missing_imports = true module = "daft.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "pyiceberg_core.*" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "pyparsing.*" ignore_missing_imports = true @@ -886,6 +891,7 @@ zstandard = ["zstandard"] sql-postgres = ["sqlalchemy", "psycopg2-binary"] sql-sqlite = ["sqlalchemy"] gcsfs = ["gcsfs"] +pyiceberg-core = ["pyiceberg-core"] [tool.pytest.ini_options] markers = [ From 27ade9a0d7c80d11b0105640598800d57e8147eb Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Wed, 20 Nov 2024 02:35:53 +0000 Subject: [PATCH 3/7] introduce bucket transform --- pyiceberg/transforms.py | 22 +++++++++++++++++++++- tests/test_transforms.py | 29 +++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 84e1c942d3..44d32c9449 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -309,7 +309,27 @@ def __repr__(self) -> str: return f"BucketTransform(num_buckets={self._num_buckets})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + import pyarrow as pa + from pyiceberg_core import transform as pyiceberg_core_transform + + ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) + + def bucket(array: ArrayLike) -> ArrayLike: + if isinstance(array, pa.Array): + return pyiceberg_core_transform.bucket(array, self._num_buckets) + elif isinstance(array, pa.ChunkedArray): + result_chunks = [] + for arr in array.iterchunks(): + result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets)) + return pa.chunked_array(result_chunks) + else: + raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") + + return bucket + + @property + def supports_pyarrow_transform(self) -> bool: + return True class TimeResolution(IntEnum): diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 7ebab87e3a..50ed775272 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -18,10 +18,11 @@ # pylint: disable=eval-used,protected-access,redefined-outer-name from datetime import date from decimal import Decimal -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import Any, Callable, Optional, Union from uuid import UUID import mmh3 as mmh3 +import pyarrow as pa import pytest from pydantic import ( BeforeValidator, @@ -116,9 +117,6 @@ timestamptz_to_micros, ) -if TYPE_CHECKING: - import pyarrow as pa - @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1563,3 +1561,26 @@ def test_ymd_pyarrow_transforms( else: with pytest.raises(ValueError): transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, num_buckets", + [ + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ( + IntegerType(), + pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]), + pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]), + 10, + ), + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ], +) +def test_bucket_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + num_buckets: int, +) -> None: + transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) + assert expected == transform.pyarrow_transform(source_type)(input_arr) From fcd654c03cfe23ddcb16ba18b79e5ae931803ee0 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Tue, 24 Dec 2024 18:13:52 +0000 Subject: [PATCH 4/7] include pyiceberg-core --- poetry.lock | 16 ++++++++++++++++ pyproject.toml | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/poetry.lock b/poetry.lock index 6e4f55f39a..ba3bbddfb5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3327,6 +3327,21 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyiceberg-core" +version = "0.4.0" +description = "" +optional = true +python-versions = "*" +files = [ + {file = "pyiceberg_core-0.4.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5aec569271c96e18428d542f9b7007117a7232c06017f95cb239d42e952ad3b4"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e74773e58efa4df83aba6f6265cdd41e446fa66fa4e343ca86395fed9f209ae"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7675d21a54bf3753c740d8df78ad7efe33f438096844e479d4f3493f84830925"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7058ad935a40b1838e4cdc5febd768878c1a51f83dca005d5a52a7fa280a2489"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-win_amd64.whl", hash = "sha256:a83eb4c2307ae3dd321a9360828fb043a4add2cc9797bef0bafa20894488fb07"}, + {file = "pyiceberg_core-0.4.0.tar.gz", hash = "sha256:d2e6138707868477b806ed354aee9c476e437913a331cb9ad9ad46b4054cd11f"}, +] + [[package]] name = "pyjwt" version = "2.10.1" @@ -4742,6 +4757,7 @@ glue = ["boto3", "mypy-boto3-glue"] hive = ["thrift"] pandas = ["pandas", "pyarrow"] pyarrow = ["pyarrow"] +pyiceberg-core = ["pyiceberg-core"] ray = ["pandas", "pyarrow", "ray", "ray"] rest-sigv4 = ["boto3"] s3fs = ["s3fs"] diff --git a/pyproject.toml b/pyproject.toml index a2737c3f92..715388c290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true } sqlalchemy = { version = "^2.0.18", optional = true } getdaft = { version = ">=0.2.12", optional = true } cachetools = "^5.5.0" +pyiceberg-core = { version = "^0.4.0", optional = true } [tool.poetry.group.dev.dependencies] pytest = "7.4.4" @@ -827,6 +828,10 @@ ignore_missing_imports = true module = "daft.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "pyiceberg_core.*" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "pyparsing.*" ignore_missing_imports = true @@ -887,6 +892,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"] sql-sqlite = ["sqlalchemy"] gcsfs = ["gcsfs"] rest-sigv4 = ["boto3"] +pyiceberg-core = ["pyiceberg-core"] [tool.pytest.ini_options] markers = [ From a0a9c589e2c55ca3f7360b769ff30b6bc33eb2d4 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Tue, 24 Dec 2024 18:33:32 +0000 Subject: [PATCH 5/7] resolve poetry conflict --- poetry.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index ba3bbddfb5..7f5e788e3d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "adlfs" @@ -4769,4 +4769,4 @@ zstandard = ["zstandard"] [metadata] lock-version = "2.0" python-versions = "^3.9, !=3.9.7" -content-hash = "2084f03c93f2d1085a5671a171c6cbeb96d9688079270ceca38b0854fe9e0520" +content-hash = "5d0dd91ca2837bd93fe8a2d17b504f992d0c3095a278de43982e89a65c67ee66" From 05c440f545568ba575f4ac15b379c9e3e023bb56 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Tue, 24 Dec 2024 20:45:18 +0000 Subject: [PATCH 6/7] support truncate transforms --- pyiceberg/transforms.py | 44 ++++--- .../test_writes/test_partitioned_writes.py | 111 +++++++++++++----- tests/test_transforms.py | 17 +++ 3 files changed, 128 insertions(+), 44 deletions(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 44d32c9449..12f25ed7ac 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -85,6 +85,8 @@ if TYPE_CHECKING: import pyarrow as pa + ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) + S = TypeVar("S") T = TypeVar("T") @@ -193,6 +195,24 @@ def supports_pyarrow_transform(self) -> bool: @abstractmethod def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + def _pyiceberg_transform_wrapper( + self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any + ) -> Callable[["ArrayLike"], "ArrayLike"]: + import pyarrow as pa + + def _transform(array: "ArrayLike") -> "ArrayLike": + if isinstance(array, pa.Array): + return transform_func(array, *args) + elif isinstance(array, pa.ChunkedArray): + result_chunks = [] + for arr in array.iterchunks(): + result_chunks.append(transform_func(arr, *args)) + return pa.chunked_array(result_chunks) + else: + raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") + + return _transform + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -309,23 +329,9 @@ def __repr__(self) -> str: return f"BucketTransform(num_buckets={self._num_buckets})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - import pyarrow as pa from pyiceberg_core import transform as pyiceberg_core_transform - ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) - - def bucket(array: ArrayLike) -> ArrayLike: - if isinstance(array, pa.Array): - return pyiceberg_core_transform.bucket(array, self._num_buckets) - elif isinstance(array, pa.ChunkedArray): - result_chunks = [] - for arr in array.iterchunks(): - result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets)) - return pa.chunked_array(result_chunks) - else: - raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") - - return bucket + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets) @property def supports_pyarrow_transform(self) -> bool: @@ -847,7 +853,13 @@ def __repr__(self) -> str: return f"TruncateTransform(width={self._width})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + from pyiceberg_core import transform as pyiceberg_core_transform + + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width) + + @property + def supports_pyarrow_transform(self) -> bool: + return True @singledispatch diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index b92c338931..16b668fd85 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -719,50 +719,105 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non @pytest.mark.parametrize( "spec", [ - # mixed with non-identity is not supported - ( - PartitionSpec( - PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"), - PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), - ) - ), - # none of non-identity is supported - (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), - (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), - (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), - (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), - (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), - (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), ], ) -def test_unsupported_transform( - spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +@pytest.mark.parametrize("format_version", [1, 2]) +def test_truncate_transform( + spec: PartitionSpec, + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: int, ) -> None: - identifier = "default.unsupported_transform" + identifier = "default.truncate_transform" try: session_catalog.drop_table(identifier=identifier) except NoSuchTableError: pass - tbl = session_catalog.create_table( + tbl = _create_table( + session_catalog=session_catalog, identifier=identifier, - schema=TABLE_SCHEMA, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], partition_spec=spec, - properties={"format-version": "1"}, ) - with pytest.raises( - ValueError, - match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", - ): - tbl.append(arrow_table_with_null) + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == 3 + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == 3 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "spec, expected_rows", + [ + # none of non-identity is supported + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3), + (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_bucket_transform( + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + spec: PartitionSpec, + expected_rows: int, + format_version: int, +) -> None: + identifier = "default.bucket_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=spec, + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_rows + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_rows @pytest.mark.integration diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 50ed775272..2fa459527e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1584,3 +1584,20 @@ def test_bucket_pyarrow_transforms( ) -> None: transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) assert expected == transform.pyarrow_transform(source_type)(input_arr) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, width", + [ + (StringType(), pa.array(["hello", "iceberg"]), pa.array(["hel", "ice"]), 3), + (IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10), + ], +) +def test_truncate_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + width: int, +) -> None: + transform: Transform[Any, Any] = TruncateTransform(width=width) + assert expected == transform.pyarrow_transform(source_type)(input_arr) From 7079265176893450331cd719c421935b61b28fe3 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Tue, 24 Dec 2024 19:21:34 -0500 Subject: [PATCH 7/7] Remove stale comment --- tests/integration/test_writes/test_partitioned_writes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 16b668fd85..3eb3bd68a8 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -768,7 +768,6 @@ def test_truncate_transform( @pytest.mark.parametrize( "spec, expected_rows", [ - # none of non-identity is supported (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3), (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3), (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3),