Skip to content

Commit

Permalink
Merge pull request #157 from singnet/improvement/refactor-traverseeng…
Browse files Browse the repository at this point in the history
…ine-request

[#153] Refactor the constructor of the TraverseEngine and the request method in the FunctionsClient
  • Loading branch information
marcocapozzoli authored Feb 28, 2024
2 parents c232450 + 8f2b946 commit 3ac9c76
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 25 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[#154] Refactor remote tests to use a single declaration of remote host/port
[#153] Refactor the constructor of the TraverseEngine and the request method in the FunctionsClient
32 changes: 26 additions & 6 deletions hyperon_das/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import contextlib
import json
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
from hyperon_das_atomdb import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist
from requests import exceptions, sessions

from hyperon_das.logger import logger

Expand All @@ -15,13 +16,32 @@ def __init__(self, url: str, server_count: int = 0, name: Optional[str] = None):

def _send_request(self, payload) -> Any:
try:
response = requests.request('POST', url=self.url, data=json.dumps(payload))
with sessions.Session() as session:
response = session.request(method='POST', url=self.url, data=json.dumps(payload))

response.raise_for_status()

try:
response_data = response.json()
except exceptions.JSONDecodeError as e:
raise Exception(f"JSON decode error: {str(e)}")

if response.status_code == 200:
return response.json()
return response_data
else:
return response.json()['error']
except requests.exceptions.RequestException as e:
raise e
return response_data.get(
'error', f'Unknown error with status code {response.status_code}'
)
except exceptions.ConnectionError as e:
raise Exception(f"Connection error: {str(e)}")
except exceptions.Timeout as e:
raise Exception(f"Request timed out: {str(e)}")
except exceptions.HTTPError as e:
with contextlib.suppress(exceptions.JSONDecodeError):
return response.json().get('error')
raise Exception(f"HTTP error occurred: {str(e)}")
except exceptions.RequestException as e:
raise Exception(f"Request exception occurred: {str(e)}")

def get_atom(self, handle: str, **kwargs) -> Union[str, Dict]:
payload = {
Expand Down
4 changes: 1 addition & 3 deletions hyperon_das/das.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,6 @@ def get_traversal_cursor(self, handle: str, **kwargs) -> TraverseEngine:
TraverseEngine: The object that allows traversal of the hypergraph
"""
try:
self.get_atom(handle)
return TraverseEngine(handle, das=self, **kwargs)
except AtomDoesNotExist:
raise GetTraversalCursorException(message="Cannot start Traversal. Atom does not exist")

return TraverseEngine(handle, das=self, **kwargs)
15 changes: 8 additions & 7 deletions hyperon_das/query_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union

import requests
from hyperon_das_atomdb import WILDCARD
from hyperon_das_atomdb.exceptions import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist
from requests import sessions

from hyperon_das.cache import (
AndEvaluator,
Expand Down Expand Up @@ -279,12 +279,13 @@ def _connect_server(self, host: str, port: Optional[str] = None):
def _is_server_connect(self, url: str) -> bool:
logger().debug(f'connecting to remote Das {url}')
try:
response = requests.request(
'POST',
url=url,
data=json.dumps({"action": "ping", "input": {}}),
timeout=10,
)
with sessions.Session() as session:
response = session.request(
method='POST',
url=url,
data=json.dumps({"action": "ping", "input": {}}),
timeout=10,
)
except Exception:
return False
if response.status_code == 200:
Expand Down
10 changes: 8 additions & 2 deletions hyperon_das/traverse_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
class TraverseEngine:
def __init__(self, handle: str, **kwargs) -> None:
self.das: DistributedAtomSpace = kwargs['das']
self._cursor = self.das.get_atom(handle)

try:
atom = self.das.get_atom(handle)
except AtomDoesNotExist as e:
raise e

self._cursor = atom

def get(self) -> Dict[str, Any]:
return self.das.get_atom(self._cursor['handle'])
return self._cursor

def get_links(self, **kwargs) -> QueryAnswerIterator:
incoming_links = self.das.get_incoming_links(
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class TestFunctionsClient:
@pytest.fixture
def mock_request(self):
with patch('requests.request') as mock_request:
with patch('requests.sessions.Session.request') as mock_request:
yield mock_request

def test_get_atom_success(self, mock_request):
Expand All @@ -26,7 +26,7 @@ def test_get_atom_success(self, mock_request):
result = client.get_atom(handle='123')

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_atom", "input": {"handle": "123"}}',
)
Expand All @@ -48,7 +48,7 @@ def test_get_node_success(self, mock_request):
result = client.get_node(node_type='Concept', node_name='human')

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_node", "input": {"node_type": "Concept", "node_name": "human"}}',
)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_get_link_success(self, mock_request):
)

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_link", "input": {"link_type": "Similarity", "link_targets": ["af12f10f9ae2002a1607ba0b47ba8407", "1cdffc6b0b89ff41d68bec237481d1e1"]}}',
)
Expand All @@ -107,7 +107,7 @@ def test_get_links_success(self, mock_request):
)

mock_request.assert_called_with(
'POST',
method='POST',
url='http://example.com',
data='{"action": "get_links", "input": {"link_type": "Inheritance", "kwargs": {}, "link_targets": ["4e8e26e3276af8a5c2ac2cc2dc95c6d2", "80aff30094874e75028033a38ce677bb"]}}',
)
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_count_atoms_success(self, mock_request):
result = client.count_atoms()

mock_request.assert_called_once_with(
'POST', url='http://example.com', data='{"action": "count_atoms", "input": {}}'
method='POST', url='http://example.com', data='{"action": "count_atoms", "input": {}}'
)

assert result == tuple(expected_response)

0 comments on commit 3ac9c76

Please sign in to comment.