Skip to content

Commit

Permalink
[das-serverless-function#90] Fix serializing/deserializing query answ…
Browse files Browse the repository at this point in the history
…ers (#194)

* Fix serializing/deserializing query answers

* Update CHANGELOG

* Fix unit tests

* update CHANGELOG

* Rename serialize and deserialize methods

* Fix get incoming links action
  • Loading branch information
levisingularity authored Mar 21, 2024
1 parent bb9bdf7 commit 9d981a3
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[#180] Fix in the test_metta_api.py integration test
[#136] Implement methods in the DAS API to create indexes in the database
[#BUGFIX] Fix Mock in unit tests
[#90] OpenFaas is not serializing/deserializing query answers
22 changes: 14 additions & 8 deletions hyperon_das/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import contextlib
import json
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union

from hyperon_das_atomdb import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist
from requests import exceptions, sessions

from hyperon_das.utils import serialize, deserialize
from hyperon_das.exceptions import ConnectionError, HTTPError, RequestError, TimeoutError
from hyperon_das.logger import logger

Expand All @@ -17,15 +17,21 @@ def __init__(self, url: str, server_count: int = 0, name: Optional[str] = None):

def _send_request(self, payload) -> Any:
try:
payload_serialized = serialize(payload)

with sessions.Session() as session:
response = session.request(method='POST', url=self.url, data=json.dumps(payload))
response = session.request(
method='POST',
url=self.url,
data=payload_serialized,
)

response.raise_for_status()

try:
response_data = response.json()
except exceptions.JSONDecodeError as e:
raise Exception(f"JSON decode error: {str(e)}")
response_data = deserialize(response.content)
except pickle.UnpicklingError as e:
raise Exception(f"Unpickling error: {str(e)}")

if response.status_code == 200:
return response_data
Expand All @@ -44,8 +50,8 @@ def _send_request(self, payload) -> Any:
details=str(e),
)
except exceptions.HTTPError as e:
with contextlib.suppress(exceptions.JSONDecodeError):
return response.json().get('error')
with contextlib.suppress(pickle.UnpicklingError):
return deserialize(response.content).get('error')
raise HTTPError(
message=f"HTTP error occurred for URL: '{self.url}' with payload: '{payload}'",
details=str(e),
Expand Down
39 changes: 19 additions & 20 deletions hyperon_das/query_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,53 +32,52 @@
UnexpectedQueryFormat,
)
from hyperon_das.logger import logger
from hyperon_das.utils import Assignment, QueryAnswer, get_package_version # noqa: F401
from hyperon_das.utils import Assignment, QueryAnswer, get_package_version, serialize # noqa: F401


class QueryEngine(ABC):
@abstractmethod
def get_atom(self, handle: str) -> Union[Dict[str, Any], None]:
... # pragma no cover
def get_atom(self, handle: str) -> Union[Dict[str, Any], None]: ... # pragma no cover

@abstractmethod
def get_node(self, node_type: str, node_name: str) -> Union[Dict[str, Any], None]:
... # pragma no cover
def get_node(
self, node_type: str, node_name: str
) -> Union[Dict[str, Any], None]: ... # pragma no cover

@abstractmethod
def get_link(self, link_type: str, targets: List[str]) -> Union[Dict[str, Any], None]:
... # pragma no cover
def get_link(
self, link_type: str, targets: List[str]
) -> Union[Dict[str, Any], None]: ... # pragma no cover

@abstractmethod
def get_links(
self, link_type: str, target_types: List[str] = None, link_targets: List[str] = None
) -> Union[List[str], List[Dict]]:
... # pragma no cover
) -> Union[List[str], List[Dict]]: ... # pragma no cover

@abstractmethod
def get_incoming_links(
self, atom_handle: str, **kwargs
) -> List[Union[dict, str, Tuple[dict, List[dict]]]]:
... # pragma no cover
) -> List[Union[dict, str, Tuple[dict, List[dict]]]]: ... # pragma no cover

@abstractmethod
def query(
self,
query: Dict[str, Any],
parameters: Optional[Dict[str, Any]] = {},
) -> Union[QueryAnswerIterator, List[Tuple[Assignment, Dict[str, str]]]]:
... # pragma no cover
) -> Union[QueryAnswerIterator, List[Tuple[Assignment, Dict[str, str]]]]: ... # pragma no cover

@abstractmethod
def count_atoms(self) -> Tuple[int, int]:
... # pragma no cover
def count_atoms(self) -> Tuple[int, int]: ... # pragma no cover

@abstractmethod
def reindex(self, pattern_index_templates: Optional[Dict[str, Dict[str, Any]]]):
... # pragma no cover
def reindex(
self, pattern_index_templates: Optional[Dict[str, Dict[str, Any]]]
): ... # pragma no cover

@abstractmethod
def create_field_index(self, atom_type: str, field: str, type: str = None) -> str:
... # pragma no cover
def create_field_index(
self, atom_type: str, field: str, type: str = None
) -> str: ... # pragma no cover


class LocalQueryEngine(QueryEngine):
Expand Down Expand Up @@ -339,7 +338,7 @@ def _is_server_connect(self, url: str) -> bool:
response = session.request(
method='POST',
url=url,
data=json.dumps({"action": "ping", "input": {}}),
data=serialize({"action": "ping", "input": {}}),
timeout=10,
)
except Exception:
Expand Down
9 changes: 9 additions & 0 deletions hyperon_das/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from importlib import import_module
from typing import Any, Dict, FrozenSet, List, Optional, Set, Union
import pickle

from hyperon_das.exceptions import InvalidAssignment

Expand Down Expand Up @@ -101,3 +102,11 @@ class QueryAnswer:
def get_package_version(package_name: str) -> str:
package_module = import_module(package_name)
return getattr(package_module, '__version__', None)


def serialize(payload: Any) -> bytes:
return pickle.dumps(payload)


def deserialize(payload: bytes) -> Any:
return pickle.loads(payload)
2 changes: 1 addition & 1 deletion tests/integration/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,4 @@ def test_get_incoming_links(self, server: FunctionsClient):
assert len(response_atoms_targets) == 8
for atom_targets in response_atoms_targets:
if len(atom_targets[0]["targets"]) == 3:
assert atom_targets in expected_atoms_targets
assert list(atom_targets) in expected_atoms_targets
80 changes: 65 additions & 15 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from hyperon_das.client import FunctionsClient
from hyperon_das.exceptions import ConnectionError, RequestError, TimeoutError
from hyperon_das.utils import serialize


class TestFunctionsClient:
Expand All @@ -15,6 +16,7 @@ def mock_request(self):
yield mock_request

def test_get_atom_success(self, mock_request):
expected_request_data = {"action": "get_atom", "input": {"handle": "123"}}
expected_response = {
"handle": "af12f10f9ae2002a1607ba0b47ba8407",
"composite_type_hash": "d99a604c79ce3c2e76a2f43488d5d4c3",
Expand All @@ -23,20 +25,24 @@ def test_get_atom_success(self, mock_request):
}

mock_request.return_value.status_code = 200
mock_request.return_value.json.return_value = expected_response
mock_request.return_value.content = serialize(expected_response)

client = FunctionsClient(url='http://example.com')
result = client.get_atom(handle='123')

mock_request.assert_called_with(
method='POST',
url='http://example.com',
data='{"action": "get_atom", "input": {"handle": "123"}}',
data=serialize(expected_request_data),
)

assert result == expected_response

def test_get_node_success(self, mock_request):
expected_request_data = {
"action": "get_node",
"input": {"node_type": "Concept", "node_name": "human"},
}
expected_response = {
"handle": "af12f10f9ae2002a1607ba0b47ba8407",
"composite_type_hash": "d99a604c79ce3c2e76a2f43488d5d4c3",
Expand All @@ -45,20 +51,30 @@ def test_get_node_success(self, mock_request):
}

mock_request.return_value.status_code = 200
mock_request.return_value.json.return_value = expected_response
mock_request.return_value.content = serialize(expected_response)

client = FunctionsClient(url='http://example.com')
result = client.get_node(node_type='Concept', node_name='human')

mock_request.assert_called_with(
method='POST',
url='http://example.com',
data='{"action": "get_node", "input": {"node_type": "Concept", "node_name": "human"}}',
data=serialize(expected_request_data),
)

assert result == expected_response

def test_get_link_success(self, mock_request):
expected_request_data = {
"action": "get_link",
"input": {
"link_type": "Similarity",
"link_targets": [
"af12f10f9ae2002a1607ba0b47ba8407",
"1cdffc6b0b89ff41d68bec237481d1e1",
],
},
}
expected_response = {
"handle": "bad7472f41a0e7d601ca294eb4607c3a",
"composite_type_hash": "ed73ea081d170e1d89fc950820ce1cee",
Expand All @@ -74,7 +90,7 @@ def test_get_link_success(self, mock_request):
}

mock_request.return_value.status_code = 200
mock_request.return_value.json.return_value = expected_response
mock_request.return_value.content = serialize(expected_response)

client = FunctionsClient(url='http://example.com')
result = client.get_link(
Expand All @@ -85,12 +101,23 @@ def test_get_link_success(self, mock_request):
mock_request.assert_called_with(
method='POST',
url='http://example.com',
data='{"action": "get_link", "input": {"link_type": "Similarity", "link_targets": ["af12f10f9ae2002a1607ba0b47ba8407", "1cdffc6b0b89ff41d68bec237481d1e1"]}}',
data=serialize(expected_request_data),
)

assert result == expected_response

def test_get_links_success(self, mock_request):
expected_request_data = {
"action": "get_links",
"input": {
"link_type": "Inheritance",
"kwargs": {},
"link_targets": [
"4e8e26e3276af8a5c2ac2cc2dc95c6d2",
"80aff30094874e75028033a38ce677bb",
],
},
}
expected_response = [
{
"handle": "ee1c03e6d1f104ccd811cfbba018451a",
Expand All @@ -101,7 +128,7 @@ def test_get_links_success(self, mock_request):
]

mock_request.return_value.status_code = 200
mock_request.return_value.json.return_value = expected_response
mock_request.return_value.content = serialize(expected_response)

client = FunctionsClient(url='http://example.com')
result = client.get_links(
Expand All @@ -112,12 +139,26 @@ def test_get_links_success(self, mock_request):
mock_request.assert_called_with(
method='POST',
url='http://example.com',
data='{"action": "get_links", "input": {"link_type": "Inheritance", "kwargs": {}, "link_targets": ["4e8e26e3276af8a5c2ac2cc2dc95c6d2", "80aff30094874e75028033a38ce677bb"]}}',
data=serialize(expected_request_data),
)

assert result == expected_response

def test_query_success(self, mock_request):
expected_request_data = {
"action": "query",
"input": {
"query": {
"atom_type": "link",
"type": "Similarity",
"targets": [
{"atom_type": "node", "type": "Concept", "name": "human"},
{"atom_type": "node", "type": "Concept", "name": "monkey"},
],
},
"parameters": [],
},
}
expected_response = [
{
"handle": "bad7472f41a0e7d601ca294eb4607c3a",
Expand All @@ -139,7 +180,7 @@ def test_query_success(self, mock_request):
]

mock_request.return_value.status_code = 200
mock_request.return_value.json.return_value = expected_response
mock_request.return_value.content = serialize(expected_response)

client = FunctionsClient(url='http://example.com')
query = {
Expand All @@ -153,24 +194,34 @@ def test_query_success(self, mock_request):

result = client.query(query, parameters=[])

mock_request.assert_called_with(
method='POST',
url='http://example.com',
data=serialize(expected_request_data),
)

assert result == expected_response

def test_count_atoms_success(self, mock_request):
expected_request_data = {"action": "count_atoms", "input": {}}
expected_response = (14, 26)

mock_request.return_value.status_code = 200
mock_request.return_value.json.return_value = expected_response
mock_request.return_value.content = serialize(expected_response)

client = FunctionsClient(url='http://example.com')
result = client.count_atoms()

mock_request.assert_called_once_with(
method='POST', url='http://example.com', data='{"action": "count_atoms", "input": {}}'
method='POST',
url='http://example.com',
data=serialize(expected_request_data),
)

assert result == tuple(expected_response)
assert result == expected_response

def test_send_request_success(self, mock_request):
payload = {"action": "get_atom", "input": {"handle": "123"}}
expected_response = {
"handle": "af12f10f9ae2002a1607ba0b47ba8407",
"composite_type_hash": "d99a604c79ce3c2e76a2f43488d5d4c3",
Expand All @@ -179,16 +230,15 @@ def test_send_request_success(self, mock_request):
}

mock_request.return_value.status_code = 200
mock_request.return_value.json.return_value = expected_response
mock_request.return_value.content = serialize(expected_response)

client = FunctionsClient(url='http://example.com')
payload = {"action": "get_atom", "input": {"handle": "123"}}
result = client._send_request(payload)

mock_request.assert_called_with(
method='POST',
url='http://example.com',
data=json.dumps(payload),
data=serialize(payload),
)

assert result == expected_response
Expand Down

0 comments on commit 9d981a3

Please sign in to comment.