Skip to content

Commit

Permalink
Use VisitorWithPartner for name-mapping (#1014)
Browse files Browse the repository at this point in the history
* Use `VisitorWithPartner` for name-mapping

This will correctly handle fields with `.` in the name.

* Fix versions in deprecation

Co-authored-by: Sung Yun <[email protected]>

* Use full path in error

---------

Co-authored-by: Sung Yun <[email protected]>
  • Loading branch information
Fokko and sungwy authored Aug 13, 2024
1 parent f05b1ae commit 5cce906
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 13 deletions.
16 changes: 6 additions & 10 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
visit_with_partner,
)
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
Expand Down Expand Up @@ -818,14 +818,14 @@ def pyarrow_to_schema(
) -> Schema:
has_ids = visit_pyarrow(schema, _HasIds())
if has_ids:
visitor = _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
return visit_pyarrow(schema, _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
elif name_mapping is not None:
visitor = _ConvertToIceberg(name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
schema_without_ids = _pyarrow_to_schema_without_ids(schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
return apply_name_mapping(schema_without_ids, name_mapping)
else:
raise ValueError(
"Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
)
return visit_pyarrow(schema, visitor)


def _pyarrow_to_schema_without_ids(schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> Schema:
Expand Down Expand Up @@ -1002,17 +1002,13 @@ class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
"""Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided."""

_field_names: List[str]
_name_mapping: Optional[NameMapping]

def __init__(self, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False) -> None:
def __init__(self, downcast_ns_timestamp_to_us: bool = False) -> None:
self._field_names = []
self._name_mapping = name_mapping
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _field_id(self, field: pa.Field) -> int:
if self._name_mapping:
return self._name_mapping.find(*self._field_names).field_id
elif (field_id := _get_field_id(field)) is not None:
if (field_id := _get_field_id(field)) is not None:
return field_id
else:
raise ValueError(f"Cannot convert {field} to Iceberg Field as field_id is empty.")
Expand Down
134 changes: 132 additions & 2 deletions pyiceberg/table/name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@

from pydantic import Field, conlist, field_validator, model_serializer

from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor, SchemaWithPartnerVisitor, visit, visit_with_partner
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType
from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType
from pyiceberg.utils.deprecated import deprecated


class MappedField(IcebergBaseModel):
Expand Down Expand Up @@ -74,6 +75,11 @@ class NameMapping(IcebergRootModel[List[MappedField]]):
def _field_by_name(self) -> Dict[str, MappedField]:
return visit_name_mapping(self, _IndexByName())

@deprecated(
deprecated_in="0.8.0",
removed_in="0.9.0",
help_message="Please use `apply_name_mapping` instead",
)
def find(self, *names: str) -> MappedField:
name = ".".join(names)
try:
Expand Down Expand Up @@ -248,3 +254,127 @@ def create_mapping_from_schema(schema: Schema) -> NameMapping:

def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping:
return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds)))


class NameMappingAccessor(PartnerAccessor[MappedField]):
def schema_partner(self, partner: Optional[MappedField]) -> Optional[MappedField]:
return partner

def field_partner(
self, partner_struct: Optional[Union[List[MappedField], MappedField]], _: int, field_name: str
) -> Optional[MappedField]:
if partner_struct is not None:
if isinstance(partner_struct, MappedField):
partner_struct = partner_struct.fields

for field in partner_struct:
if field_name in field.names:
return field

return None

def list_element_partner(self, partner_list: Optional[MappedField]) -> Optional[MappedField]:
if partner_list is not None:
for field in partner_list.fields:
if "element" in field.names:
return field
return None

def map_key_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]:
if partner_map is not None:
for field in partner_map.fields:
if "key" in field.names:
return field
return None

def map_value_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]:
if partner_map is not None:
for field in partner_map.fields:
if "value" in field.names:
return field
return None


class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField, IcebergType]):
current_path: List[str]

def __init__(self) -> None:
# For keeping track where we are in case when a field cannot be found
self.current_path = []

def before_field(self, field: NestedField, field_partner: Optional[P]) -> None:
self.current_path.append(field.name)

def after_field(self, field: NestedField, field_partner: Optional[P]) -> None:
self.current_path.pop()

def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
self.current_path.append("element")

def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
self.current_path.pop()

def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
self.current_path.append("key")

def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
self.current_path.pop()

def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
self.current_path.append("value")

def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
self.current_path.pop()

def schema(self, schema: Schema, schema_partner: Optional[MappedField], struct_result: StructType) -> IcebergType:
return Schema(*struct_result.fields, schema_id=schema.schema_id)

def struct(self, struct: StructType, struct_partner: Optional[MappedField], field_results: List[NestedField]) -> IcebergType:
return StructType(*field_results)

def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType:
if field_partner is None:
raise ValueError(f"Field missing from NameMapping: {'.'.join(self.current_path)}")

return NestedField(
field_id=field_partner.field_id,
name=field.name,
field_type=field_result,
required=field.required,
doc=field.doc,
initial_default=field.initial_default,
initial_write=field.write_default,
)

def list(self, list_type: ListType, list_partner: Optional[MappedField], element_result: IcebergType) -> IcebergType:
if list_partner is None:
raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")

element_id = next(field for field in list_partner.fields if "element" in field.names).field_id
return ListType(element_id=element_id, element=element_result, element_required=list_type.element_required)

def map(
self, map_type: MapType, map_partner: Optional[MappedField], key_result: IcebergType, value_result: IcebergType
) -> IcebergType:
if map_partner is None:
raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")

key_id = next(field for field in map_partner.fields if "key" in field.names).field_id
value_id = next(field for field in map_partner.fields if "value" in field.names).field_id
return MapType(
key_id=key_id,
key_type=key_result,
value_id=value_id,
value_type=value_result,
value_required=map_type.value_required,
)

def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[MappedField]) -> PrimitiveType:
if primitive_partner is None:
raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")

return primitive


def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping) -> Schema:
return visit_with_partner(schema_without_ids, name_mapping, NameMappingProjectionVisitor(), NameMappingAccessor()) # type: ignore
52 changes: 51 additions & 1 deletion tests/table/test_name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from pyiceberg.table.name_mapping import (
MappedField,
NameMapping,
apply_name_mapping,
create_mapping_from_schema,
parse_mapping_from_json,
update_mapping,
)
from pyiceberg.types import NestedField, StringType
from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType, MapType, NestedField, StringType, StructType


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -321,3 +322,52 @@ def test_update_mapping(table_name_mapping_nested: NameMapping) -> None:
MappedField(field_id=18, names=["add_18"]),
])
assert update_mapping(table_name_mapping_nested, updates, adds) == expected


def test_mapping_using_by_visitor(table_schema_nested: Schema, table_name_mapping_nested: NameMapping) -> None:
schema_without_ids = Schema(
NestedField(field_id=0, name="foo", field_type=StringType(), required=False),
NestedField(field_id=0, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=0, name="baz", field_type=BooleanType(), required=False),
NestedField(
field_id=0,
name="qux",
field_type=ListType(element_id=0, element_type=StringType(), element_required=True),
required=True,
),
NestedField(
field_id=0,
name="quux",
field_type=MapType(
key_id=0,
key_type=StringType(),
value_id=0,
value_type=MapType(key_id=0, key_type=StringType(), value_id=0, value_type=IntegerType(), value_required=True),
value_required=True,
),
required=True,
),
NestedField(
field_id=0,
name="location",
field_type=ListType(
element_id=0,
element_type=StructType(
NestedField(field_id=0, name="latitude", field_type=FloatType(), required=False),
NestedField(field_id=0, name="longitude", field_type=FloatType(), required=False),
),
element_required=True,
),
required=True,
),
NestedField(
field_id=0,
name="person",
field_type=StructType(
NestedField(field_id=0, name="name", field_type=StringType(), required=False),
NestedField(field_id=0, name="age", field_type=IntegerType(), required=True),
),
required=False,
),
)
assert apply_name_mapping(schema_without_ids, table_name_mapping_nested).fields == table_schema_nested.fields

0 comments on commit 5cce906

Please sign in to comment.