From 88c4bad1d95175c3672bf7c47cba22ca803efcca Mon Sep 17 00:00:00 2001 From: Jiakai Li <50531391+jiakai-li@users.noreply.github.com> Date: Tue, 10 Dec 2024 03:33:33 +1300 Subject: [PATCH] Add `_missing_` to make `FileFormat` case insensitive (#1411) * Add _missing_ to FileFormat Enum to make it case insensitive * Combine the manifest test to existing test_manifest.py file * Fix linting --- pyiceberg/manifest.py | 8 ++++++++ tests/utils/test_manifest.py | 23 +++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 6774499f2..a56da5fc0 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -30,6 +30,7 @@ Optional, Tuple, Type, + Union, ) from cachetools import LRUCache, cached @@ -97,6 +98,13 @@ class FileFormat(str, Enum): PARQUET = "PARQUET" ORC = "ORC" + @classmethod + def _missing_(cls, value: object) -> Union[None, str]: + for member in cls: + if member.value == str(value).upper(): + return member + return None + def __repr__(self) -> str: """Return the string representation of the FileFormat class.""" return f"FileFormat.{self.name}" diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 97c88a99e..154671c92 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -604,3 +604,26 @@ def test_write_manifest_list( assert entry.file_sequence_number == 0 if format_version == 1 else 3 assert entry.snapshot_id == 8744736658442914487 assert entry.status == ManifestEntryStatus.ADDED + + +@pytest.mark.parametrize( + "raw_file_format,expected_file_format", + [ + ("avro", FileFormat("AVRO")), + ("AVRO", FileFormat("AVRO")), + ("parquet", FileFormat("PARQUET")), + ("PARQUET", FileFormat("PARQUET")), + ("orc", FileFormat("ORC")), + ("ORC", FileFormat("ORC")), + ("NOT_EXISTS", None), + ], +) +def test_file_format_case_insensitive(raw_file_format: str, expected_file_format: FileFormat) -> None: + if expected_file_format: + parsed_file_format = FileFormat(raw_file_format) + assert parsed_file_format == expected_file_format, ( + f"File format {raw_file_format}: {parsed_file_format} != {expected_file_format}" + ) + else: + with pytest.raises(ValueError): + _ = FileFormat(raw_file_format)