Skip to content

Commit

Permalink
Add _missing_ to make FileFormat case insensitive (#1411)
Browse files Browse the repository at this point in the history
* Add _missing_ to FileFormat Enum to make it case insensitive

* Combine the manifest test to existing test_manifest.py file

* Fix linting
  • Loading branch information
jiakai-li authored Dec 9, 2024
1 parent d82f8f7 commit 88c4bad
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Optional,
Tuple,
Type,
Union,
)

from cachetools import LRUCache, cached
Expand Down Expand Up @@ -97,6 +98,13 @@ class FileFormat(str, Enum):
PARQUET = "PARQUET"
ORC = "ORC"

@classmethod
def _missing_(cls, value: object) -> Union[None, str]:
for member in cls:
if member.value == str(value).upper():
return member
return None

def __repr__(self) -> str:
"""Return the string representation of the FileFormat class."""
return f"FileFormat.{self.name}"
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,26 @@ def test_write_manifest_list(
assert entry.file_sequence_number == 0 if format_version == 1 else 3
assert entry.snapshot_id == 8744736658442914487
assert entry.status == ManifestEntryStatus.ADDED


@pytest.mark.parametrize(
"raw_file_format,expected_file_format",
[
("avro", FileFormat("AVRO")),
("AVRO", FileFormat("AVRO")),
("parquet", FileFormat("PARQUET")),
("PARQUET", FileFormat("PARQUET")),
("orc", FileFormat("ORC")),
("ORC", FileFormat("ORC")),
("NOT_EXISTS", None),
],
)
def test_file_format_case_insensitive(raw_file_format: str, expected_file_format: FileFormat) -> None:
if expected_file_format:
parsed_file_format = FileFormat(raw_file_format)
assert parsed_file_format == expected_file_format, (
f"File format {raw_file_format}: {parsed_file_format} != {expected_file_format}"
)
else:
with pytest.raises(ValueError):
_ = FileFormat(raw_file_format)

0 comments on commit 88c4bad

Please sign in to comment.