diff --git a/hyperon_das/client.py b/hyperon_das/client.py index 14904b9a..8900bb53 100644 --- a/hyperon_das/client.py +++ b/hyperon_das/client.py @@ -3,6 +3,8 @@ import requests +from hyperon_das.logger import logger + class FunctionsClient: def __init__(self, url: str, server_count: int = 0, name: Optional[str] = None): @@ -80,3 +82,18 @@ def commit_changes(self) -> Tuple[int, int]: 'input': {}, } return self._send_request(payload) + + def get_incoming_links( + self, atom_handle: str, **kwargs + ) -> List[Union[Tuple[Dict[str, Any], List[Dict[str, Any]]], Dict[str, Any]]]: + payload = { + 'action': 'get_incoming_links', + 'input': {'atom_handle': atom_handle, 'kwargs': kwargs}, + } + response = self._send_request(payload) + if response and 'error' in response: + logger().debug( + f'Error during `get_incoming_links` request on remote Das: {response["error"]}' + ) + return [] + return response diff --git a/hyperon_das/das.py b/hyperon_das/das.py index 5d2b2e6f..dc94fa1c 100644 --- a/hyperon_das/das.py +++ b/hyperon_das/das.py @@ -188,6 +188,20 @@ def get_links( """ return self.query_engine.get_links(link_type, target_types, link_targets) + def get_incoming_links(self, atom_handle: str, **kwargs) -> List[Union[Dict[str, Any], str]]: + """Retrieve all links pointing to Atom + + Args: + atom_handle (str): The unique handle of the atom + kwargs (optional): You can send `handles_only` as a bool value. + True returns only atom handles. + + Returns: + List[Union[Dict[str, Any], str]]: A list of dictionaries containing detailed + atom information or a list of strings containing only the atom identifiers + """ + return self.query_engine.get_incoming_links(atom_handle, **kwargs) + def count_atoms(self) -> Tuple[int, int]: """ This method is useful for returning the count of atoms in the database. diff --git a/hyperon_das/query_engines.py b/hyperon_das/query_engines.py index 7a47506f..9af6f43d 100644 --- a/hyperon_das/query_engines.py +++ b/hyperon_das/query_engines.py @@ -37,6 +37,12 @@ def get_links( ) -> Union[List[str], List[Dict]]: ... + @abstractmethod + def get_incoming_links( + self, atom_handle: str, **kwargs + ) -> List[Union[dict, str, Tuple[dict, List[dict]]]]: + ... + @abstractmethod def query( self, @@ -155,6 +161,11 @@ def get_links( return self._to_link_dict_list(db_answer) + def get_incoming_links( + self, atom_handle: str, **kwargs + ) -> List[Union[dict, str, Tuple[dict, List[dict]]]]: + return self.local_backend.get_incoming_links(atom_handle, **kwargs) + def query( self, query: Dict[str, Any], @@ -255,6 +266,46 @@ def get_links( if not local: return self.remote_das.get_links(link_type, target_types, link_targets) + def get_incoming_links( + self, atom_handle: str, **kwargs + ) -> List[Union[dict, str, Tuple[dict, List[dict]]]]: + local_links = self.local_query_engine.get_incoming_links(atom_handle, **kwargs) + remote_links = self.remote_das.get_incoming_links(atom_handle, **kwargs) + + if not local_links and remote_links: + return remote_links + elif local_links and not remote_links: + return local_links + elif not local_links and not remote_links: + return [] + + if kwargs.get('handles_only', False): + return list(set(local_links + remote_links)) + else: + answer = [] + + if isinstance(remote_links[0], dict): + remote_links_dict = {link['handle']: link for link in remote_links} + else: + remote_links_dict = { + link['handle']: (link, targets) for link, targets in remote_links + } + + for local_link in local_links: + if isinstance(local_link, dict): + handle = local_link['handle'] + else: + handle = local_link[0]['handle'] + local_link = local_link[0] + + answer.append(local_link) + + remote_links_dict.pop(handle, None) + + answer.extend(remote_links_dict.values()) + + return answer + def query( self, query: Dict[str, Any], diff --git a/tests/unit/mock.py b/tests/unit/mock.py index e024d81e..f3000efc 100644 --- a/tests/unit/mock.py +++ b/tests/unit/mock.py @@ -4,7 +4,7 @@ from hyperon_das_atomdb import WILDCARD, AtomDB from hyperon_das import DistributedAtomSpace -from hyperon_das.das import LocalQueryEngine +from hyperon_das.das import LocalQueryEngine, RemoteQueryEngine def _build_node_handle(node_type: str, node_name: str) -> str: @@ -23,15 +23,16 @@ def _build_link_handle(link_type: str, target_handles: List[str]) -> str: class DistributedAtomSpaceMock(DistributedAtomSpace): - def __init__(self) -> None: - self.backend = DatabaseMock() - self.query_engine = LocalQueryEngine(self.backend) + def __init__(self, query_engine: Optional[str] = 'local', **kwargs) -> None: + self.backend = DatabaseAnimals() + if query_engine == 'remote': + self.query_engine = RemoteQueryEngine(self.backend, kwargs) + else: + self.query_engine = LocalQueryEngine(self.backend) class DatabaseMock(AtomDB): - def __init__(self, name: str = 'das'): - self.database_name = name - + def __init__(self): human = _build_node_handle('Concept', 'human') monkey = _build_node_handle('Concept', 'monkey') chimp = _build_node_handle('Concept', 'chimp') @@ -112,6 +113,7 @@ def __init__(self, name: str = 'das'): ] self.template_index = {} + self.incoming_set = {} for link in self.all_links: key = [link[0]] @@ -122,12 +124,21 @@ def __init__(self, name: str = 'das'): v = self.template_index.get(key, []) v.append([_build_link_handle(link[0], link[1:]), link[1:]]) self.template_index[key] = v + self._add_incoming_set(str(link), link[1:]) self.all_links.append(nested_link) def __repr__(self): return "" + def _add_incoming_set(self, key, targets): + for target in targets: + incoming_set = self.incoming_set.get(target) + if incoming_set is None: + self.incoming_set[target] = [key] + else: + self.incoming_set[target].append(key) + def node_exists(self, node_type: str, node_name: str) -> bool: return _build_node_handle(node_type, node_name) in self.all_nodes @@ -289,3 +300,90 @@ def add_link(self, link_params: Dict[str, Any], toplevel: bool = True) -> Dict[s def add_node(self, node_params: Dict[str, Any]) -> Dict[str, Any]: assert False + + def get_incoming_links(self, atom_handle: str, **kwargs): + links = self.incoming_set.get(atom_handle) + + if not links: + return [] + + return links + + def get_atom_type(self, handle: str) -> str: + pass + + +class DatabaseAnimals(DatabaseMock): + def __init__(self): + human = _build_node_handle('Concept', 'human') + monkey = _build_node_handle('Concept', 'monkey') + chimp = _build_node_handle('Concept', 'chimp') + snake = _build_node_handle('Concept', 'snake') + earthworm = _build_node_handle('Concept', 'earthworm') + rhino = _build_node_handle('Concept', 'rhino') + triceratops = _build_node_handle('Concept', 'triceratops') + vine = _build_node_handle('Concept', 'vine') + ent = _build_node_handle('Concept', 'ent') + mammal = _build_node_handle('Concept', 'mammal') + animal = _build_node_handle('Concept', 'animal') + reptile = _build_node_handle('Concept', 'reptile') + dinosaur = _build_node_handle('Concept', 'dinosaur') + plant = _build_node_handle('Concept', 'plant') + + self.all_nodes = [ + human, + monkey, + chimp, + snake, + earthworm, + rhino, + triceratops, + vine, + ent, + mammal, + animal, + reptile, + dinosaur, + plant, + ] + + self.all_links = [ + ['Similarity', human, monkey], + ['Similarity', human, chimp], + ['Similarity', chimp, monkey], + ['Similarity', snake, earthworm], + ['Similarity', rhino, triceratops], + ['Similarity', snake, vine], + ['Similarity', human, ent], + ['Inheritance', human, mammal], + ['Inheritance', monkey, mammal], + ['Inheritance', chimp, mammal], + ['Inheritance', mammal, animal], + ['Inheritance', reptile, animal], + ['Inheritance', snake, reptile], + ['Inheritance', dinosaur, reptile], + ['Inheritance', triceratops, dinosaur], + ['Inheritance', earthworm, animal], + ['Inheritance', rhino, mammal], + ['Inheritance', vine, plant], + ['Inheritance', ent, plant], + ['Similarity', monkey, human], + ['Similarity', chimp, human], + ['Similarity', monkey, chimp], + ['Similarity', earthworm, snake], + ['Similarity', triceratops, rhino], + ['Similarity', vine, snake], + ['Similarity', ent, human], + ] + + self.incoming_set = {} + + for link in self.all_links: + self._add_incoming_set(str(link), link[1:]) + + def add_link(self, link_params: Dict[str, Any], toplevel: bool = True) -> Dict[str, Any]: + if link_params in self.all_links: + index = self.all_links.index(link_params) + self.all_links[index] = link_params + else: + self.all_links.append(link_params) diff --git a/tests/unit/test_das.py b/tests/unit/test_das.py index a8e72785..08e983e7 100644 --- a/tests/unit/test_das.py +++ b/tests/unit/test_das.py @@ -6,6 +6,8 @@ from hyperon_das.das import DistributedAtomSpace, LocalQueryEngine, RemoteQueryEngine from hyperon_das.exceptions import InvalidQueryEngine +from .mock import DistributedAtomSpaceMock + class TestDistributedAtomSpace: def test_create_das(self): @@ -28,3 +30,38 @@ def test_create_das(self): assert exc.value.message == 'The possible values are: `local` or `remote`' assert exc.value.details == 'query_engine=snet' + + def test_get_incoming_links(self): + das = DistributedAtomSpaceMock() + links = das.get_incoming_links('', handles_only=True) + assert len(links) == 7 + + links = das.get_incoming_links('') + assert len(links) == 7 + + with mock.patch( + 'hyperon_das.query_engines.RemoteQueryEngine._connect_server', return_value='fake' + ): + das_remote = DistributedAtomSpaceMock('remote', host='test') + + with mock.patch('hyperon_das.client.FunctionsClient.get_incoming_links', return_value=[]): + links = das_remote.get_incoming_links('') + assert len(links) == 7 + + with mock.patch( + 'hyperon_das.client.FunctionsClient.get_incoming_links', return_value=[1, 2, 3, 4] + ): + links = das_remote.get_incoming_links('') + assert links == [1, 2, 3, 4] + + with mock.patch( + 'hyperon_das.client.FunctionsClient.get_incoming_links', + return_value=["['Inheritance', '', '']"], + ): + links = das_remote.get_incoming_links('', handles_only=True) + assert set(links) == { + "['Inheritance', '', '']", + "['Similarity', '', '']", + "['Similarity', '', '']", + "['Inheritance', '', '']", + }