Skip to content

Commit

Permalink
Support Location Providers (#1452)
Browse files Browse the repository at this point in the history
* Skeletal implementation

* First attempt at hashing locations

* Relocate to table submodule; code and comment improvements

* Add unit tests

* Remove entropy check

* Nit: Prefer `self.table_properties`

* Remove special character testing

* Add integration tests for writes

* Move all `LocationProviders`-related code into locations.py

* Nit: tiny for loop refactor

* Fix typo

* Object storage as default location provider

* Update tests/integration/test_writes/test_partitioned_writes.py

Co-authored-by: Kevin Liu <[email protected]>

* Test entropy in test_object_storage_injects_entropy

* Refactor integration tests to use properties and omit when default once

* Use a different table property for custom location provision

* write.location-provider.py-impl -> write.py-location-provider.impl

* Make lint

* Move location provider loading into `write_file` for back-compat

* Make object storage no longer the default

* Add test case for partitioned paths disabled but with no partition special case

* Moved constants within ObjectStoreLocationProvider

---------

Co-authored-by: Sreesh Maheshwar <[email protected]>
Co-authored-by: Kevin Liu <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 691740d commit c68b9b1
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 8 deletions.
7 changes: 6 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
visit,
visit_with_partner,
)
from pyiceberg.table.locations import load_location_provider
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
Expand Down Expand Up @@ -2305,6 +2306,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT,
default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT,
)
location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties)

def write_parquet(task: WriteTask) -> DataFile:
table_schema = table_metadata.schema()
Expand All @@ -2327,7 +2329,10 @@ def write_parquet(task: WriteTask) -> DataFile:
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
file_path = f"{table_metadata.location}/data/{task.generate_data_file_path('parquet')}"
file_path = location_provider.new_data_location(
data_file_name=task.generate_data_file_filename("parquet"),
partition_key=task.partition_key,
)
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:
Expand Down
15 changes: 8 additions & 7 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ class TableProperties:
WRITE_PARTITION_SUMMARY_LIMIT = "write.summary.partition-limit"
WRITE_PARTITION_SUMMARY_LIMIT_DEFAULT = 0

WRITE_PY_LOCATION_PROVIDER_IMPL = "write.py-location-provider.impl"

OBJECT_STORE_ENABLED = "write.object-storage.enabled"
OBJECT_STORE_ENABLED_DEFAULT = False

WRITE_OBJECT_STORE_PARTITIONED_PATHS = "write.object-storage.partitioned-paths"
WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT = True

DELETE_MODE = "write.delete.mode"
DELETE_MODE_COPY_ON_WRITE = "copy-on-write"
DELETE_MODE_MERGE_ON_READ = "merge-on-read"
Expand Down Expand Up @@ -1613,13 +1621,6 @@ def generate_data_file_filename(self, extension: str) -> str:
# https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"

def generate_data_file_path(self, extension: str) -> str:
if self.partition_key:
file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}"
return file_path
else:
return self.generate_data_file_filename(extension)


@dataclass(frozen=True)
class AddFileTask:
Expand Down
145 changes: 145 additions & 0 deletions pyiceberg/table/locations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import importlib
import logging
from abc import ABC, abstractmethod
from typing import Optional

import mmh3

from pyiceberg.partitioning import PartitionKey
from pyiceberg.table import TableProperties
from pyiceberg.typedef import Properties
from pyiceberg.utils.properties import property_as_bool

logger = logging.getLogger(__name__)


class LocationProvider(ABC):
"""A base class for location providers, that provide data file locations for write tasks."""

table_location: str
table_properties: Properties

def __init__(self, table_location: str, table_properties: Properties):
self.table_location = table_location
self.table_properties = table_properties

@abstractmethod
def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str:
"""Return a fully-qualified data file location for the given filename.
Args:
data_file_name (str): The name of the data file.
partition_key (Optional[PartitionKey]): The data file's partition key. If None, the data is not partitioned.
Returns:
str: A fully-qualified location URI for the data file.
"""


class SimpleLocationProvider(LocationProvider):
def __init__(self, table_location: str, table_properties: Properties):
super().__init__(table_location, table_properties)

def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str:
prefix = f"{self.table_location}/data"
return f"{prefix}/{partition_key.to_path()}/{data_file_name}" if partition_key else f"{prefix}/{data_file_name}"


class ObjectStoreLocationProvider(LocationProvider):
HASH_BINARY_STRING_BITS = 20
ENTROPY_DIR_LENGTH = 4
ENTROPY_DIR_DEPTH = 3

_include_partition_paths: bool

def __init__(self, table_location: str, table_properties: Properties):
super().__init__(table_location, table_properties)
self._include_partition_paths = property_as_bool(
self.table_properties,
TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS,
TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT,
)

def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str:
if self._include_partition_paths and partition_key:
return self.new_data_location(f"{partition_key.to_path()}/{data_file_name}")

prefix = f"{self.table_location}/data"
hashed_path = self._compute_hash(data_file_name)

return (
f"{prefix}/{hashed_path}/{data_file_name}"
if self._include_partition_paths
else f"{prefix}/{hashed_path}-{data_file_name}"
)

@staticmethod
def _compute_hash(data_file_name: str) -> str:
# Bitwise AND to combat sign-extension; bitwise OR to preserve leading zeroes that `bin` would otherwise strip.
top_mask = 1 << ObjectStoreLocationProvider.HASH_BINARY_STRING_BITS
hash_code = mmh3.hash(data_file_name) & (top_mask - 1) | top_mask
return ObjectStoreLocationProvider._dirs_from_hash(bin(hash_code)[-ObjectStoreLocationProvider.HASH_BINARY_STRING_BITS :])

@staticmethod
def _dirs_from_hash(file_hash: str) -> str:
"""Divides hash into directories for optimized orphan removal operation using ENTROPY_DIR_DEPTH and ENTROPY_DIR_LENGTH."""
total_entropy_length = ObjectStoreLocationProvider.ENTROPY_DIR_DEPTH * ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH

hash_with_dirs = []
for i in range(0, total_entropy_length, ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH):
hash_with_dirs.append(file_hash[i : i + ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH])

if len(file_hash) > total_entropy_length:
hash_with_dirs.append(file_hash[total_entropy_length:])

return "/".join(hash_with_dirs)


def _import_location_provider(
location_provider_impl: str, table_location: str, table_properties: Properties
) -> Optional[LocationProvider]:
try:
path_parts = location_provider_impl.split(".")
if len(path_parts) < 2:
raise ValueError(
f"{TableProperties.WRITE_PY_LOCATION_PROVIDER_IMPL} should be full path (module.CustomLocationProvider), got: {location_provider_impl}"
)
module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1]
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_(table_location, table_properties)
except ModuleNotFoundError:
logger.warning("Could not initialize LocationProvider: %s", location_provider_impl)
return None


def load_location_provider(table_location: str, table_properties: Properties) -> LocationProvider:
table_location = table_location.rstrip("/")

if location_provider_impl := table_properties.get(TableProperties.WRITE_PY_LOCATION_PROVIDER_IMPL):
if location_provider := _import_location_provider(location_provider_impl, table_location, table_properties):
logger.info("Loaded LocationProvider: %s", location_provider_impl)
return location_provider
else:
raise ValueError(f"Could not initialize LocationProvider: {location_provider_impl}")

if property_as_bool(table_properties, TableProperties.OBJECT_STORE_ENABLED, TableProperties.OBJECT_STORE_ENABLED_DEFAULT):
return ObjectStoreLocationProvider(table_location, table_properties)
else:
return SimpleLocationProvider(table_location, table_properties)
39 changes: 39 additions & 0 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import TableProperties
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
Expand Down Expand Up @@ -280,6 +281,44 @@ def test_query_filter_v1_v2_append_null(
assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}"


@pytest.mark.integration
@pytest.mark.parametrize(
"part_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamp", "timestamptz", "binary"]
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_object_storage_location_provider_excludes_partition_path(
session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int
) -> None:
nested_field = TABLE_SCHEMA.find_field(part_col)
partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col)
)

tbl = _create_table(
session_catalog=session_catalog,
identifier=f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}",
# write.object-storage.partitioned-paths defaults to True
properties={"format-version": str(format_version), TableProperties.OBJECT_STORE_ENABLED: True},
data=[arrow_table_with_null],
partition_spec=partition_spec,
)

original_paths = tbl.inspect.data_files().to_pydict()["file_path"]
assert len(original_paths) == 3

# Update props to exclude partitioned paths and append data
with tbl.transaction() as tx:
tx.set_properties({TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS: False})
tbl.append(arrow_table_with_null)

added_paths = set(tbl.inspect.data_files().to_pydict()["file_path"]) - set(original_paths)
assert len(added_paths) == 3

# All paths before the props update should contain the partition, while all paths after should not
assert all(f"{part_col}=" in path for path in original_paths)
assert all(f"{part_col}=" not in path for path in added_paths)


@pytest.mark.integration
@pytest.mark.parametrize(
"spec",
Expand Down
27 changes: 27 additions & 0 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,33 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.deleted_data_files_count for row in rows] == [0, 1, 0, 0, 0]


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_object_storage_data_files(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
tbl = _create_table(
session_catalog=session_catalog,
identifier="default.object_stored",
properties={"format-version": format_version, TableProperties.OBJECT_STORE_ENABLED: True},
data=[arrow_table_with_null],
)
tbl.append(arrow_table_with_null)

paths = tbl.inspect.data_files().to_pydict()["file_path"]
assert len(paths) == 2

for location in paths:
assert location.startswith("s3://warehouse/default/object_stored/data/")
parts = location.split("/")
assert len(parts) == 11

# Entropy binary directories should have been injected
for dir_name in parts[6:10]:
assert dir_name
assert all(c in "01" for c in dir_name)


@pytest.mark.integration
def test_python_writes_with_spark_snapshot_reads(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table
Expand Down
Loading

0 comments on commit c68b9b1

Please sign in to comment.