Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow writing pa.Table that are either a subset of table schema or in arbitrary order, and support type promotion on write #921

Merged
merged 13 commits into from
Jul 17, 2024
71 changes: 33 additions & 38 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
_check_schema_compatible,
pre_order_visit,
promote,
prune_columns,
Expand Down Expand Up @@ -1407,7 +1408,7 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array:
# This can be removed once this has been fixed:
# https://github.com/apache/arrow/issues/38809
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)

value_array = self._cast_if_needed(list_type.element_field, value_array)
sungwy marked this conversation as resolved.
Show resolved Hide resolved
arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
Expand All @@ -1417,6 +1418,8 @@ def map(
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
) -> Optional[pa.Array]:
if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None:
key_result = self._cast_if_needed(map_type.key_field, key_result)
value_result = self._cast_if_needed(map_type.value_field, value_result)
arrow_field = pa.map_(
self._construct_field(map_type.key_field, key_result.type),
self._construct_field(map_type.value_field, value_result.type),
Expand Down Expand Up @@ -1549,9 +1552,16 @@ def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc

expected_physical_type = _primitive_to_physical(iceberg_type)
if expected_physical_type != physical_type_string:
raise ValueError(
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
)
# Allow promotable physical types
# INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts
if (physical_type_string == "INT32" and expected_physical_type == "INT64") or (
physical_type_string == "FLOAT" and expected_physical_type == "DOUBLE"
Copy link
Collaborator Author

@sungwy sungwy Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put in this logic to allow StatsAggregator to collect stats for files that are added through add_files that have file field types that map to broader Iceberg Schema types. This feels overly specific, and I feel as though I am duplicating the type promote mappings in a different format. I'm open to other ideas if we want to keep this check on the parquet physical types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we get more promotions (which is expected in V3 with the upcoming variant type, this might need some reworking), but I think we're good for now 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! There’s unfortunately one more case where we are failing for add_files so will that test case added shortly

):
pass
else:
raise ValueError(
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
)

self.primitive_type = iceberg_type

Expand Down Expand Up @@ -1998,8 +2008,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
)

def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema

Fokko marked this conversation as resolved.
Show resolved Hide resolved
table_schema = table_metadata.schema()
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
Expand All @@ -2011,7 +2020,7 @@ def write_parquet(task: WriteTask) -> DataFile:
batches = [
_to_requested_schema(
requested_schema=file_schema,
file_schema=table_schema,
file_schema=task.schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
include_field_ids=True,
Expand Down Expand Up @@ -2070,47 +2079,30 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[
return bin_packed_record_batches


def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None:
def _check_pyarrow_schema_compatible(
requested_schema: Schema, provided_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False
) -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.
Check if the `requested_schema` is compatible with `provided_schema`.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.

Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = table_schema.name_mapping
name_mapping = requested_schema.name_mapping
sungwy marked this conversation as resolved.
Show resolved Hide resolved
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
provided_schema = pyarrow_to_schema(
provided_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
provided_schema = _pyarrow_to_schema_without_ids(provided_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(provided_schema._name_to_id.keys()) - set(requested_schema._name_to_id.keys())
raise ValueError(
sungwy marked this conversation as resolved.
Show resolved Hide resolved
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
_check_schema_compatible(requested_schema, provided_schema)


def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
Expand All @@ -2124,7 +2116,7 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
_check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())
_check_pyarrow_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
Expand Down Expand Up @@ -2205,7 +2197,7 @@ def _dataframe_to_data_files(
Returns:
An iterable that supplies datafiles that represent the table.
"""
from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties, WriteTask

counter = counter or itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()
Expand All @@ -2214,13 +2206,16 @@ def _dataframe_to_data_files(
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
name_mapping = table_metadata.schema().name_mapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use table's schema.name-mapping.default if set and fallback to schema's name-mapping if not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - I'm not sure actually.

When we are writing a dataframe into an Iceberg table, I think we are making the assumption that its names match the current names of the Iceberg table, so I think using the table_metadata.schema().name_mapping is appropriate in supporting that.

table_metadata.name_mapping is used to map the field names of written data files hat do not have field IDs to the Iceberg SChema, so I don't think we need to extend our write APIs to also support mapping other names in the name mapping. Since this will be a new feature, I'm of the opinion that we should leave it out until we can think of valid use cases for that from our user base.

I'm curious to hear what others' thoughts are, and whether anyone has a workflow in mind that would benefit from this change!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great! I initially raised this because we’re assigning field IDs for the input dataframe, which aligns with the general purpose of name mapping - to provide fallback IDs. On second thought, schema.name-mapping.default is more for the read side, so using it here may silently introduce unwanted side effects during write. I agree, let’s hold off on this for a while and wait for more discussions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great 👍 thank you for the review!

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is necessary to ensure that we are comparing the Schema that matches that arrow table's schema versus the Table Schema in order to properly invoke promote on write

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch 👍


if table_metadata.spec().is_unpartitioned():
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
for batches in bin_pack_arrow_table(df, target_file_size)
]),
)
Expand All @@ -2235,7 +2230,7 @@ def _dataframe_to_data_files(
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=table_metadata.schema(),
schema=task_schema,
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
Expand Down
100 changes: 100 additions & 0 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,3 +1616,103 @@ def _(file_type: FixedType, read_type: IcebergType) -> IcebergType:
return read_type
else:
raise ResolveError(f"Cannot promote {file_type} to {read_type}")


def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None:
"""
Check if the `provided_schema` is compatible with `requested_schema`.

Both Schemas must have valid IDs and share the same ID for the same field names.

Two schemas are considered compatible when:
1. All `required` fields in `requested_schema` are present and are also `required` in the `provided_schema`
2. Field Types are consistent for fields that are present in both schemas. I.e. the field type
in the `provided_schema` can be promoted to the field type of the same field ID in `requested_schema`

Raises:
ValueError: If the schemas are not compatible.
"""
pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema))


class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]):
provided_schema: Schema

def __init__(self, provided_schema: Schema):
from rich.console import Console
from rich.table import Table as RichTable

self.provided_schema = provided_schema
self.rich_table = RichTable(show_header=True, header_style="bold")
self.rich_table.add_column("")
self.rich_table.add_column("Table field")
self.rich_table.add_column("Dataframe field")
self.console = Console(record=True)

def _is_field_compatible(self, lhs: NestedField) -> bool:
# Validate nullability first.
# An optional field can be missing in the provided schema
# But a required field must exist as a required field
try:
rhs = self.provided_schema.find_field(lhs.field_id)
except ValueError:
if lhs.required:
self.rich_table.add_row("❌", str(lhs), "Missing")
return False
else:
self.rich_table.add_row("✅", str(lhs), "Missing")
return True

if lhs.required and not rhs.required:
self.rich_table.add_row("❌", str(lhs), str(rhs))
return False

# Check type compatibility
if lhs.field_type == rhs.field_type:
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
# We only check that the parent node is also of the same type.
# We check the type of the child nodes when we traverse them later.
elif any(
(isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type))
for container_type in {StructType, MapType, ListType}
):
sungwy marked this conversation as resolved.
Show resolved Hide resolved
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
else:
try:
# If type can be promoted to the requested schema
# it is considered compatible
promote(rhs.field_type, lhs.field_type)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This succeeds for append or overwrite, but will cause the add_files operation to fail even though it passes the schema checks. This is because the StatsAggregator will fail to collect stats because it makes an assertion that the physical type of the file be the same as the expected physical type of the IcebergType.

INT32 != physical_type(LongType())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a test to reproduce this? This is interesting since for Python int and long are both a int in Python.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll write up a test 👍

The comparison isn't between python types, but between parquet physical types: https://github.com/apache/iceberg-python/blob/main/pyiceberg/io/pyarrow.py#L1503-L1507

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could get away with just removing this check, since we are running a comprehensive type compatibility check already?

expected_physical_type = _primitive_to_physical(iceberg_type)
if expected_physical_type != physical_type_string:
raise ValueError(
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that this would work @Fokko let me know if we are good to move forward with this change!

self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
except ResolveError:
self.rich_table.add_row("❌", str(lhs), str(rhs))
return False

def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> bool:
if not (result := struct_result()):
self.console.print(self.rich_table)
raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}")
sungwy marked this conversation as resolved.
Show resolved Hide resolved
return result

def struct(self, struct: StructType, field_results: List[Callable[[], bool]]) -> bool:
results = [result() for result in field_results]
return all(results)

def field(self, field: NestedField, field_result: Callable[[], bool]) -> bool:
return self._is_field_compatible(field) and field_result()

def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool:
return self._is_field_compatible(list_type.element_field) and element_result()

def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool:
return all([
self._is_field_compatible(map_type.key_field),
self._is_field_compatible(map_type.value_field),
key_result(),
value_result(),
])

def primitive(self, primitive: PrimitiveType) -> bool:
return True
15 changes: 10 additions & 5 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
Expand Down Expand Up @@ -471,6 +470,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

Expand All @@ -481,8 +482,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
_check_pyarrow_schema_compatible(
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

manifest_merge_enabled = PropertyUtil.property_as_bool(
Expand Down Expand Up @@ -528,6 +529,8 @@ def overwrite(
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

Expand All @@ -538,8 +541,8 @@ def overwrite(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
_check_pyarrow_schema_compatible(
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)
Expand All @@ -566,6 +569,8 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table

if (
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
== TableProperties.DELETE_MODE_MERGE_ON_READ
Expand Down
59 changes: 59 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,3 +2506,62 @@ def table_schema_with_all_microseconds_timestamp_precision() -> Schema:
NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False),
NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False),
)


@pytest.fixture(scope="session")
def table_schema_with_promoted_types() -> Schema:
"""Iceberg table Schema with longs, doubles and uuid in simple and nested types."""
return Schema(
NestedField(field_id=1, name="long", field_type=LongType(), required=False),
NestedField(
field_id=2,
name="list",
field_type=ListType(element_id=4, element_type=LongType(), element_required=False),
required=True,
),
NestedField(
field_id=3,
name="map",
field_type=MapType(
key_id=5,
key_type=StringType(),
value_id=6,
value_type=LongType(),
value_required=False,
),
required=True,
),
NestedField(field_id=7, name="double", field_type=DoubleType(), required=False),
NestedField(field_id=8, name="uuid", field_type=UUIDType(), required=False),
)


@pytest.fixture(scope="session")
def pyarrow_schema_with_promoted_types() -> "pa.Schema":
"""Pyarrow Schema with longs, doubles and uuid in simple and nested types."""
import pyarrow as pa

return pa.schema((
pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long
pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long
pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long
pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double
pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double
))


@pytest.fixture(scope="session")
def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types: "pa.Schema") -> "pa.Table":
"""Pyarrow table with longs, doubles and uuid in simple and nested types."""
import pyarrow as pa

return pa.Table.from_pydict(
{
"long": [1, 9],
"list": [[1, 1], [2, 2]],
"map": [{"a": 1}, {"b": 2}],
"double": [1.1, 9.2],
"uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E", b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"],
},
schema=pyarrow_schema_with_promoted_types,
)
Loading
Loading