diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 86e0a39713..4ad9309666 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2034,7 +2034,7 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None: """ - Check if the `table_schema` is compatible with `other_schema`. + Check if the `table_schema` is compatible with `other_schema` in terms of the Iceberg Schema representation. The schemas are compatible if: - All fields in `other_schema` are present in `table_schema`. (other_schema <= table_schema) @@ -2043,22 +2043,22 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down Raises: ValueError: If the schemas are not compatible. """ - from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema - name_mapping = table_schema.name_mapping try: - other_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping) + other_schema = pyarrow_to_schema( + other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) except ValueError as e: - other_schema = _pyarrow_to_schema_without_ids(other_schema) + other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) additional_names = set(other_schema.column_names) - set(table_schema.column_names) raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - missing_table_schema_fields = {field for field in other_schema.fields if field not in table_schema.fields} - required_table_schema_fields = {field for field in table_schema.fields if field.required} - missing_required_fields = {field for field in required_table_schema_fields if field not in other_schema.fields} - if missing_table_schema_fields or missing_required_fields: + fields_missing_from_table = {field for field in other_schema.fields if field not in table_schema.fields} + required_fields_in_table = {field for field in table_schema.fields if field.required} + missing_required_fields_in_other = {field for field in required_fields_in_table if field not in other_schema.fields} + if fields_missing_from_table or missing_required_fields_in_other: from rich.console import Console from rich.table import Table as RichTable @@ -2182,17 +2182,20 @@ def _dataframe_to_data_files( default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) + # projects schema to match the pyarrow table + write_schema = pyarrow_to_schema(df.schema, name_mapping=table_metadata.schema().name_mapping) + if table_metadata.spec().is_unpartitioned(): yield from write_file( io=io, table_metadata=table_metadata, tasks=iter([ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=write_schema) for batches in bin_pack_arrow_table(df, target_file_size) ]), ) else: - partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) + partitions = _determine_partitions(spec=table_metadata.spec(), schema=write_schema, arrow_table=df) yield from write_file( io=io, table_metadata=table_metadata, diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 3ce38555aa..d0045854b3 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1832,20 +1832,6 @@ def test_schema_downcast(table_schema_simple: Schema) -> None: pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") -def test_schema_downcast(table_schema_simple: Schema) -> None: - # large_string type is compatible with string type - other_schema = pa.schema(( - pa.field("foo", pa.large_string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - )) - - try: - _check_schema_compatible(table_schema_simple, other_schema) - except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema`") - - def test_partition_for_demo() -> None: test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) test_schema = Schema(