From e5a58b34dd830c6ffea11649613b693f70f7cbb4 Mon Sep 17 00:00:00 2001 From: awdavidson <54780428+awdavidson@users.noreply.github.com> Date: Wed, 25 Sep 2024 18:17:14 +0100 Subject: [PATCH] HA HMS support (#752) * Support HA and kerberos * reformat * Remove kerberos auth * Capture all exceptions * Make more pythonic * Add uts * Update UT to use assert_called_once_with * Fix for linting Co-authored-by: Kevin Liu * Fix f string * fix formatting --------- Co-authored-by: Kevin Liu --- pyiceberg/catalog/hive.py | 15 ++++++++++++++- tests/catalog/test_hive.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index e8aff20438..4b60b35f05 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -261,7 +261,7 @@ class HiveCatalog(MetastoreCatalog): def __init__(self, name: str, **properties: str): super().__init__(name, **properties) - self._client = _HiveClient(properties["uri"], properties.get("ugi")) + self._client = self._create_hive_client(properties) self._lock_check_min_wait_time = property_as_float(properties, LOCK_CHECK_MIN_WAIT_TIME, DEFAULT_LOCK_CHECK_MIN_WAIT_TIME) self._lock_check_max_wait_time = property_as_float(properties, LOCK_CHECK_MAX_WAIT_TIME, DEFAULT_LOCK_CHECK_MAX_WAIT_TIME) @@ -271,6 +271,19 @@ def __init__(self, name: str, **properties: str): DEFAULT_LOCK_CHECK_RETRIES, ) + @staticmethod + def _create_hive_client(properties: Dict[str, str]) -> _HiveClient: + last_exception = None + for uri in properties["uri"].split(","): + try: + return _HiveClient(uri, properties.get("ugi")) + except BaseException as e: + last_exception = e + if last_exception is not None: + raise last_exception + else: + raise ValueError(f"Unable to connect to hive using uri: {properties['uri']}") + def _convert_hive_into_iceberg(self, table: HiveTable) -> Table: properties: Dict[str, str] = table.parameters if TABLE_TYPE not in properties: diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index 96e95815be..a51598acf8 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -1195,3 +1195,33 @@ def test_hive_wait_for_lock() -> None: with pytest.raises(WaitingForLockException): catalog._wait_for_lock("db", "tbl", lockid, catalog._client) assert catalog._client.check_lock.call_count == 5 + + +def test_create_hive_client_success() -> None: + properties = {"uri": "thrift://localhost:10000", "ugi": "user"} + + with patch("pyiceberg.catalog.hive._HiveClient", return_value=MagicMock()) as mock_hive_client: + client = HiveCatalog._create_hive_client(properties) + mock_hive_client.assert_called_once_with("thrift://localhost:10000", "user") + assert client is not None + + +def test_create_hive_client_multiple_uris() -> None: + properties = {"uri": "thrift://localhost:10000,thrift://localhost:10001", "ugi": "user"} + + with patch("pyiceberg.catalog.hive._HiveClient") as mock_hive_client: + mock_hive_client.side_effect = [Exception("Connection failed"), MagicMock()] + + client = HiveCatalog._create_hive_client(properties) + assert mock_hive_client.call_count == 2 + mock_hive_client.assert_has_calls([call("thrift://localhost:10000", "user"), call("thrift://localhost:10001", "user")]) + assert client is not None + + +def test_create_hive_client_failure() -> None: + properties = {"uri": "thrift://localhost:10000,thrift://localhost:10001", "ugi": "user"} + + with patch("pyiceberg.catalog.hive._HiveClient", side_effect=Exception("Connection failed")) as mock_hive_client: + with pytest.raises(Exception, match="Connection failed"): + HiveCatalog._create_hive_client(properties) + assert mock_hive_client.call_count == 2