diff --git a/poetry.lock b/poetry.lock index 6e4f55f39a..cc8b4271e6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3327,6 +3327,21 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyiceberg-core" +version = "0.4.0" +description = "" +optional = true +python-versions = "*" +files = [ + {file = "pyiceberg_core-0.4.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5aec569271c96e18428d542f9b7007117a7232c06017f95cb239d42e952ad3b4"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e74773e58efa4df83aba6f6265cdd41e446fa66fa4e343ca86395fed9f209ae"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7675d21a54bf3753c740d8df78ad7efe33f438096844e479d4f3493f84830925"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7058ad935a40b1838e4cdc5febd768878c1a51f83dca005d5a52a7fa280a2489"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-win_amd64.whl", hash = "sha256:a83eb4c2307ae3dd321a9360828fb043a4add2cc9797bef0bafa20894488fb07"}, + {file = "pyiceberg_core-0.4.0.tar.gz", hash = "sha256:d2e6138707868477b806ed354aee9c476e437913a331cb9ad9ad46b4054cd11f"}, +] + [[package]] name = "pyjwt" version = "2.10.1" @@ -4742,6 +4757,7 @@ glue = ["boto3", "mypy-boto3-glue"] hive = ["thrift"] pandas = ["pandas", "pyarrow"] pyarrow = ["pyarrow"] +pyiceberg-core = ["pyiceberg-core"] ray = ["pandas", "pyarrow", "ray", "ray"] rest-sigv4 = ["boto3"] s3fs = ["s3fs"] @@ -4753,4 +4769,4 @@ zstandard = ["zstandard"] [metadata] lock-version = "2.0" python-versions = "^3.9, !=3.9.7" -content-hash = "2084f03c93f2d1085a5671a171c6cbeb96d9688079270ceca38b0854fe9e0520" +content-hash = "5d0dd91ca2837bd93fe8a2d17b504f992d0c3095a278de43982e89a65c67ee66" diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 84e1c942d3..12f25ed7ac 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -85,6 +85,8 @@ if TYPE_CHECKING: import pyarrow as pa + ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) + S = TypeVar("S") T = TypeVar("T") @@ -193,6 +195,24 @@ def supports_pyarrow_transform(self) -> bool: @abstractmethod def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + def _pyiceberg_transform_wrapper( + self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any + ) -> Callable[["ArrayLike"], "ArrayLike"]: + import pyarrow as pa + + def _transform(array: "ArrayLike") -> "ArrayLike": + if isinstance(array, pa.Array): + return transform_func(array, *args) + elif isinstance(array, pa.ChunkedArray): + result_chunks = [] + for arr in array.iterchunks(): + result_chunks.append(transform_func(arr, *args)) + return pa.chunked_array(result_chunks) + else: + raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") + + return _transform + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -309,7 +329,13 @@ def __repr__(self) -> str: return f"BucketTransform(num_buckets={self._num_buckets})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + from pyiceberg_core import transform as pyiceberg_core_transform + + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets) + + @property + def supports_pyarrow_transform(self) -> bool: + return True class TimeResolution(IntEnum): @@ -827,7 +853,13 @@ def __repr__(self) -> str: return f"TruncateTransform(width={self._width})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + from pyiceberg_core import transform as pyiceberg_core_transform + + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width) + + @property + def supports_pyarrow_transform(self) -> bool: + return True @singledispatch diff --git a/pyproject.toml b/pyproject.toml index a2737c3f92..715388c290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true } sqlalchemy = { version = "^2.0.18", optional = true } getdaft = { version = ">=0.2.12", optional = true } cachetools = "^5.5.0" +pyiceberg-core = { version = "^0.4.0", optional = true } [tool.poetry.group.dev.dependencies] pytest = "7.4.4" @@ -827,6 +828,10 @@ ignore_missing_imports = true module = "daft.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "pyiceberg_core.*" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "pyparsing.*" ignore_missing_imports = true @@ -887,6 +892,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"] sql-sqlite = ["sqlalchemy"] gcsfs = ["gcsfs"] rest-sigv4 = ["boto3"] +pyiceberg-core = ["pyiceberg-core"] [tool.pytest.ini_options] markers = [ diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index b92c338931..3eb3bd68a8 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -719,50 +719,104 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non @pytest.mark.parametrize( "spec", [ - # mixed with non-identity is not supported - ( - PartitionSpec( - PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"), - PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), - ) - ), - # none of non-identity is supported - (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), - (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), - (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), - (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), - (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), - (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), ], ) -def test_unsupported_transform( - spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +@pytest.mark.parametrize("format_version", [1, 2]) +def test_truncate_transform( + spec: PartitionSpec, + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: int, ) -> None: - identifier = "default.unsupported_transform" + identifier = "default.truncate_transform" try: session_catalog.drop_table(identifier=identifier) except NoSuchTableError: pass - tbl = session_catalog.create_table( + tbl = _create_table( + session_catalog=session_catalog, identifier=identifier, - schema=TABLE_SCHEMA, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], partition_spec=spec, - properties={"format-version": "1"}, ) - with pytest.raises( - ValueError, - match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", - ): - tbl.append(arrow_table_with_null) + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == 3 + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == 3 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "spec, expected_rows", + [ + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3), + (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_bucket_transform( + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + spec: PartitionSpec, + expected_rows: int, + format_version: int, +) -> None: + identifier = "default.bucket_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=spec, + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_rows + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_rows @pytest.mark.integration diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 7ebab87e3a..2fa459527e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -18,10 +18,11 @@ # pylint: disable=eval-used,protected-access,redefined-outer-name from datetime import date from decimal import Decimal -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import Any, Callable, Optional, Union from uuid import UUID import mmh3 as mmh3 +import pyarrow as pa import pytest from pydantic import ( BeforeValidator, @@ -116,9 +117,6 @@ timestamptz_to_micros, ) -if TYPE_CHECKING: - import pyarrow as pa - @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1563,3 +1561,43 @@ def test_ymd_pyarrow_transforms( else: with pytest.raises(ValueError): transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, num_buckets", + [ + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ( + IntegerType(), + pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]), + pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]), + 10, + ), + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ], +) +def test_bucket_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + num_buckets: int, +) -> None: + transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) + assert expected == transform.pyarrow_transform(source_type)(input_arr) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, width", + [ + (StringType(), pa.array(["hello", "iceberg"]), pa.array(["hel", "ice"]), 3), + (IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10), + ], +) +def test_truncate_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + width: int, +) -> None: + transform: Transform[Any, Any] = TruncateTransform(width=width) + assert expected == transform.pyarrow_transform(source_type)(input_arr)