Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#348] supporting set instead of list #349

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions hyperon_das/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, host: str, port: int, name: Optional[str] = None) -> None:
if not host and not port:
das_error(ValueError("'host' and 'port' are mandatory parameters"))
self.name = name if name else f'client_{host}:{port}'
self.url = connect_to_server(host, port)
self.status_code, self.url = connect_to_server(host, port)

def _send_request(self, payload) -> Any:
try:
Expand Down Expand Up @@ -164,9 +164,7 @@ def commit_changes(self, **kwargs) -> Tuple[int, int]:
else:
raise e

def get_incoming_links(
self, atom_handle: str, **kwargs
) -> tuple[int | None, IncomingLinksT | Iterator]:
def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT | Iterator:
payload = {
'action': 'get_incoming_links',
'input': {'atom_handle': atom_handle, 'kwargs': kwargs},
Expand All @@ -175,7 +173,7 @@ def get_incoming_links(
return self._send_request(payload)
except HTTPError as e:
logger().debug(f'Error during `get_incoming_links` request on remote Das: {str(e)}')
return None, []
return []

def create_field_index(
self,
Expand Down
146 changes: 75 additions & 71 deletions hyperon_das/das.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion hyperon_das/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def wrapper(*args, **kwargs):
logger().debug(
f'{retry_count + 1} successful connection attempt at [host={args[1]}]'
)
return response
return status, response
except Exception as e:
raise RetryConnectionError(
message="An error occurs while connecting to the server",
Expand Down
1 change: 0 additions & 1 deletion hyperon_das/link_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class FlatTypeTemplate(LinkFilter):
def __init__(
self, target_types: list[str], link_type: str = WILDCARD, toplevel_only: bool = False
):

self.filter_type = LinkFilterType.FLAT_TYPE_TEMPLATE
self.link_type = link_type
self.target_types = target_types
Expand Down
33 changes: 20 additions & 13 deletions hyperon_das/query_engines/local_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

from hyperon_das_atomdb import WILDCARD, AtomDB
from hyperon_das_atomdb.adapters import InMemoryDB
from hyperon_das_atomdb.database import AtomT, HandleListT, IncomingLinksT, LinkT
from hyperon_das_atomdb.database import (
AtomT,
HandleListT,
HandleSetT,
HandleT,
IncomingLinksT,
LinkT,
)
from hyperon_das_atomdb.exceptions import AtomDoesNotExist

from hyperon_das.cache.cache_controller import CacheController
Expand Down Expand Up @@ -86,9 +93,9 @@ def _get_related_links(
self,
link_type: str,
target_types: list[str] | None = None,
link_targets: list[str] | None = None,
link_targets: HandleListT | None = None,
**kwargs,
) -> HandleListT:
) -> HandleSetT:
if link_type != WILDCARD and target_types is not None:
return self.local_backend.get_matched_type_template(
[link_type, *target_types], **kwargs
Expand All @@ -97,7 +104,7 @@ def _get_related_links(
try:
return self.local_backend.get_matched_links(link_type, link_targets, **kwargs)
except AtomDoesNotExist:
return None, []
return set()
elif link_type != WILDCARD:
return self.local_backend.get_all_links(link_type, **kwargs)
else:
Expand Down Expand Up @@ -145,10 +152,11 @@ def _process_link(self, query: dict) -> List[dict]:

def _generate_target_handles(
self, targets: List[Dict[str, Any]]
) -> list[str | list[str] | list[Any]]: # multiple levels of nested lists due to recursion
targets_hash: list[str | list[str] | list[Any]] = []
) -> list[HandleT | HandleListT | list[Any]]: # multiple levels of nested lists due to
# recursion
targets_hash: list[HandleT | HandleListT | list[Any]] = []
for target in targets:
handle: str | list[str] | None = None
handle: HandleT | HandleListT | None = None
if target["atom_type"] == "node":
handle = self.local_backend.node_handle(target["type"], target["name"])
elif target["atom_type"] == "link":
Expand Down Expand Up @@ -190,7 +198,7 @@ def get_atom(self, handle: str, **kwargs) -> Dict[str, Any]:
def get_atoms(self, handles: str, **kwargs) -> List[Dict[str, Any]]:
return [self.local_backend.get_atom(handle, **kwargs) for handle in handles]

def get_link_handles(self, link_filter: LinkFilter) -> List[str]:
def get_link_handles(self, link_filter: LinkFilter) -> HandleSetT:
if link_filter.filter_type == LinkFilterType.FLAT_TYPE_TEMPLATE:
return self.local_backend.get_matched_type_template(
[link_filter.link_type, *link_filter.target_types],
Expand All @@ -201,10 +209,9 @@ def get_link_handles(self, link_filter: LinkFilter) -> List[str]:
link_filter.link_type, link_filter.targets, toplevel_only=link_filter.toplevel_only
)
elif link_filter.filter_type == LinkFilterType.NAMED_TYPE:
_, answer = self.local_backend.get_all_links(
return self.local_backend.get_all_links(
link_filter.link_type, toplevel_only=link_filter.toplevel_only
)
return answer
else:
das_error(ValueError("Invalid LinkFilterType: {link_filter.filter_type}"))

Expand Down Expand Up @@ -318,13 +325,13 @@ def create_context(
) -> Context: # type: ignore
das_error(NotImplementedError("Contexts are not implemented for non-server local DAS"))

def get_atoms_by_field(self, query: list[OrderedDict[str, str]]) -> List[str]:
def get_atoms_by_field(self, query: list[OrderedDict[str, str]]) -> HandleListT:
return self.local_backend.get_atoms_by_field(query)

def get_atoms_by_text_field(
self, text_value: str, field: Optional[str] = None, text_index_id: Optional[str] = None
) -> List[str]:
) -> HandleListT:
return self.local_backend.get_atoms_by_text_field(text_value, field, text_index_id)

def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> List[str]:
def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> HandleListT:
return self.local_backend.get_node_by_name_starting_with(node_type, startswith)
43 changes: 25 additions & 18 deletions hyperon_das/query_engines/query_engine_protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Union

from hyperon_das_atomdb.database import IncomingLinksT, LinkT
from hyperon_das_atomdb.database import (
AtomT,
HandleListT,
HandleSetT,
HandleT,
IncomingLinksT,
LinkT,
)

from hyperon_das.context import Context
from hyperon_das.link_filters import LinkFilter
Expand All @@ -11,23 +18,23 @@

class QueryEngine(ABC):
@abstractmethod
def get_atom(self, handle: str) -> Dict[str, Any]:
def get_atom(self, handle: HandleT) -> AtomT:
"""
Retrieves an atom from the database using its unique handle.

This method searches the database for an atom with the specified handle. If found, it returns
the atom's data as a dictionary. If no atom with the given handle exists, an exception is thrown.

Args:
handle (str): The unique handle of the atom to retrieve.
handle (HandleT): The unique handle of the atom to retrieve.

Returns:
Dict[str, Any]: A dictionary containing the atom's data.
AtomT: A dictionary containing the atom's data.
"""
...

@abstractmethod
def get_atoms(self, handles: List[str], **kwargs) -> List[Dict[str, Any]]:
def get_atoms(self, handles: HandleListT, **kwargs) -> List[AtomT]:
"""
Retrieves atoms from the database using their unique handles.

Expand All @@ -37,15 +44,15 @@ def get_atoms(self, handles: List[str], **kwargs) -> List[Dict[str, Any]]:
Remote query engines do a single request to remote DAS in order to get all the requested atoms.

Args:
handles (List[str]): Unique handle of the atoms to retrieve.
handles (HandleListT): List of atoms handles to retrieve.

Returns:
List[Dict[str, Any]]: List with requested atoms.
List[AtomT]: List with requested atoms.
"""
...

@abstractmethod
def get_links(link_filter: LinkFilter) -> List[LinkT]:
def get_links(self, link_filter: LinkFilter) -> List[LinkT]:
"""
Retrieves all links that match the passed filtering criteria.

Expand All @@ -58,20 +65,20 @@ def get_links(link_filter: LinkFilter) -> List[LinkT]:
...

@abstractmethod
def get_link_handles(link_filter: LinkFilter) -> List[LinkT]:
def get_link_handles(self, link_filter: LinkFilter) -> HandleSetT:
"""
Retrieve the handle of all links that match the passed filtering criteria.

Args:
link_filter (LinkFilter): Filtering criteria to be used to select links

Returns:
List[str]: A list of link handles
HandleSetT: Link handles
"""
...

@abstractmethod
def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:
def get_incoming_links(self, atom_handle: HandleT, **kwargs) -> IncomingLinksT:
"""
Retrieves incoming links for a specified atom handle.

Expand All @@ -81,7 +88,7 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:
implementation and the provided keyword arguments.

Args:
atom_handle (str): The unique handle of the atom for which incoming links are to be
atom_handle (HandleT): The unique handle of the atom for which incoming links are to be
retrieved.

Keyword Args:
Expand Down Expand Up @@ -291,7 +298,7 @@ def commit(self, **kwargs) -> None:
...

@abstractmethod
def get_atoms_by_field(self, query: Query) -> List[str]:
def get_atoms_by_field(self, query: Query) -> HandleListT:
"""
Retrieves a list of atom handles based on a specified field query.

Expand All @@ -303,14 +310,14 @@ def get_atoms_by_field(self, query: Query) -> List[str]:
query (Query): The query specifying the field and value(s) to filter atoms by.

Returns:
List[str]: A list of atom handles that match the query criteria.
HandleListT: A list of atom handles that match the query criteria.
"""
...

@abstractmethod
def get_atoms_by_text_field(
self, text_value: str, field: Optional[str] = None, text_index_id: Optional[str] = None
) -> List[str]:
) -> HandleListT:
"""
Retrieves a list of atom handles based on a text field value, with optional field and index ID.

Expand All @@ -326,12 +333,12 @@ def get_atoms_by_text_field(
optimize the search process if provided. Defaults to None.

Returns:
List[str]: A list of atom handles that match the search criteria.
HandleListT: A list of atom handles that match the search criteria.
"""
...

@abstractmethod
def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> List[str]:
def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> HandleListT:
"""
Retrieves a list of node handles where the node name starts with a specified string.

Expand All @@ -343,6 +350,6 @@ def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> Lis
startswith (str): The initial string of the node names to match.

Returns:
List[str]: A list of handles for the nodes that match the search criteria.
HandleListT: A list of handles for the nodes that match the search criteria.
"""
...
32 changes: 19 additions & 13 deletions hyperon_das/query_engines/remote_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional

from hyperon_das_atomdb.database import IncomingLinksT, LinkT
from hyperon_das_atomdb.database import (
AtomT,
HandleListT,
HandleSetT,
HandleT,
IncomingLinksT,
LinkT,
)
from hyperon_das_atomdb.exceptions import AtomDoesNotExist

from hyperon_das.cache.cache_controller import CacheController
Expand Down Expand Up @@ -46,7 +53,7 @@ def __init__(
def mode(self):
return self.__mode

def get_atom(self, handle: str, **kwargs) -> Dict[str, Any]:
def get_atom(self, handle: HandleT, **kwargs) -> AtomT:
atom = self.cache_controller.get_atom(handle)
if atom is None:
try:
Expand All @@ -58,7 +65,7 @@ def get_atom(self, handle: str, **kwargs) -> Dict[str, Any]:
das_error(exception)
return atom

def get_atoms(self, handles: List[str], **kwargs) -> List[Dict[str, Any]]:
def get_atoms(self, handles: HandleListT, **kwargs) -> List[AtomT]:
return self.cache_controller.get_atoms(handles)

def get_links(self, link_filter: LinkFilter) -> List[LinkT]:
Expand All @@ -67,11 +74,11 @@ def get_links(self, link_filter: LinkFilter) -> List[LinkT]:
links.extend(remote_links)
return links

def get_link_handles(self, link_filter: LinkFilter) -> List[str]:
def get_link_handles(self, link_filter: LinkFilter) -> HandleSetT:
# TODO Implement get_link_handles() in faas client
return [link['handle'] for link in self.get_links(link_filter)]
return {link['handle'] for link in self.get_links(link_filter)}

def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:
def get_incoming_links(self, atom_handle: HandleT, **kwargs) -> IncomingLinksT:
links = self.local_query_engine.get_incoming_links(atom_handle, **kwargs)
remote_links = self.remote_das.get_incoming_links(atom_handle, **kwargs)
links.extend(remote_links)
Expand All @@ -89,10 +96,9 @@ def custom_query(self, index_id: str, query: Query, **kwargs) -> Iterator:
kwargs.pop('no_iterator', None)
if kwargs.get('cursor') is None:
kwargs['cursor'] = 0
cursor, answer = self.remote_das.custom_query(index_id, query=query, **kwargs)
answer = self.remote_das.custom_query(index_id, query=query, **kwargs)
kwargs['backend'] = self.remote_das
kwargs['index_id'] = index_id
kwargs['cursor'] = cursor
kwargs['is_remote'] = True
return CustomQuery(ListIterator(answer), **kwargs)

Expand Down Expand Up @@ -147,8 +153,8 @@ def count_atoms(self, parameters: Optional[Dict[str, Any]] = None) -> Dict[str,
def commit(self, **kwargs) -> None:
if self.__mode == 'read-write':
if self.local_query_engine.has_buffer():
return self.remote_das.commit_changes(buffer=self.local_query_engine.buffer)
return self.remote_das.commit_changes()
self.remote_das.commit_changes(buffer=self.local_query_engine.buffer)
self.remote_das.commit_changes()
elif self.__mode == 'read-only':
das_error(PermissionError("Commit can't be executed in read mode"))
else:
Expand Down Expand Up @@ -189,13 +195,13 @@ def create_context(
) -> Context:
return self.remote_das.create_context(name, queries)

def get_atoms_by_field(self, query: Query) -> List[str]:
def get_atoms_by_field(self, query: Query) -> HandleListT:
return self.remote_das.get_atoms_by_field(query)

def get_atoms_by_text_field(
self, text_value: str, field: Optional[str] = None, text_index_id: Optional[str] = None
) -> List[str]:
) -> HandleListT:
return self.remote_das.get_atoms_by_text_field(text_value, field, text_index_id)

def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> List[str]:
def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> HandleListT:
return self.remote_das.get_node_by_name_starting_with(node_type, startswith)
2 changes: 1 addition & 1 deletion tests/unit/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def successful_function(self, host, port):

result = successful_function({}, 'localhost', 80)

assert result == 'Success'
assert result == (200, 'Success')


@patch('hyperon_das.logger')
Expand Down
Loading