Skip to content

Commit

Permalink
Fix setting V1 format version for Non-REST catalogs (apache#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh-jahagirdar authored Feb 12, 2024
1 parent fbeeb9a commit dab5d76
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 4 deletions.
2 changes: 2 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class TableProperties:
METRICS_MODE_COLUMN_CONF_PREFIX = "write.metadata.metrics.column"

DEFAULT_NAME_MAPPING = "schema.name-mapping.default"
FORMAT_VERSION = "format-version"
DEFAULT_FORMAT_VERSION = 2


class PropertyUtil:
Expand Down
28 changes: 25 additions & 3 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]:
The TableMetadata with the defaults applied.
"""
# When the schema doesn't have an ID
if data.get("schema") and "schema_id" not in data["schema"]:
data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID
schema = data.get("schema")
if isinstance(schema, dict):
if "schema_id" not in schema and "schema-id" not in schema:
schema["schema_id"] = DEFAULT_SCHEMA_ID

return data

Expand Down Expand Up @@ -335,7 +337,7 @@ def to_v2(self) -> TableMetadataV2:
metadata["format-version"] = 2
return TableMetadataV2.model_validate(metadata)

format_version: Literal[1] = Field(alias="format-version")
format_version: Literal[1] = Field(alias="format-version", default=1)
"""An integer version number for the format. Currently, this can be 1 or 2
based on the spec. Implementations must throw an exception if a table’s
version is higher than the supported version."""
Expand Down Expand Up @@ -404,13 +406,33 @@ def new_table_metadata(
properties: Properties = EMPTY_DICT,
table_uuid: Optional[uuid.UUID] = None,
) -> TableMetadata:
from pyiceberg.table import TableProperties

fresh_schema = assign_fresh_schema_ids(schema)
fresh_partition_spec = assign_fresh_partition_spec_ids(partition_spec, schema, fresh_schema)
fresh_sort_order = assign_fresh_sort_order_ids(sort_order, schema, fresh_schema)

if table_uuid is None:
table_uuid = uuid.uuid4()

# Remove format-version so it does not get persisted
format_version = int(properties.pop(TableProperties.FORMAT_VERSION, TableProperties.DEFAULT_FORMAT_VERSION))
if format_version == 1:
return TableMetadataV1(
location=location,
last_column_id=fresh_schema.highest_field_id,
current_schema_id=fresh_schema.schema_id,
schema=fresh_schema,
partition_spec=[field.model_dump() for field in fresh_partition_spec.fields],
partition_specs=[fresh_partition_spec],
default_spec_id=fresh_partition_spec.spec_id,
sort_orders=[fresh_sort_order],
default_sort_order_id=fresh_sort_order.order_id,
properties=properties,
last_partition_id=fresh_partition_spec.last_assigned_field_id,
table_uuid=table_uuid,
)

return TableMetadataV2(
location=location,
schemas=[fresh_schema],
Expand Down
32 changes: 32 additions & 0 deletions tests/catalog/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,38 @@ def test_create_table_with_database_location(
assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"


@mock_aws
def test_create_v1_table(
_bucket_initialize: None,
_glue: boto3.client,
moto_endpoint_url: str,
table_schema_nested: Schema,
database_name: str,
table_name: str,
) -> None:
catalog_name = "glue"
test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url})
test_catalog.create_namespace(namespace=database_name, properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"})
table = test_catalog.create_table((database_name, table_name), table_schema_nested, properties={"format-version": "1"})
assert table.format_version == 1

table_info = _glue.get_table(
DatabaseName=database_name,
Name=table_name,
)

storage_descriptor = table_info["Table"]["StorageDescriptor"]
columns = storage_descriptor["Columns"]
assert len(columns) == len(table_schema_nested.fields)
assert columns[0] == {
"Name": "foo",
"Type": "string",
"Parameters": {"iceberg.field.id": "1", "iceberg.field.optional": "true", "iceberg.field.current": "true"},
}

assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"


@mock_aws
def test_create_table_with_default_warehouse(
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
Expand Down
55 changes: 54 additions & 1 deletion tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
MetadataLogEntry,
Expand Down Expand Up @@ -295,6 +295,59 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
assert metadata.model_dump() == expected.model_dump()


@patch("time.time", MagicMock(return_value=12345))
def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
)

# Test creating V1 table
called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
metadata_location = called_v1_table.parameters["metadata_location"]
with open(metadata_location, encoding=UTF8) as f:
payload = f.read()

expected_schema = Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
schema_id=0,
identifier_field_ids=[2],
)
actual_v1_metadata = TableMetadataUtil.parse_raw(payload)
expected_spec = PartitionSpec()
expected_v1_metadata = TableMetadataV1(
location=actual_v1_metadata.location,
table_uuid=actual_v1_metadata.table_uuid,
last_updated_ms=actual_v1_metadata.last_updated_ms,
last_column_id=3,
schema=expected_schema,
schemas=[expected_schema],
current_schema_id=0,
last_partition_id=1000,
properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"},
partition_spec=[],
partition_specs=[expected_spec],
default_spec_id=0,
current_snapshot_id=None,
snapshots=[],
snapshot_log=[],
metadata_log=[],
sort_orders=[SortOrder(order_id=0)],
default_sort_order_id=0,
refs={},
format_version=1,
)

assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump()


def test_load_table(hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

Expand Down
19 changes: 19 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC
from pyiceberg.schema import Schema
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.sorting import (
Expand Down Expand Up @@ -158,6 +159,24 @@ def test_create_table_default_sort_order(catalog: SqlCatalog, table_schema_neste
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
],
)
def test_create_v1_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
database_name, _table_name = random_identifier
catalog.create_namespace(database_name)
table = catalog.create_table(random_identifier, table_schema_nested, properties={"format-version": "1"})
assert table.sort_order().order_id == 0, "Order ID must match"
assert table.sort_order().is_unsorted is True, "Order must be unsorted"
assert table.format_version == 1
assert table.spec() == UNPARTITIONED_PARTITION_SPEC
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
Expand Down
69 changes: 69 additions & 0 deletions tests/table/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,75 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) -
]


def test_new_table_metadata_with_explicit_v1_format() -> None:
schema = Schema(
NestedField(field_id=10, name="foo", field_type=StringType(), required=False),
NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False),
schema_id=10,
identifier_field_ids=[22],
)

partition_spec = PartitionSpec(
PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10
)

sort_order = SortOrder(
SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST),
order_id=10,
)

actual = new_table_metadata(
schema=schema,
partition_spec=partition_spec,
sort_order=sort_order,
location="s3://some_v1_location/",
properties={'format-version': "1"},
)

expected_schema = Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
schema_id=0,
identifier_field_ids=[2],
)

expected_spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar"))

expected = TableMetadataV1(
location="s3://some_v1_location/",
table_uuid=actual.table_uuid,
last_updated_ms=actual.last_updated_ms,
last_column_id=3,
schemas=[expected_schema],
schema_=expected_schema,
current_schema_id=0,
partition_spec=[field.model_dump() for field in expected_spec.fields],
partition_specs=[expected_spec],
default_spec_id=0,
last_partition_id=1000,
properties={},
current_snapshot_id=None,
snapshots=[],
snapshot_log=[],
metadata_log=[],
sort_orders=[
SortOrder(
SortField(
source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST
),
order_id=1,
)
],
default_sort_order_id=1,
refs={},
format_version=1,
)

assert actual.model_dump() == expected.model_dump()


def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None:
"""Test the exception when trying to load an unknown version"""

Expand Down

0 comments on commit dab5d76

Please sign in to comment.