Skip to content

Commit

Permalink
StructType field optional by default (#592)
Browse files Browse the repository at this point in the history
* StructType field optional by default

* changes in Integration tests to accomodate required field changes

* mkdoc changes to accomodate required field changes

* rebase from main

* fix additional test cases

* Remove unrelated change

* Remove unrelated change

---------

Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
MehulBatra and Fokko authored Apr 17, 2024
1 parent 87656fb commit aa850ef
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 41 deletions.
8 changes: 4 additions & 4 deletions pyiceberg/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
... NestedField(1, "required_field", StringType(), True),
... NestedField(2, "optional_field", IntegerType())
... ))
'struct<1: required_field: optional string, 2: optional_field: optional int>'
'struct<1: required_field: required string, 2: optional_field: optional int>'
Notes:
- https://iceberg.apache.org/#spec/#primitive-types
- https://iceberg.apache.org/spec/#primitive-types
"""

from __future__ import annotations
Expand Down Expand Up @@ -289,7 +289,7 @@ class NestedField(IcebergType):
field_id: int = Field(alias="id")
name: str = Field()
field_type: SerializeAsAny[IcebergType] = Field(alias="type")
required: bool = Field(default=True)
required: bool = Field(default=False)
doc: Optional[str] = Field(default=None, repr=False)
initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False)
write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore
Expand All @@ -299,7 +299,7 @@ def __init__(
field_id: Optional[int] = None,
name: Optional[str] = None,
field_type: Optional[IcebergType] = None,
required: bool = True,
required: bool = False,
doc: Optional[str] = None,
initial_default: Optional[Any] = None,
write_default: Optional[L] = None,
Expand Down
25 changes: 14 additions & 11 deletions tests/avro/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FloatReader,
IntegerReader,
MapReader,
OptionReader,
StringReader,
StructReader,
)
Expand Down Expand Up @@ -91,28 +92,30 @@ def test_resolver() -> None:
"location",
location_struct,
),
NestedField(1, "id", LongType()),
NestedField(1, "id", LongType(), required=False),
NestedField(6, "preferences", MapType(7, StringType(), 8, StringType())),
schema_id=1,
)
read_tree = resolve_reader(write_schema, read_schema)

assert read_tree == StructReader(
(
(1, IntegerReader()),
(None, StringReader()),
(1, OptionReader(option=IntegerReader())),
(None, OptionReader(option=StringReader())),
(
0,
StructReader(
(
(0, DoubleReader()),
(1, DoubleReader()),
OptionReader(
option=StructReader(
(
(0, OptionReader(option=DoubleReader())),
(1, OptionReader(option=DoubleReader())),
),
Record,
location_struct,
),
Record,
location_struct,
),
),
(2, MapReader(StringReader(), StringReader())),
(2, OptionReader(option=MapReader(StringReader(), StringReader()))),
),
Record,
read_schema.as_struct(),
Expand Down Expand Up @@ -309,7 +312,7 @@ def test_resolver_initial_value() -> None:

assert resolve_reader(write_schema, read_schema) == StructReader(
(
(None, StringReader()), # The one we skip
(None, OptionReader(option=StringReader())), # The one we skip
(0, DefaultReader("vo")),
),
Record,
Expand Down
6 changes: 3 additions & 3 deletions tests/catalog/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ def catalog(tmp_path: PosixPath) -> InMemoryCatalog:
TEST_TABLE_NAMESPACE = ("com", "organization", "department")
TEST_TABLE_NAME = "my_table"
TEST_TABLE_SCHEMA = Schema(
NestedField(1, "x", LongType()),
NestedField(2, "y", LongType(), doc="comment"),
NestedField(3, "z", LongType()),
NestedField(1, "x", LongType(), required=True),
NestedField(2, "y", LongType(), doc="comment", required=True),
NestedField(3, "z", LongType(), required=True),
)
TEST_TABLE_PARTITION_SPEC = PartitionSpec(PartitionField(name="x", transform=IdentityTransform(), source_id=1, field_id=1000))
TEST_TABLE_PROPERTIES = {"key1": "value1", "key2": "value2"}
Expand Down
2 changes: 1 addition & 1 deletion tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_create_table(
NestedField(field_id=21, name='inner_string', field_type=StringType(), required=False),
NestedField(field_id=22, name='inner_int', field_type=IntegerType(), required=True),
),
required=True,
required=False,
),
schema_id=0,
identifier_field_ids=[2],
Expand Down
6 changes: 3 additions & 3 deletions tests/cli/test_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def mock_datetime_now(monkeypatch: pytest.MonkeyPatch) -> None:
TEST_NAMESPACE_PROPERTIES = {"location": "s3://warehouse/database/location"}
TEST_TABLE_NAME = "my_table"
TEST_TABLE_SCHEMA = Schema(
NestedField(1, "x", LongType()),
NestedField(2, "y", LongType(), doc="comment"),
NestedField(3, "z", LongType()),
NestedField(1, "x", LongType(), required=True),
NestedField(2, "y", LongType(), doc="comment", required=True),
NestedField(3, "z", LongType(), required=True),
)
TEST_TABLE_PARTITION_SPEC = PartitionSpec(PartitionField(name="x", transform=IdentityTransform(), source_id=1, field_id=1000))
TEST_TABLE_PROPERTIES = {"read.split.target.size": "134217728"}
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_rest_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ def test_add_nested_map_of_structs(catalog: Catalog) -> None:
tbl = _create_table_with_schema(
catalog,
Schema(
NestedField(field_id=1, name="foo", field_type=StringType()),
NestedField(field_id=1, name="foo", field_type=StringType(), required=True),
),
)

Expand Down Expand Up @@ -1031,7 +1031,7 @@ def test_add_nested_list_of_structs(catalog: Catalog) -> None:
tbl = _create_table_with_schema(
catalog,
Schema(
NestedField(field_id=1, name="foo", field_type=StringType()),
NestedField(field_id=1, name="foo", field_type=StringType(), required=True),
),
)

Expand Down
27 changes: 16 additions & 11 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,9 @@ def schema_map_of_structs() -> Schema:
key_id=51,
value_id=52,
key_type=StringType(),
value_type=StructType(NestedField(511, "lat", DoubleType()), NestedField(512, "long", DoubleType())),
value_type=StructType(
NestedField(511, "lat", DoubleType(), required=True), NestedField(512, "long", DoubleType(), required=True)
),
element_required=False,
),
required=False,
Expand Down Expand Up @@ -1038,7 +1040,7 @@ def test_projection_add_column_struct_required(file_int: str) -> None:
def test_projection_rename_column(schema_int: Schema, file_int: str) -> None:
schema = Schema(
# Reuses the id 1
NestedField(1, "other_name", IntegerType())
NestedField(1, "other_name", IntegerType(), required=True)
)
result_table = project(schema, [file_int])
assert len(result_table.columns[0]) == 3
Expand Down Expand Up @@ -1071,7 +1073,7 @@ def test_projection_filter(schema_int: Schema, file_int: str) -> None:
def test_projection_filter_renamed_column(file_int: str) -> None:
schema = Schema(
# Reuses the id 1
NestedField(1, "other_id", IntegerType())
NestedField(1, "other_id", IntegerType(), required=True)
)
result_table = project(schema, [file_int], GreaterThan("other_id", 1))
assert len(result_table.columns[0]) == 1
Expand All @@ -1089,7 +1091,7 @@ def test_projection_filter_add_column(schema_int: Schema, file_int: str, file_st


def test_projection_filter_add_column_promote(file_int: str) -> None:
schema_long = Schema(NestedField(1, "id", LongType()))
schema_long = Schema(NestedField(1, "id", LongType(), required=True))
result_table = project(schema_long, [file_int])

for actual, expected in zip(result_table.columns[0], [0, 1, 2]):
Expand All @@ -1111,9 +1113,10 @@ def test_projection_nested_struct_subset(file_struct: str) -> None:
4,
"location",
StructType(
NestedField(41, "lat", DoubleType()),
NestedField(41, "lat", DoubleType(), required=True),
# long is missing!
),
required=True,
)
)

Expand All @@ -1138,6 +1141,7 @@ def test_projection_nested_new_field(file_struct: str) -> None:
StructType(
NestedField(43, "null", DoubleType(), required=False), # Whoa, this column doesn't exist in the file
),
required=True,
)
)

Expand All @@ -1163,6 +1167,7 @@ def test_projection_nested_struct(schema_struct: Schema, file_struct: str) -> No
NestedField(43, "null", DoubleType(), required=False),
NestedField(42, "long", DoubleType(), required=False),
),
required=True,
)
)

Expand Down Expand Up @@ -1194,8 +1199,8 @@ def test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of
ListType(
51,
StructType(
NestedField(511, "latitude", DoubleType()),
NestedField(512, "longitude", DoubleType()),
NestedField(511, "latitude", DoubleType(), required=True),
NestedField(512, "longitude", DoubleType(), required=True),
NestedField(513, "altitude", DoubleType(), required=False),
),
element_required=False,
Expand Down Expand Up @@ -1239,9 +1244,9 @@ def test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_s
value_id=52,
key_type=StringType(),
value_type=StructType(
NestedField(511, "latitude", DoubleType()),
NestedField(512, "longitude", DoubleType()),
NestedField(513, "altitude", DoubleType(), required=False),
NestedField(511, "latitude", DoubleType(), required=True),
NestedField(512, "longitude", DoubleType(), required=True),
NestedField(513, "altitude", DoubleType()),
),
element_required=False,
),
Expand Down Expand Up @@ -1308,7 +1313,7 @@ def test_projection_nested_struct_different_parent_id(file_struct: str) -> None:


def test_projection_filter_on_unprojected_field(schema_int_str: Schema, file_int_str: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType()))
schema = Schema(NestedField(1, "id", IntegerType(), required=True))

result_table = project(schema, [file_int_str], GreaterThan("data", "1"), schema_int_str)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_schema_str(table_schema_simple: Schema) -> None:
def test_schema_repr_single_field() -> None:
"""Test schema representation"""
actual = repr(schema.Schema(NestedField(field_id=1, name="foo", field_type=StringType()), schema_id=1))
expected = "Schema(NestedField(field_id=1, name='foo', field_type=StringType(), required=True), schema_id=1, identifier_field_ids=[])"
expected = "Schema(NestedField(field_id=1, name='foo', field_type=StringType(), required=False), schema_id=1, identifier_field_ids=[])"
assert expected == actual


Expand All @@ -104,7 +104,7 @@ def test_schema_repr_two_fields() -> None:
schema_id=1,
)
)
expected = "Schema(NestedField(field_id=1, name='foo', field_type=StringType(), required=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), required=False), schema_id=1, identifier_field_ids=[])"
expected = "Schema(NestedField(field_id=1, name='foo', field_type=StringType(), required=False), NestedField(field_id=2, name='bar', field_type=IntegerType(), required=False), schema_id=1, identifier_field_ids=[])"
assert expected == actual


Expand Down
8 changes: 4 additions & 4 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test_serialization_struct() -> None:
expected = (
'{"type":"struct","fields":['
'{"id":1,"name":"required_field","type":"string","required":true,"doc":"this is a doc"},'
'{"id":2,"name":"optional_field","type":"int","required":true}'
'{"id":2,"name":"optional_field","type":"int","required":false}'
"]}"
)
assert actual == expected
Expand All @@ -545,7 +545,7 @@ def test_deserialization_struct() -> None:
"id": 2,
"name": "optional_field",
"type": "int",
"required": true
"required": false
}
]
}
Expand All @@ -560,13 +560,13 @@ def test_deserialization_struct() -> None:


def test_str_struct(simple_struct: StructType) -> None:
assert str(simple_struct) == "struct<1: required_field: required string (this is a doc), 2: optional_field: required int>"
assert str(simple_struct) == "struct<1: required_field: required string (this is a doc), 2: optional_field: optional int>"


def test_repr_struct(simple_struct: StructType) -> None:
assert (
repr(simple_struct)
== "StructType(fields=(NestedField(field_id=1, name='required_field', field_type=StringType(), required=True), NestedField(field_id=2, name='optional_field', field_type=IntegerType(), required=True),))"
== "StructType(fields=(NestedField(field_id=1, name='required_field', field_type=StringType(), required=True), NestedField(field_id=2, name='optional_field', field_type=IntegerType(), required=False),))"
)


Expand Down

0 comments on commit aa850ef

Please sign in to comment.