diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 56f2242514..4ad9309666 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2034,16 +2034,18 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None: """ - Check if the `table_schema` is compatible with `other_schema`. + Check if the `table_schema` is compatible with `other_schema` in terms of the Iceberg Schema representation. - Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. + The schemas are compatible if: + - All fields in `other_schema` are present in `table_schema`. (other_schema <= table_schema) + - All required fields in `table_schema` are present in `other_schema`. Raises: ValueError: If the schemas are not compatible. """ name_mapping = table_schema.name_mapping try: - task_schema = pyarrow_to_schema( + other_schema = pyarrow_to_schema( other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) except ValueError as e: @@ -2053,7 +2055,10 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - if table_schema.as_struct() != task_schema.as_struct(): + fields_missing_from_table = {field for field in other_schema.fields if field not in table_schema.fields} + required_fields_in_table = {field for field in table_schema.fields if field.required} + missing_required_fields_in_other = {field for field in required_fields_in_table if field not in other_schema.fields} + if fields_missing_from_table or missing_required_fields_in_other: from rich.console import Console from rich.table import Table as RichTable @@ -2066,7 +2071,7 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down for lhs in table_schema.fields: try: - rhs = task_schema.find_field(lhs.field_id) + rhs = other_schema.find_field(lhs.field_id) rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) except ValueError: rich_table.add_row("❌", str(lhs), "Missing") @@ -2177,17 +2182,20 @@ def _dataframe_to_data_files( default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) + # projects schema to match the pyarrow table + write_schema = pyarrow_to_schema(df.schema, name_mapping=table_metadata.schema().name_mapping) + if table_metadata.spec().is_unpartitioned(): yield from write_file( io=io, table_metadata=table_metadata, tasks=iter([ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=write_schema) for batches in bin_pack_arrow_table(df, target_file_size) ]), ) else: - partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) + partitions = _determine_partitions(spec=table_metadata.spec(), schema=write_schema, arrow_table=df) yield from write_file( io=io, table_metadata=table_metadata, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 62440c4773..b43dc3206b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -484,10 +484,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) manifest_merge_enabled = PropertyUtil.property_as_bool( self.table_metadata.properties, @@ -545,10 +541,6 @@ def overwrite( _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index af626718f7..2fd5a8d4c7 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -963,9 +963,10 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: assert len(tbl.scan().to_arrow()) == 22 +@pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: - identifier = "default.table_append_subset_of_schema" +def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + identifier = "default.test_table_write_subset_of_schema" tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0]) assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns) @@ -975,6 +976,23 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2 +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + identifier = "default.test_table_write_out_of_order_schema" + # rotate the schema fields by 1 + fields = list(arrow_table_with_null.schema) + rotated_fields = fields[1:] + fields[:1] + rotated_schema = pa.schema(rotated_fields) + assert arrow_table_with_null.schema != rotated_schema + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema) + + tbl.overwrite(arrow_table_with_null) + tbl.append(arrow_table_with_null) + # overwrite and then append should produce twice the data + assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2 + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None: diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 326eeff195..d0045854b3 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1799,6 +1799,25 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: _check_schema_compatible(table_schema_simple, other_schema) +def test_schema_compatible(table_schema_simple: Schema) -> None: + try: + _check_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + +def test_schema_projection(table_schema_simple: Schema) -> None: + # remove optional `baz` field from `table_schema_simple` + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + )) + try: + _check_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + def test_schema_downcast(table_schema_simple: Schema) -> None: # large_string type is compatible with string type other_schema = pa.schema(( @@ -1810,7 +1829,7 @@ def test_schema_downcast(table_schema_simple: Schema) -> None: try: _check_schema_compatible(table_schema_simple, other_schema) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema`") + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") def test_partition_for_demo() -> None: