Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#223] handling/using/returning set instead of list when possible #226

Merged
merged 10 commits into from
Sep 27, 2024
37 changes: 18 additions & 19 deletions hyperon_das_atomdb/adapters/ram_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
AtomT,
FieldIndexType,
FieldNames,
HandleListT,
HandleSetT,
IncomingLinksT,
LinkParamsT,
LinkT,
Expand Down Expand Up @@ -205,13 +205,13 @@ def _add_incoming_set(self, key: str, targets_hash: list[str]) -> None:
for target_hash in targets_hash:
self.db.incoming_set.setdefault(target_hash, set()).add(key)

def _delete_incoming_set(self, link_handle: str, atoms_handle: list[str]) -> None:
def _delete_incoming_set(self, link_handle: str, atoms_handle: Iterable[str]) -> None:
"""
Delete an incoming set from the database.

Args:
link_handle (str): The handle of the link to delete.
atoms_handle (list[str]): A list of atom handles associated with the link.
atoms_handle (Iterable[str]): A Iterable of atom handles associated with the link.
"""
for atom_handle in atoms_handle:
if handles := self.db.incoming_set.get(atom_handle):
Expand Down Expand Up @@ -301,26 +301,26 @@ def _delete_link_and_update_index(self, link_handle: str) -> None:
if link_document := self._get_and_delete_link(link_handle):
self._update_index(atom=link_document, delete_atom=True)

def _filter_non_toplevel(self, matches: HandleListT) -> HandleListT:
def _filter_non_toplevel(self, matches: HandleSetT) -> HandleSetT:
"""
Filter out non-toplevel matches from the provided list.

Args:
matches (HandleListT): A list of matches
matches (HandleSetT): A set of matches

Returns:
HandleListT: Filtered matches
HandleSetT: Filtered matches
"""
if not self.db.link:
return matches
return [
return {
link_handle
for link_handle in matches
if (link := self.db.link.get(link_handle)) and link.get(FieldNames.IS_TOPLEVEL)
]
}

@staticmethod
def _build_targets_list(link: dict[str, Any]) -> list[Any]:
def _build_targets_list(link: dict[str, Any]) -> list[str]:
"""
Build a list of target handles from the given link document.

Expand Down Expand Up @@ -512,7 +512,7 @@ def get_link_type(self, link_handle: str) -> str | None:
def get_link_targets(self, link_handle: str) -> list[str]:
answer = self.db.outgoing_set.get(link_handle)
if answer is not None:
return list(answer)
return answer
logger().error(
f"Failed to retrieve link targets for {link_handle}. This link may not exist."
)
Expand All @@ -521,21 +521,20 @@ def get_link_targets(self, link_handle: str) -> list[str]:
details=f"link_handle: {link_handle}",
)

def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs) -> HandleListT:
def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs) -> HandleSetT:
if link_type != WILDCARD and WILDCARD not in target_handles:
try:
answer = [self.get_link_handle(link_type, target_handles)]
return {self.get_link_handle(link_type, target_handles)}
except AtomDoesNotExist:
answer = []
return answer
return set()

link_type_hash = (
WILDCARD if link_type == WILDCARD else ExpressionHasher.named_type_hash(link_type)
)

pattern_hash = ExpressionHasher.composite_hash([link_type_hash, *target_handles])

patterns_matched = list(pattern) if (pattern := self.db.patterns.get(pattern_hash)) else []
patterns_matched = self.db.patterns.get(pattern_hash, set())

if kwargs.get("toplevel_only", False):
return self._filter_non_toplevel(patterns_matched)
Expand All @@ -548,17 +547,17 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:
return list(links)
return [self.get_atom(handle, **kwargs) for handle in links]

def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleListT:
def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT:
hash_base = self._build_named_type_hash_template(template)
template_hash = ExpressionHasher.composite_hash(hash_base)
templates_matched = list(self.db.templates.get(template_hash, set()))
templates_matched = self.db.templates.get(template_hash, set())
if kwargs.get("toplevel_only", False):
return self._filter_non_toplevel(templates_matched)
return templates_matched

def get_matched_type(self, link_type: str, **kwargs) -> HandleListT:
def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT:
link_type_hash = ExpressionHasher.named_type_hash(link_type)
templates_matched = list(self.db.templates.get(link_type_hash, set()))
templates_matched = self.db.templates.get(link_type_hash, set())
if kwargs.get("toplevel_only", False):
return self._filter_non_toplevel(templates_matched)
return templates_matched
Expand Down
42 changes: 21 additions & 21 deletions hyperon_das_atomdb/adapters/redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
AtomT,
FieldIndexType,
FieldNames,
HandleListT,
HandleSetT,
IncomingLinksT,
LinkParamsT,
LinkT,
Expand Down Expand Up @@ -487,7 +487,7 @@ def _get_document_keys(document: dict[str, Any]) -> list[str]:
index += 1
return answer

def _filter_non_toplevel(self, matches: HandleListT) -> HandleListT:
def _filter_non_toplevel(self, matches: HandleSetT) -> HandleSetT:
"""
Filter out non-toplevel links from the given list of matches.

Expand All @@ -496,16 +496,16 @@ def _filter_non_toplevel(self, matches: HandleListT) -> HandleListT:
are included in the returned list.

Args:
matches (HandleListT): A list of link handles to be filtered.
matches (HandleSetT): A set of link handles to be filtered.

Returns:
HandleListT: A list of handles corresponding to toplevel links.
HandleSetT: A set of handles corresponding to toplevel links.
"""
return [
return {
link_handle
for link_handle in matches
if (link := self._retrieve_document(link_handle)) and link.get(FieldNames.IS_TOPLEVEL)
]
}

def get_node_handle(self, node_type: str, node_name: str) -> str:
node_handle = self.node_handle(node_type, node_name)
Expand Down Expand Up @@ -660,13 +660,13 @@ def get_link_targets(self, link_handle: str) -> list[str]:
raise ValueError(f"Invalid handle: {link_handle}")
return answer

def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs) -> HandleListT:
def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs) -> HandleSetT:
if link_type != WILDCARD and WILDCARD not in target_handles:
try:
link_handle = self.get_link_handle(link_type, target_handles)
return [link_handle]
return {link_handle}
except AtomDoesNotExist:
return []
return set()

link_type_hash = (
WILDCARD if link_type == WILDCARD else ExpressionHasher.named_type_hash(link_type)
Expand All @@ -685,11 +685,11 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:
links = self._retrieve_incoming_set(atom_handle, **kwargs)

if kwargs.get("handles_only", False):
return links
return list(links)
else:
return [self.get_atom(handle, **kwargs) for handle in links]

def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleListT:
def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT:
try:
hash_base: list[str] = self._build_named_type_hash_template(template) # type: ignore
template_hash = ExpressionHasher.composite_hash(hash_base)
Expand All @@ -704,7 +704,7 @@ def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleList
logger().error(f"Failed to get matched type template - Details: {str(exception)}")
raise ValueError(str(exception))

def get_matched_type(self, link_type: str, **kwargs) -> HandleListT:
def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT:
named_type_hash = ExpressionHasher.named_type_hash(link_type)
templates_matched = self._retrieve_hash_targets_value(
KeyPrefix.TEMPLATES, named_type_hash, **kwargs
Expand Down Expand Up @@ -864,7 +864,7 @@ def _apply_index_template(
key.append(WILDCARD if cursor in target_selected_pos else targets[cursor])
return _build_redis_key(KeyPrefix.PATTERNS, ExpressionHasher.composite_hash(key))

def _retrieve_incoming_set(self, handle: str, **kwargs) -> HandleListT:
def _retrieve_incoming_set(self, handle: str, **kwargs) -> HandleSetT:
"""
Retrieve the incoming set for the given handle from Redis.

Expand All @@ -876,10 +876,10 @@ def _retrieve_incoming_set(self, handle: str, **kwargs) -> HandleListT:
**kwargs: Additional keyword arguments.

Returns:
HandleListT: List of members for the given key
HandleSetT: Set of members for the given key
"""
key = _build_redis_key(KeyPrefix.INCOMING_SET, handle)
return list(self._get_redis_members(key, **kwargs))
return self._get_redis_members(key, **kwargs)

def _delete_smember_incoming_set(self, handle: str, smember: str) -> None:
"""
Expand Down Expand Up @@ -968,7 +968,7 @@ def _retrieve_name(self, handle: str) -> str | None:
else:
return None

def _retrieve_hash_targets_value(self, key_prefix: str, handle: str, **kwargs) -> HandleListT:
def _retrieve_hash_targets_value(self, key_prefix: str, handle: str, **kwargs) -> HandleSetT:
"""
Retrieve the hash targets value for the given handle from Redis.

Expand All @@ -983,10 +983,10 @@ def _retrieve_hash_targets_value(self, key_prefix: str, handle: str, **kwargs) -
**kwargs: Additional keyword arguments

Returns:
HandleListT: List of members in the hash targets value.
HandleSetT: Set of members in the hash targets value.
"""
key = _build_redis_key(key_prefix, handle)
return list(self._get_redis_members(key, **kwargs))
return self._get_redis_members(key, **kwargs)

def _delete_smember_template(self, handle: str, smember: str) -> None:
"""
Expand Down Expand Up @@ -1046,17 +1046,17 @@ def _retrieve_custom_index(self, index_id: str) -> dict[str, Any] | None:
logger().error(f"Unexpected error retrieving custom index with ID {index_id}: {e}")
raise e

def _get_redis_members(self, key: str, **kwargs) -> HandleListT:
def _get_redis_members(self, key: str, **kwargs) -> HandleSetT:
"""
Retrieve members from a Redis set.

Args:
key (str): The key of the set in Redis.

Returns:
HandleListT: List of members retrieved from Redis.
HandleSetT: Set of members retrieved from Redis.
"""
return list(self.redis.smembers(key)) # type: ignore
return set(self.redis.smembers(key)) # type: ignore

def _update_atom_indexes(self, documents: Iterable[dict[str, Any]], **kwargs) -> None:
"""
Expand Down
14 changes: 8 additions & 6 deletions hyperon_das_atomdb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@

HandleListT: TypeAlias = list[HandleT]

HandleSetT: TypeAlias = set[HandleT]

IncomingLinksT: TypeAlias = HandleListT | list[AtomT]

# pylint: enable=invalid-name
Expand Down Expand Up @@ -537,7 +539,7 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:
"""

@abstractmethod
def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs) -> HandleListT:
def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs) -> HandleSetT:
"""
Retrieve links that match a specified link type and target handles.

Expand All @@ -548,11 +550,11 @@ def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs)
purposes.

Returns:
HandleListT: List of matching link handles.
HandleSetT: Set of matching link handles.
"""

@abstractmethod
def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleListT:
def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT:
"""
Retrieve links that match a specified type template.

Expand All @@ -562,11 +564,11 @@ def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleList
purposes.

Returns:
HandleListT: List of matching link handles.
HandleSetT: Set of matching link handles.
"""

@abstractmethod
def get_matched_type(self, link_type: str, **kwargs) -> HandleListT:
def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT:
"""
Retrieve links that match a specified link type.

Expand All @@ -576,7 +578,7 @@ def get_matched_type(self, link_type: str, **kwargs) -> HandleListT:
purposes.

Returns:
HandleListT: List of matching link handles.
HandleSetT: Set of matching link handles.
"""

def get_atom(self, handle: str, **kwargs) -> AtomT:
Expand Down
Loading