Skip to content

Commit

Permalink
Merge pull request #226 from singnet/angelo/#223/avoiding-set-to-list…
Browse files Browse the repository at this point in the history
…-conversions

[#223] handling/using/returning `set` instead of `list` when possible
  • Loading branch information
angeloprobst authored Sep 27, 2024
2 parents 264f95b + ff33f90 commit 86252e7
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 274 deletions.
92 changes: 47 additions & 45 deletions hyperon_das_atomdb/adapters/ram_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
FieldIndexType,
FieldNames,
HandleListT,
HandleSetT,
HandleT,
IncomingLinksT,
LinkParamsT,
LinkT,
Expand All @@ -40,10 +42,10 @@ class Database:
atom_type: dict[str, Any] = dc_field(default_factory=dict)
node: dict[str, AtomT] = dc_field(default_factory=dict)
link: dict[str, AtomT] = dc_field(default_factory=dict)
outgoing_set: dict[str, list[str]] = dc_field(default_factory=dict)
incoming_set: dict[str, set[str]] = dc_field(default_factory=dict)
patterns: dict[str, set[str]] = dc_field(default_factory=dict)
templates: dict[str, set[str]] = dc_field(default_factory=dict)
outgoing_set: dict[str, HandleListT] = dc_field(default_factory=dict)
incoming_set: dict[str, HandleSetT] = dc_field(default_factory=dict)
patterns: dict[str, HandleSetT] = dc_field(default_factory=dict)
templates: dict[str, HandleSetT] = dc_field(default_factory=dict)


class InMemoryDB(AtomDB):
Expand All @@ -70,7 +72,7 @@ def __init__(self, database_name: str = "das"):
"""
self.database_name: str = database_name
self.named_type_table: dict[str, str] = {} # keyed by named type hash
self.all_named_types: set[str] = set()
self.all_named_types: HandleSetT = set()
self.db: Database = Database()

def _get_link(self, handle: str) -> dict[str, Any] | None:
Expand Down Expand Up @@ -172,46 +174,46 @@ def _delete_atom_type(self, _name: str) -> None:
self.db.atom_type.pop(key, None)
self.all_named_types.remove(_name)

def _add_outgoing_set(self, key: str, targets_hash: list[str]) -> None:
def _add_outgoing_set(self, key: str, targets_hash: HandleListT) -> None:
"""
Add an outgoing set to the database.
Args:
key (str): The key for the outgoing set.
targets_hash (list[str]): A list of target hashes to be added to the outgoing set.
targets_hash (HandleListT): A list of target hashes to be added to the outgoing set.
"""
self.db.outgoing_set[key] = targets_hash

def _get_and_delete_outgoing_set(self, handle: str) -> list[str] | None:
def _get_and_delete_outgoing_set(self, handle: str) -> HandleListT | None:
"""
Retrieve and delete an outgoing set from the database by its handle.
Args:
handle (str): The handle of the outgoing set to retrieve and delete.
Returns:
list[str] | None: The outgoing set if found and deleted, otherwise None.
HandleListT | None: The outgoing set if found and deleted, otherwise None.
"""
return self.db.outgoing_set.pop(handle, None)

def _add_incoming_set(self, key: str, targets_hash: list[str]) -> None:
def _add_incoming_set(self, key: str, targets_hash: Iterable[HandleT]) -> None:
"""
Add an incoming set to the database.
Args:
key (str): The key for the incoming set.
targets_hash (list[str]): A list of target hashes to be added to the incoming set.
targets_hash (Iterable[HandleT]): Target hashes to be added to the incoming set.
"""
for target_hash in targets_hash:
self.db.incoming_set.setdefault(target_hash, set()).add(key)

def _delete_incoming_set(self, link_handle: str, atoms_handle: list[str]) -> None:
def _delete_incoming_set(self, link_handle: str, atoms_handle: Iterable[HandleT]) -> None:
"""
Delete an incoming set from the database.
Args:
link_handle (str): The handle of the link to delete.
atoms_handle (list[str]): A list of atom handles associated with the link.
atoms_handle (Iterable[HandleT]): Atom handles associated with the link.
"""
for atom_handle in atoms_handle:
if handles := self.db.incoming_set.get(atom_handle):
Expand Down Expand Up @@ -261,14 +263,14 @@ def _delete_templates(self, link_document: dict) -> None:
if len(template_named_type) > 0:
template_named_type.remove(link_document[FieldNames.ID_HASH])

def _add_patterns(self, named_type_hash: str, key: str, targets_hash: list[str]) -> None:
def _add_patterns(self, named_type_hash: str, key: str, targets_hash: HandleListT) -> None:
"""
Add patterns to the database.
Args:
named_type_hash (str): The hash of the named type.
key (str): The key for the pattern.
targets_hash (list[str]): A list of target hashes to be added to the pattern.
targets_hash (HandleListT): A list of target hashes to be added to the pattern.
"""
pattern_keys = build_pattern_keys([named_type_hash, *targets_hash])

Expand All @@ -278,13 +280,13 @@ def _add_patterns(self, named_type_hash: str, key: str, targets_hash: list[str])
set(),
).add(key)

def _delete_patterns(self, link_document: dict, targets_hash: list[str]) -> None:
def _delete_patterns(self, link_document: dict, targets_hash: HandleListT) -> None:
"""
Delete patterns from the database.
Args:
link_document (dict): The document of the link whose patterns are to be deleted.
targets_hash (list[str]): A list of target hashes associated with the link.
targets_hash (HandleListT): A list of target hashes associated with the link.
"""
pattern_keys = build_pattern_keys([link_document[FieldNames.TYPE_NAME_HASH], *targets_hash])
for pattern_key in pattern_keys:
Expand All @@ -301,34 +303,34 @@ def _delete_link_and_update_index(self, link_handle: str) -> None:
if link_document := self._get_and_delete_link(link_handle):
self._update_index(atom=link_document, delete_atom=True)

def _filter_non_toplevel(self, matches: HandleListT) -> HandleListT:
def _filter_non_toplevel(self, matches: HandleSetT) -> HandleSetT:
"""
Filter out non-toplevel matches from the provided list.
Args:
matches (HandleListT): A list of matches
matches (HandleSetT): A set of matches
Returns:
HandleListT: Filtered matches
HandleSetT: Filtered matches
"""
if not self.db.link:
return matches
return [
return {
link_handle
for link_handle in matches
if (link := self.db.link.get(link_handle)) and link.get(FieldNames.IS_TOPLEVEL)
]
}

@staticmethod
def _build_targets_list(link: dict[str, Any]) -> list[Any]:
def _build_targets_list(link: dict[str, Any]) -> HandleListT:
"""
Build a list of target handles from the given link document.
Args:
link (dict[str, Any]): The link document from which to extract target handles.
Returns:
list[Any]: A list of target handles extracted from the link document.
HandleListT: A list of target handles extracted from the link document.
"""
return [
handle
Expand All @@ -341,7 +343,7 @@ def _update_atom_indexes(self, documents: Iterable[dict[str, Any]], **kwargs) ->
Update the indexes for the provided documents.
Args:
documents (Iterable[dict[str, any]]): An iterable of documents to update the indexes for.
documents (Iterable[dict[str, any]]): Documents to update the indexes for.
**kwargs: Additional keyword arguments that may be used for updating the indexes.
"""
for document in documents:
Expand Down Expand Up @@ -453,7 +455,7 @@ def get_node_type(self, node_handle: str) -> str | None:
details=f"node_handle: {node_handle}",
)

def get_node_by_name(self, node_type: str, substring: str) -> list[str]:
def get_node_by_name(self, node_type: str, substring: str) -> HandleListT:
node_type_hash = ExpressionHasher.named_type_hash(node_type)
return [
key
Expand All @@ -478,15 +480,14 @@ def get_all_nodes(self, node_type: str, names: bool = False) -> list[str]:
if node[FieldNames.COMPOSITE_TYPE_HASH] == node_type_hash
]

def get_all_links(self, link_type: str, **kwargs) -> tuple[int | None, list[str]]:
answer = [
def get_all_links(self, link_type: str, **kwargs) -> HandleSetT:
return {
link[FieldNames.ID_HASH]
for _, link in self.db.link.items()
if link[FieldNames.TYPE_NAME] == link_type
]
return kwargs.get("cursor"), answer
}

def get_link_handle(self, link_type: str, target_handles: list[str]) -> str:
def get_link_handle(self, link_type: str, target_handles: HandleListT) -> str:
link_handle = self.link_handle(link_type, target_handles)
if link_handle in self.db.link:
return link_handle
Expand All @@ -509,10 +510,10 @@ def get_link_type(self, link_handle: str) -> str | None:
details=f"link_handle: {link_handle}",
)

def get_link_targets(self, link_handle: str) -> list[str]:
def get_link_targets(self, link_handle: str) -> HandleListT:
answer = self.db.outgoing_set.get(link_handle)
if answer is not None:
return list(answer)
return answer
logger().error(
f"Failed to retrieve link targets for {link_handle}. This link may not exist."
)
Expand All @@ -521,21 +522,22 @@ def get_link_targets(self, link_handle: str) -> list[str]:
details=f"link_handle: {link_handle}",
)

def get_matched_links(self, link_type: str, target_handles: list[str], **kwargs) -> HandleListT:
def get_matched_links(
self, link_type: str, target_handles: HandleListT, **kwargs
) -> HandleSetT:
if link_type != WILDCARD and WILDCARD not in target_handles:
try:
answer = [self.get_link_handle(link_type, target_handles)]
return {self.get_link_handle(link_type, target_handles)}
except AtomDoesNotExist:
answer = []
return answer
return set()

link_type_hash = (
WILDCARD if link_type == WILDCARD else ExpressionHasher.named_type_hash(link_type)
)

pattern_hash = ExpressionHasher.composite_hash([link_type_hash, *target_handles])

patterns_matched = list(pattern) if (pattern := self.db.patterns.get(pattern_hash)) else []
patterns_matched = self.db.patterns.get(pattern_hash, set())

if kwargs.get("toplevel_only", False):
return self._filter_non_toplevel(patterns_matched)
Expand All @@ -548,24 +550,24 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:
return list(links)
return [self.get_atom(handle, **kwargs) for handle in links]

def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleListT:
def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT:
hash_base = self._build_named_type_hash_template(template)
template_hash = ExpressionHasher.composite_hash(hash_base)
templates_matched = list(self.db.templates.get(template_hash, set()))
templates_matched = self.db.templates.get(template_hash, set())
if kwargs.get("toplevel_only", False):
return self._filter_non_toplevel(templates_matched)
return templates_matched

def get_matched_type(self, link_type: str, **kwargs) -> HandleListT:
def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT:
link_type_hash = ExpressionHasher.named_type_hash(link_type)
templates_matched = list(self.db.templates.get(link_type_hash, set()))
templates_matched = self.db.templates.get(link_type_hash, set())
if kwargs.get("toplevel_only", False):
return self._filter_non_toplevel(templates_matched)
return templates_matched

def get_atoms_by_field(
self, query: list[OrderedDict[str, str]]
) -> list[str]: # pragma: no cover
) -> HandleListT: # pragma: no cover
raise NotImplementedError()

def get_atoms_by_index(
Expand All @@ -582,12 +584,12 @@ def get_atoms_by_text_field(
text_value: str,
field: str | None = None,
text_index_id: str | None = None,
) -> list[str]: # pragma: no cover
) -> HandleListT: # pragma: no cover
raise NotImplementedError()

def get_node_by_name_starting_with(
self, node_type: str, startswith: str
) -> list[str]: # pragma: no cover
) -> HandleListT: # pragma: no cover
raise NotImplementedError()

def _get_atom(self, handle: str) -> AtomT | None:
Expand Down
Loading

0 comments on commit 86252e7

Please sign in to comment.