diff --git a/CHANGELOG b/CHANGELOG index 7bd3a538..b78a8c50 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1 +1 @@ -[#154] Refactor remote tests to use a single declaration of remote host/port +[#153] Refactor the constructor of the TraverseEngine and the request method in the FunctionsClient diff --git a/hyperon_das/client.py b/hyperon_das/client.py index 4c34b34d..582ad5f4 100644 --- a/hyperon_das/client.py +++ b/hyperon_das/client.py @@ -1,8 +1,9 @@ +import contextlib import json from typing import Any, Dict, List, Optional, Tuple, Union -import requests from hyperon_das_atomdb import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist +from requests import exceptions, sessions from hyperon_das.logger import logger @@ -15,13 +16,32 @@ def __init__(self, url: str, server_count: int = 0, name: Optional[str] = None): def _send_request(self, payload) -> Any: try: - response = requests.request('POST', url=self.url, data=json.dumps(payload)) + with sessions.Session() as session: + response = session.request(method='POST', url=self.url, data=json.dumps(payload)) + + response.raise_for_status() + + try: + response_data = response.json() + except exceptions.JSONDecodeError as e: + raise Exception(f"JSON decode error: {str(e)}") + if response.status_code == 200: - return response.json() + return response_data else: - return response.json()['error'] - except requests.exceptions.RequestException as e: - raise e + return response_data.get( + 'error', f'Unknown error with status code {response.status_code}' + ) + except exceptions.ConnectionError as e: + raise Exception(f"Connection error: {str(e)}") + except exceptions.Timeout as e: + raise Exception(f"Request timed out: {str(e)}") + except exceptions.HTTPError as e: + with contextlib.suppress(exceptions.JSONDecodeError): + return response.json().get('error') + raise Exception(f"HTTP error occurred: {str(e)}") + except exceptions.RequestException as e: + raise Exception(f"Request exception occurred: {str(e)}") def get_atom(self, handle: str, **kwargs) -> Union[str, Dict]: payload = { diff --git a/hyperon_das/das.py b/hyperon_das/das.py index cebc2b83..eb31487e 100644 --- a/hyperon_das/das.py +++ b/hyperon_das/das.py @@ -496,8 +496,6 @@ def get_traversal_cursor(self, handle: str, **kwargs) -> TraverseEngine: TraverseEngine: The object that allows traversal of the hypergraph """ try: - self.get_atom(handle) + return TraverseEngine(handle, das=self, **kwargs) except AtomDoesNotExist: raise GetTraversalCursorException(message="Cannot start Traversal. Atom does not exist") - - return TraverseEngine(handle, das=self, **kwargs) diff --git a/hyperon_das/query_engines.py b/hyperon_das/query_engines.py index 7fee68cd..95da73a4 100644 --- a/hyperon_das/query_engines.py +++ b/hyperon_das/query_engines.py @@ -2,9 +2,9 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union -import requests from hyperon_das_atomdb import WILDCARD from hyperon_das_atomdb.exceptions import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist +from requests import sessions from hyperon_das.cache import ( AndEvaluator, @@ -279,12 +279,13 @@ def _connect_server(self, host: str, port: Optional[str] = None): def _is_server_connect(self, url: str) -> bool: logger().debug(f'connecting to remote Das {url}') try: - response = requests.request( - 'POST', - url=url, - data=json.dumps({"action": "ping", "input": {}}), - timeout=10, - ) + with sessions.Session() as session: + response = session.request( + method='POST', + url=url, + data=json.dumps({"action": "ping", "input": {}}), + timeout=10, + ) except Exception: return False if response.status_code == 200: diff --git a/hyperon_das/traverse_engines.py b/hyperon_das/traverse_engines.py index 0e2f325d..be4ef422 100644 --- a/hyperon_das/traverse_engines.py +++ b/hyperon_das/traverse_engines.py @@ -11,10 +11,16 @@ class TraverseEngine: def __init__(self, handle: str, **kwargs) -> None: self.das: DistributedAtomSpace = kwargs['das'] - self._cursor = self.das.get_atom(handle) + + try: + atom = self.das.get_atom(handle) + except AtomDoesNotExist as e: + raise e + + self._cursor = atom def get(self) -> Dict[str, Any]: - return self.das.get_atom(self._cursor['handle']) + return self._cursor def get_links(self, **kwargs) -> QueryAnswerIterator: incoming_links = self.das.get_incoming_links( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 584110cb..38f7db25 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -8,7 +8,7 @@ class TestFunctionsClient: @pytest.fixture def mock_request(self): - with patch('requests.request') as mock_request: + with patch('requests.sessions.Session.request') as mock_request: yield mock_request def test_get_atom_success(self, mock_request): @@ -26,7 +26,7 @@ def test_get_atom_success(self, mock_request): result = client.get_atom(handle='123') mock_request.assert_called_with( - 'POST', + method='POST', url='http://example.com', data='{"action": "get_atom", "input": {"handle": "123"}}', ) @@ -48,7 +48,7 @@ def test_get_node_success(self, mock_request): result = client.get_node(node_type='Concept', node_name='human') mock_request.assert_called_with( - 'POST', + method='POST', url='http://example.com', data='{"action": "get_node", "input": {"node_type": "Concept", "node_name": "human"}}', ) @@ -80,7 +80,7 @@ def test_get_link_success(self, mock_request): ) mock_request.assert_called_with( - 'POST', + method='POST', url='http://example.com', data='{"action": "get_link", "input": {"link_type": "Similarity", "link_targets": ["af12f10f9ae2002a1607ba0b47ba8407", "1cdffc6b0b89ff41d68bec237481d1e1"]}}', ) @@ -107,7 +107,7 @@ def test_get_links_success(self, mock_request): ) mock_request.assert_called_with( - 'POST', + method='POST', url='http://example.com', data='{"action": "get_links", "input": {"link_type": "Inheritance", "kwargs": {}, "link_targets": ["4e8e26e3276af8a5c2ac2cc2dc95c6d2", "80aff30094874e75028033a38ce677bb"]}}', ) @@ -162,7 +162,7 @@ def test_count_atoms_success(self, mock_request): result = client.count_atoms() mock_request.assert_called_once_with( - 'POST', url='http://example.com', data='{"action": "count_atoms", "input": {}}' + method='POST', url='http://example.com', data='{"action": "count_atoms", "input": {}}' ) assert result == tuple(expected_response)