Skip to content

Commit

Permalink
introduce bucket transform
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy committed Nov 20, 2024
1 parent 93ebd39 commit dd888ec
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
22 changes: 21 additions & 1 deletion pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,27 @@ def __repr__(self) -> str:
return f"BucketTransform(num_buckets={self._num_buckets})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()
import pyarrow as pa
from pyiceberg_core import transform as pyiceberg_core_transform

ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)

def bucket(array: ArrayLike) -> ArrayLike:
if isinstance(array, pa.Array):
return pyiceberg_core_transform.bucket(array, self._num_buckets)
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return bucket

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TimeResolution(IntEnum):
Expand Down
29 changes: 25 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
# pylint: disable=eval-used,protected-access,redefined-outer-name
from datetime import date
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import Any, Callable, Optional, Union
from uuid import UUID

import mmh3 as mmh3
import pyarrow as pa
import pytest
from pydantic import (
BeforeValidator,
Expand Down Expand Up @@ -112,9 +113,6 @@
timestamptz_to_micros,
)

if TYPE_CHECKING:
import pyarrow as pa


@pytest.mark.parametrize(
"test_input,test_type,expected",
Expand Down Expand Up @@ -1840,3 +1838,26 @@ def test_ymd_pyarrow_transforms(
else:
with pytest.raises(ValueError):
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])


@pytest.mark.parametrize(
"source_type, input_arr, expected, num_buckets",
[
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
(
IntegerType(),
pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]),
pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]),
10,
),
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
],
)
def test_bucket_pyarrow_transforms(
source_type: PrimitiveType,
input_arr: Union[pa.Array, pa.ChunkedArray],
expected: Union[pa.Array, pa.ChunkedArray],
num_buckets: int,
) -> None:
transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets)
assert expected == transform.pyarrow_transform(source_type)(input_arr)

0 comments on commit dd888ec

Please sign in to comment.