From 19f0ecc923d97e4081f492547d6fc8168af39257 Mon Sep 17 00:00:00 2001 From: Angelo Probst Date: Tue, 22 Oct 2024 13:54:02 -0300 Subject: [PATCH 1/6] syncing Atom DB Python with C++ implementation --- Makefile | 3 +- hyperon_das_atomdb/adapters/ram_only.py | 708 +------------- hyperon_das_atomdb/adapters/redis_mongo_db.py | 248 ++--- hyperon_das_atomdb/database.py | 924 +----------------- hyperon_das_atomdb/exceptions.py | 56 +- hyperon_das_atomdb/index.py | 5 +- hyperon_das_atomdb/utils/expression_hasher.py | 12 +- hyperon_das_atomdb/utils/patterns.py | 81 -- pyproject.toml | 2 + tests/helpers.py | 62 ++ .../integration/adapters/test_redis_mongo.py | 382 ++++---- tests/integration/scripts/mongo-down.sh | 5 +- tests/unit/adapters/test_ram_only.py | 907 +++++++---------- tests/unit/adapters/test_ram_only_extra.py | 3 + tests/unit/adapters/test_redis_mongo_db.py | 124 ++- tests/unit/adapters/test_redis_mongo_extra.py | 3 +- tests/unit/test_database_private_methods.py | 97 +- tests/unit/test_database_public_methods.py | 385 +++++--- 18 files changed, 1183 insertions(+), 2824 deletions(-) delete mode 100644 hyperon_das_atomdb/utils/patterns.py create mode 100644 tests/helpers.py diff --git a/Makefile b/Makefile index 85046197..e3438500 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ pylint: mypy: @unbuffer mypy --color-output --config-file mypy.ini ./hyperon_das_atomdb -lint: isort black flake8 +lint: isort black flake8 pylint mypy unit-tests: @py.test -sx -vv ./tests/unit @@ -25,3 +25,4 @@ integration-tests: @py.test -sx -vv ./tests/integration pre-commit: lint unit-tests-coverage unit-tests integration-tests + diff --git a/hyperon_das_atomdb/adapters/ram_only.py b/hyperon_das_atomdb/adapters/ram_only.py index 70841e0e..2ee25202 100644 --- a/hyperon_das_atomdb/adapters/ram_only.py +++ b/hyperon_das_atomdb/adapters/ram_only.py @@ -1,707 +1,3 @@ -""" -This module provides an in-memory implementation of the AtomDB interface using hashtables (dict). +from hyperon_das_atomdb_cpp.adapters import InMemoryDB -The InMemoryDB class offers methods for managing nodes and links, including adding, deleting, -and retrieving them. It also supports indexing and pattern matching for efficient querying. - -Classes: - Database: A dataclass representing the structure of the in-memory database. - InMemoryDB: A concrete implementation of the AtomDB interface using hashtables. -""" - -from collections import OrderedDict -from dataclasses import dataclass -from dataclasses import field as dc_field -from typing import Any, Iterable - -from hyperon_das_atomdb.database import ( - WILDCARD, - AtomDB, - AtomT, - FieldIndexType, - FieldNames, - HandleListT, - HandleSetT, - HandleT, - IncomingLinksT, - LinkParamsT, - LinkT, - NodeParamsT, - NodeT, -) -from hyperon_das_atomdb.exceptions import AtomDoesNotExist -from hyperon_das_atomdb.logger import logger -from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher -from hyperon_das_atomdb.utils.patterns import build_pattern_keys - - -@dataclass -class Database: - """Dataclass representing the structure of the in-memory database""" - - atom_type: dict[str, Any] = dc_field(default_factory=dict) - node: dict[str, AtomT] = dc_field(default_factory=dict) - link: dict[str, AtomT] = dc_field(default_factory=dict) - outgoing_set: dict[str, HandleListT] = dc_field(default_factory=dict) - incoming_set: dict[str, HandleSetT] = dc_field(default_factory=dict) - patterns: dict[str, HandleSetT] = dc_field(default_factory=dict) - templates: dict[str, HandleSetT] = dc_field(default_factory=dict) - - -class InMemoryDB(AtomDB): - """A concrete implementation using hashtable (dict)""" - - def __repr__(self) -> str: - """ - Return a string representation of the InMemoryDB instance. - - This method is intended to provide a human-readable representation of the - InMemoryDB instance, which can be useful for debugging and logging purposes. - - Returns: - str: A string representing the InMemoryDB instance. - """ - return "" # pragma no cover - - def __init__(self, database_name: str = "das"): - """ - Initialize an InMemoryDB instance. - - Args: - database_name (str): The name of the database. Defaults to "das". - """ - self.database_name: str = database_name - self.named_type_table: dict[str, str] = {} # keyed by named type hash - self.all_named_types: HandleSetT = set() - self.db: Database = Database() - - def _get_link(self, handle: str) -> dict[str, Any] | None: - """ - Retrieve a link from the database by its handle. - - Args: - handle (str): The handle of the link to retrieve. - - Returns: - dict[str, Any] | None: The link document if found, otherwise None. - """ - return self.db.link.get(handle, None) - - def _get_and_delete_link(self, link_handle: str) -> dict[str, Any] | None: - """ - Retrieve and delete a link from the database by its handle. - - Args: - link_handle (str): The handle of the link to retrieve and delete. - - Returns: - dict[str, Any] | None: The link document if found and deleted, otherwise None. - """ - return self.db.link.pop(link_handle, None) - - def _build_named_type_hash_template(self, template: str | list[Any]) -> str | list[Any]: - """ - Build a named type hash template from the given template. - - Args: - template (str | list[Any]): The template to build the named type hash from. It can be - either a string or a list of elements. - - Returns: - str | list[Any]: The named type hash if the template is a string, or a list of named - type hashes if the template is a list. - """ - if isinstance(template, str): - return ExpressionHasher.named_type_hash(template) - return [self._build_named_type_hash_template(element) for element in template] - - @staticmethod - def _build_atom_type_key_hash(_name: str) -> str: - """ - Build a hash key for the given atom type name. - - Args: - _name (str): The name of the atom type. - - Returns: - str: The hash key for the atom type. - """ - name_hash = ExpressionHasher.named_type_hash(_name) - type_hash = ExpressionHasher.named_type_hash("Type") - typedef_mark_hash = ExpressionHasher.named_type_hash(":") - return ExpressionHasher.expression_hash(typedef_mark_hash, [name_hash, type_hash]) - - def _add_atom_type(self, atom_type_name: str, atom_type: str = "Type") -> None: - """ - Add a type atom to the database. - - Args: - atom_type_name (str): The name of the atom to add. - atom_type (str): The type of the atom. Defaults to "Type". - """ - if atom_type_name in self.all_named_types: - return - - self.all_named_types.add(atom_type_name) - name_hash = ExpressionHasher.named_type_hash(atom_type_name) - type_hash = ExpressionHasher.named_type_hash(atom_type) - typedef_mark_hash = ExpressionHasher.named_type_hash(":") - - key = ExpressionHasher.expression_hash(typedef_mark_hash, [name_hash, type_hash]) - - _atom_type = self.db.atom_type.get(key) - if _atom_type is None: - base_type_hash = ExpressionHasher.named_type_hash("Type") - composite_type = [typedef_mark_hash, type_hash, base_type_hash] - composite_type_hash = ExpressionHasher.composite_hash(composite_type) - _atom_type = { - FieldNames.ID_HASH: key, - FieldNames.COMPOSITE_TYPE_HASH: composite_type_hash, - FieldNames.TYPE_NAME: atom_type_name, - FieldNames.TYPE_NAME_HASH: name_hash, - } - self.db.atom_type[key] = _atom_type - self.named_type_table[name_hash] = atom_type_name - - def _delete_atom_type(self, _name: str) -> None: - """ - Delete an atom type from the database. - - Args: - _name (str): The name of the atom type to delete. - """ - key = self._build_atom_type_key_hash(_name) - self.db.atom_type.pop(key, None) - self.all_named_types.remove(_name) - - def _add_outgoing_set(self, key: str, targets_hash: HandleListT) -> None: - """ - Add an outgoing set to the database. - - Args: - key (str): The key for the outgoing set. - targets_hash (HandleListT): A list of target hashes to be added to the outgoing set. - """ - self.db.outgoing_set[key] = targets_hash - - def _get_and_delete_outgoing_set(self, handle: str) -> HandleListT | None: - """ - Retrieve and delete an outgoing set from the database by its handle. - - Args: - handle (str): The handle of the outgoing set to retrieve and delete. - - Returns: - HandleListT | None: The outgoing set if found and deleted, otherwise None. - """ - return self.db.outgoing_set.pop(handle, None) - - def _add_incoming_set(self, key: str, targets_hash: Iterable[HandleT]) -> None: - """ - Add an incoming set to the database. - - Args: - key (str): The key for the incoming set. - targets_hash (Iterable[HandleT]): Target hashes to be added to the incoming set. - """ - for target_hash in targets_hash: - self.db.incoming_set.setdefault(target_hash, set()).add(key) - - def _delete_incoming_set(self, link_handle: str, atoms_handle: Iterable[HandleT]) -> None: - """ - Delete an incoming set from the database. - - Args: - link_handle (str): The handle of the link to delete. - atoms_handle (Iterable[HandleT]): Atom handles associated with the link. - """ - for atom_handle in atoms_handle: - if handles := self.db.incoming_set.get(atom_handle): - handles.remove(link_handle) - - def _add_templates( - self, - composite_type_hash: str, - named_type_hash: str, - key: str, - ) -> None: - """ - Add templates to the database. - - Args: - composite_type_hash (str): The hash of the composite type. - named_type_hash (str): The hash of the named type. - key (str): The key for the template. - """ - template_composite_type_hash = self.db.templates.get(composite_type_hash) - template_named_type_hash = self.db.templates.get(named_type_hash) - - if template_composite_type_hash is not None: - template_composite_type_hash.add(key) - else: - self.db.templates[composite_type_hash] = {key} - - if template_named_type_hash is not None: - template_named_type_hash.add(key) - else: - self.db.templates[named_type_hash] = {key} - - def _delete_templates(self, link_document: dict) -> None: - """ - Delete templates from the database. - - Args: - link_document (dict): The document of the link whose templates are to be deleted. - """ - template_composite_type = self.db.templates.get( - link_document[FieldNames.COMPOSITE_TYPE_HASH], set() - ) - if len(template_composite_type) > 0: - template_composite_type.remove(link_document[FieldNames.ID_HASH]) - - template_named_type = self.db.templates.get(link_document[FieldNames.TYPE_NAME_HASH], set()) - if len(template_named_type) > 0: - template_named_type.remove(link_document[FieldNames.ID_HASH]) - - def _add_patterns(self, named_type_hash: str, key: str, targets_hash: HandleListT) -> None: - """ - Add patterns to the database. - - Args: - named_type_hash (str): The hash of the named type. - key (str): The key for the pattern. - targets_hash (HandleListT): A list of target hashes to be added to the pattern. - """ - pattern_keys = build_pattern_keys([named_type_hash, *targets_hash]) - - for pattern_key in pattern_keys: - self.db.patterns.setdefault( - pattern_key, - set(), - ).add(key) - - def _delete_patterns(self, link_document: dict, targets_hash: HandleListT) -> None: - """ - Delete patterns from the database. - - Args: - link_document (dict): The document of the link whose patterns are to be deleted. - targets_hash (HandleListT): A list of target hashes associated with the link. - """ - pattern_keys = build_pattern_keys([link_document[FieldNames.TYPE_NAME_HASH], *targets_hash]) - for pattern_key in pattern_keys: - if pattern := self.db.patterns.get(pattern_key): - pattern.remove(link_document[FieldNames.ID_HASH]) - - def _delete_link_and_update_index(self, link_handle: str) -> None: - """ - Delete a link from the database and update the index. - - Args: - link_handle (str): The handle of the link to delete. - """ - if link_document := self._get_and_delete_link(link_handle): - self._update_index(atom=link_document, delete_atom=True) - - def _filter_non_toplevel(self, matches: HandleSetT) -> HandleSetT: - """ - Filter out non-toplevel matches from the provided list. - - Args: - matches (HandleSetT): A set of matches - - Returns: - HandleSetT: Filtered matches - """ - if not self.db.link: - return matches - return { - link_handle - for link_handle in matches - if (link := self.db.link.get(link_handle)) and link.get(FieldNames.IS_TOPLEVEL) - } - - @staticmethod - def _build_targets_list(link: dict[str, Any]) -> HandleListT: - """ - Build a list of target handles from the given link document. - - Args: - link (dict[str, Any]): The link document from which to extract target handles. - - Returns: - HandleListT: A list of target handles extracted from the link document. - """ - return [ - handle - for count in range(len(link)) - if (handle := link.get(f"key_{count}", None)) is not None - ] - - def _update_atom_indexes(self, documents: Iterable[dict[str, Any]], **kwargs) -> None: - """ - Update the indexes for the provided documents. - - Args: - documents (Iterable[dict[str, any]]): Documents to update the indexes for. - **kwargs: Additional keyword arguments that may be used for updating the indexes. - """ - for document in documents: - self._update_index(document, **kwargs) - - def _delete_atom_index(self, atom: AtomT) -> None: - """ - Delete an atom from the index. - - Args: - atom (AtomT): The atom to delete from the index. - - Raises: - AtomDoesNotExist: If the atom does not exist in the database. - """ - link_handle = atom[FieldNames.ID_HASH] - - handles = self.db.incoming_set.pop(link_handle, None) - - if handles: - for handle in handles: - self._delete_link_and_update_index(handle) - - outgoing_atoms = self._get_and_delete_outgoing_set(link_handle) - - if outgoing_atoms: - self._delete_incoming_set(link_handle, outgoing_atoms) - - targets_hash = self._build_targets_list(atom) - self._delete_templates(atom) - self._delete_patterns(atom, targets_hash) - - def _add_atom_index(self, atom: AtomT) -> None: - """ - Add an atom to the index. - - Args: - atom (AtomT): The atom to add to the index. - - Raises: - AtomDoesNotExist: If the atom does not exist in the database. - """ - atom_type_name = atom[FieldNames.TYPE_NAME] - self._add_atom_type(atom_type_name=atom_type_name) - if FieldNames.NODE_NAME not in atom: - handle = atom[FieldNames.ID_HASH] - targets_hash = self._build_targets_list(atom) - self._add_outgoing_set(handle, targets_hash) - self._add_incoming_set(handle, targets_hash) - self._add_templates( - atom[FieldNames.COMPOSITE_TYPE_HASH], - atom[FieldNames.TYPE_NAME_HASH], - handle, - ) - self._add_patterns(atom[FieldNames.TYPE_NAME_HASH], handle, targets_hash) - - def _update_index(self, atom: AtomT, **kwargs) -> None: - """ - Update the index for the provided atom. - - Args: - atom (AtomT): The atom document to update the index for. - **kwargs: Additional keyword arguments that may be used for updating the index. - - delete_atom (bool): If True, the atom will be deleted from the index. - - Raises: - AtomDoesNotExist: If the atom does not exist when attempting to delete it. - """ - if kwargs.get("delete_atom", False): - self._delete_atom_index(atom) - else: - self._add_atom_index(atom) - - def get_node_handle(self, node_type: str, node_name: str) -> str: - node_handle = self.node_handle(node_type, node_name) - if node_handle in self.db.node: - return node_handle - logger().error( - f"Failed to retrieve node handle for {node_type}:{node_name}. " - "This node may not exist." - ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"{node_type}:{node_name}", - ) - - def get_node_name(self, node_handle: str) -> str: - node = self.db.node.get(node_handle) - if node is None: - logger().error( - f"Failed to retrieve node name for handle: {node_handle}. This node may not exist." - ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"node_handle: {node_handle}", - ) - return node[FieldNames.NODE_NAME] - - def get_node_type(self, node_handle: str) -> str | None: - node = self.db.node.get(node_handle) - # TODO(angelo): here should we return None if `node` is `None` like redis_mongo_db does? - if node is not None: - return node[FieldNames.TYPE_NAME] - logger().error( - f"Failed to retrieve node type for handle: {node_handle}. This node may not exist." - ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"node_handle: {node_handle}", - ) - - def get_node_by_name(self, node_type: str, substring: str) -> HandleListT: - node_type_hash = ExpressionHasher.named_type_hash(node_type) - return [ - key - for key, value in self.db.node.items() - if substring in value[FieldNames.NODE_NAME] - and node_type_hash == value[FieldNames.COMPOSITE_TYPE_HASH] - ] - - def get_all_nodes(self, node_type: str, names: bool = False) -> list[str]: - node_type_hash = ExpressionHasher.named_type_hash(node_type) - - if names: - return [ - node[FieldNames.NODE_NAME] - for node in self.db.node.values() - if node[FieldNames.COMPOSITE_TYPE_HASH] == node_type_hash - ] - - return [ - handle - for handle, node in self.db.node.items() - if node[FieldNames.COMPOSITE_TYPE_HASH] == node_type_hash - ] - - def get_all_links(self, link_type: str, **kwargs) -> HandleSetT: - return { - link[FieldNames.ID_HASH] - for _, link in self.db.link.items() - if link[FieldNames.TYPE_NAME] == link_type - } - - def get_link_handle(self, link_type: str, target_handles: HandleListT) -> str: - link_handle = self.link_handle(link_type, target_handles) - if link_handle in self.db.link: - return link_handle - logger().error( - f"Failed to retrieve link handle for {link_type}:{target_handles}. " - f"This link may not exist." - ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"{link_type}:{target_handles}", - ) - - def get_link_type(self, link_handle: str) -> str | None: - link = self._get_link(link_handle) - if link is not None: - return link[FieldNames.TYPE_NAME] - logger().error(f"Failed to retrieve link type for {link_handle}. This link may not exist.") - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"link_handle: {link_handle}", - ) - - def get_link_targets(self, link_handle: str) -> HandleListT: - answer = self.db.outgoing_set.get(link_handle) - if answer is not None: - return answer - logger().error( - f"Failed to retrieve link targets for {link_handle}. This link may not exist." - ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"link_handle: {link_handle}", - ) - - def get_matched_links( - self, link_type: str, target_handles: HandleListT, **kwargs - ) -> HandleSetT: - if link_type != WILDCARD and WILDCARD not in target_handles: - try: - return {self.get_link_handle(link_type, target_handles)} - except AtomDoesNotExist: - return set() - - link_type_hash = ( - WILDCARD if link_type == WILDCARD else ExpressionHasher.named_type_hash(link_type) - ) - - pattern_hash = ExpressionHasher.composite_hash([link_type_hash, *target_handles]) - - patterns_matched = self.db.patterns.get(pattern_hash, set()) - - if kwargs.get("toplevel_only", False): - return self._filter_non_toplevel(patterns_matched) - - return patterns_matched - - def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT: - links = self.db.incoming_set.get(atom_handle, set()) - if kwargs.get("handles_only", False): - return list(links) - return [self.get_atom(handle, **kwargs) for handle in links] - - def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT: - hash_base = self._build_named_type_hash_template(template) - template_hash = ExpressionHasher.composite_hash(hash_base) - templates_matched = self.db.templates.get(template_hash, set()) - if kwargs.get("toplevel_only", False): - return self._filter_non_toplevel(templates_matched) - return templates_matched - - def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT: - link_type_hash = ExpressionHasher.named_type_hash(link_type) - templates_matched = self.db.templates.get(link_type_hash, set()) - if kwargs.get("toplevel_only", False): - return self._filter_non_toplevel(templates_matched) - return templates_matched - - def get_atoms_by_field( - self, query: list[OrderedDict[str, str]] - ) -> HandleListT: # pragma: no cover - raise NotImplementedError() - - def get_atoms_by_index( - self, - index_id: str, - query: list[OrderedDict[str, str]], - cursor: int = 0, - chunk_size: int = 500, - ) -> tuple[int, list[AtomT]]: # pragma: no cover - raise NotImplementedError() - - def get_atoms_by_text_field( - self, - text_value: str, - field: str | None = None, - text_index_id: str | None = None, - ) -> HandleListT: # pragma: no cover - raise NotImplementedError() - - def get_node_by_name_starting_with( - self, node_type: str, startswith: str - ) -> HandleListT: # pragma: no cover - raise NotImplementedError() - - def _get_atom(self, handle: str) -> AtomT | None: - return self.db.node.get(handle) or self._get_link(handle) - - def get_atom_type(self, handle: str) -> str | None: - atom = node if (node := self.db.node.get(handle)) else self._get_link(handle) - return atom.get(FieldNames.TYPE_NAME) if atom else None - - def get_atom_as_dict(self, handle: str, arity: int | None = 0) -> dict[str, Any]: - atom = self.db.node.get(handle) - if atom is not None: - return { - "handle": atom[FieldNames.ID_HASH], - "type": atom[FieldNames.TYPE_NAME], - "name": atom[FieldNames.NODE_NAME], - } - atom = self._get_link(handle) - if atom is not None: - return { - "handle": atom[FieldNames.ID_HASH], - "type": atom[FieldNames.TYPE_NAME], - "targets": self._build_targets_list(atom), - } - logger().error(f"Failed to retrieve atom for handle: {handle}. This link may not exist.") - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"handle: {handle}", - ) - - def count_atoms(self, parameters: dict[str, Any] | None = None) -> dict[str, int]: - node_count = len(self.db.node) - link_count = len(self.db.link) - atom_count = node_count + link_count - return { - "atom_count": atom_count, - "node_count": node_count, - "link_count": link_count, - } - - def clear_database(self) -> None: - self.named_type_table = {} - self.all_named_types = set() - self.db = Database() - - def add_node(self, node_params: NodeParamsT) -> NodeT | None: - handle, node = self._build_node(node_params) - self.db.node[handle] = node - self._update_index(node) - return node - - def add_link(self, link_params: LinkParamsT, toplevel: bool = True) -> LinkT | None: - r = self._build_link(link_params, toplevel) - if r is None: - return None - handle, link, _ = r - self.db.link[handle] = link - self._update_index(link) - return link - - def reindex( - self, pattern_index_templates: dict[str, list[dict[str, Any]]] | None = None - ) -> None: # pragma: no cover - raise NotImplementedError() - - def delete_atom(self, handle: str, **kwargs) -> None: - node = self.db.node.pop(handle, None) - - if node: - handles = self.db.incoming_set.pop(handle) - if handles: - for h in handles: - self._delete_link_and_update_index(h) - else: - try: - self._delete_link_and_update_index(handle) - except AtomDoesNotExist as ex: - logger().error( - f"Failed to delete atom for handle: {handle}. " - f"This atom may not exist. - Details: {kwargs}" - ) - ex.details = f"handle: {handle}" - raise ex - - def create_field_index( - self, - atom_type: str, - fields: list[str], - named_type: str | None = None, - composite_type: list[Any] | None = None, - index_type: FieldIndexType | None = None, - ) -> str: # pragma: no cover - raise NotImplementedError() - - def bulk_insert(self, documents: list[AtomT]) -> None: - try: - for document in documents: - handle = document[FieldNames.ID_HASH] - if FieldNames.NODE_NAME in document: - self.db.node[handle] = document - else: - self.db.link[handle] = document - self._update_index(document) - except Exception as e: # pylint: disable=broad-except - logger().error(f"Error bulk inserting documents: {str(e)}") - - def retrieve_all_atoms(self) -> list[AtomT]: - try: - return list(self.db.node.values()) + list(self.db.link.values()) - except Exception as e: - logger().error(f"Error retrieving all atoms: {str(e)}") - raise e - - def commit(self, **kwargs) -> None: # pragma: no cover - raise NotImplementedError() +__all__ = ["InMemoryDB"] diff --git a/hyperon_das_atomdb/adapters/redis_mongo_db.py b/hyperon_das_atomdb/adapters/redis_mongo_db.py index 6d4438fc..928d14c0 100644 --- a/hyperon_das_atomdb/adapters/redis_mongo_db.py +++ b/hyperon_das_atomdb/adapters/redis_mongo_db.py @@ -13,7 +13,7 @@ import sys from copy import deepcopy from enum import Enum -from typing import Any, Iterable, Mapping, Optional, OrderedDict +from typing import Any, Iterable, Mapping, Optional, OrderedDict, TypeAlias from pymongo import ASCENDING, MongoClient from pymongo import errors as pymongo_errors @@ -30,10 +30,7 @@ FieldNames, HandleListT, HandleSetT, - IncomingLinksT, - LinkParamsT, LinkT, - NodeParamsT, NodeT, ) from hyperon_das_atomdb.exceptions import ( @@ -45,6 +42,11 @@ from hyperon_das_atomdb.logger import logger from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher +# pylint: disable=invalid-name +# Type aliases +DocumentT: TypeAlias = dict[str, Any] +# pylint: enable=invalid-name + def _build_redis_key(prefix: str, key: str) -> str: """ @@ -90,7 +92,7 @@ class MongoIndexType(str, Enum): class _HashableDocument: """Class for making documents hashable.""" - def __init__(self, base: dict[str, Any]): + def __init__(self, base: DocumentT): self.base = base def __hash__(self) -> int: @@ -175,6 +177,7 @@ def __repr__(self) -> str: def __init__(self, **kwargs: Optional[dict[str, Any]]) -> None: """Initialize an instance of a custom class with Redis and MongoDB connections.""" + super().__init__() self.database_name = "das" self._setup_databases(**kwargs) @@ -185,23 +188,16 @@ def __init__(self, **kwargs: Optional[dict[str, Any]]) -> None: (MongoCollectionNames.ATOMS, self.mongo_atoms_collection), (MongoCollectionNames.ATOM_TYPES, self.mongo_types_collection), ] - self.pattern_index_templates: dict[str, list[dict[str, Any]]] | None = None + self.pattern_index_templates: dict[str, list[DocumentT]] | None = None self.mongo_das_config_collection: Collection | None = None if MongoCollectionNames.DAS_CONFIG in self.mongo_db.list_collection_names(): self.mongo_das_config_collection = self.mongo_db.get_collection( MongoCollectionNames.DAS_CONFIG ) - # TODO(angelo,andre): remove '_' from `ExpressionHasher._compute_hash` method? - self.wildcard_hash = ExpressionHasher._compute_hash( - WILDCARD - ) # pylint: disable=protected-access - self.typedef_mark_hash = ExpressionHasher._compute_hash( - ":" - ) # pylint: disable=protected-access - self.typedef_base_type_hash = ExpressionHasher._compute_hash( - "Type" - ) # pylint: disable=protected-access + self.wildcard_hash = ExpressionHasher.compute_hash(WILDCARD) + self.typedef_mark_hash = ExpressionHasher.compute_hash(":") + self.typedef_base_type_hash = ExpressionHasher.compute_hash("Type") self.named_type_hash: dict[str, str] = {} self.hash_length = len(self.typedef_base_type_hash) @@ -413,7 +409,7 @@ def _setup_indexes(self) -> None: # NOTE creating index for name search self.create_field_index("node", fields=["name"]) - def _retrieve_document(self, handle: str) -> dict[str, Any] | None: + def _retrieve_document(self, handle: str) -> DocumentT | None: """ Retrieve a document from the MongoDB collection using the given handle. @@ -425,7 +421,7 @@ def _retrieve_document(self, handle: str) -> dict[str, Any] | None: handle (str): The unique identifier for the document to be retrieved. Returns: - dict[str, Any] | None: The retrieved document if found, otherwise None. + DocumentT | None: The retrieved document if found, otherwise None. """ mongo_filter = {FieldNames.ID_HASH: handle} if document := self.mongo_atoms_collection.find_one(mongo_filter): @@ -463,7 +459,7 @@ def _build_named_type_hash_template(self, template: str | list[Any]) -> str | li return [self._build_named_type_hash_template(element) for element in template] @staticmethod - def _get_document_keys(document: dict[str, Any]) -> HandleListT: + def _get_document_keys(document: DocumentT) -> HandleListT: """ Retrieve the keys from the given document. @@ -472,18 +468,20 @@ def _get_document_keys(document: dict[str, Any]) -> HandleListT: a specific prefix pattern. Args: - document (dict[str, Any]): The document from which to retrieve the keys. + document (DocumentT): The document from which to retrieve the keys. Returns: HandleListT: A list of keys extracted from the document. """ - answer: HandleListT | None = document.get(FieldNames.KEYS, None) - if answer is not None: + answer = document.get(FieldNames.TARGETS, document.get(FieldNames.KEYS, None)) + if isinstance(answer, list): return answer + elif isinstance(answer, dict): + return list(answer.values()) answer = [] index = 0 - while (key := document.get(f"{FieldNames.KEY_PREFIX.value}_{index}", None)) is not None: + while (key := document.get(f"{FieldNames.KEY_PREFIX}_{index}", None)) is not None: answer.append(key) index += 1 return answer @@ -505,7 +503,10 @@ def _filter_non_toplevel(self, matches: HandleSetT) -> HandleSetT: return { link_handle for link_handle in matches - if (link := self._retrieve_document(link_handle)) and link.get(FieldNames.IS_TOPLEVEL) + if ( + (link := self._retrieve_document(link_handle)) + and link.get(FieldNames.IS_TOPLEVEL, False) + ) } def get_node_handle(self, node_type: str, node_name: str) -> str: @@ -518,10 +519,7 @@ def get_node_handle(self, node_type: str, node_name: str) -> str: f"Failed to retrieve node handle for {node_type}:{node_name}. " f"This node may not exist." ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"{node_type}:{node_name}", - ) + raise AtomDoesNotExist("Nonexistent atom", f"{node_type}:{node_name}") def get_node_name(self, node_handle: str) -> str: answer = self._retrieve_name(node_handle) @@ -535,7 +533,7 @@ def get_node_name(self, node_handle: str) -> str: def get_node_type(self, node_handle: str) -> str | None: document = self.get_atom(node_handle) - return document[FieldNames.TYPE_NAME] + return document.named_type if isinstance(document, NodeT) else None # type: ignore def get_node_by_name(self, node_type: str, substring: str) -> HandleListT: node_type_hash = ExpressionHasher.named_type_hash(node_type) @@ -558,7 +556,7 @@ def get_atoms_by_field(self, query: list[OrderedDict[str, str]]) -> HandleListT: def get_atoms_by_index( self, index_id: str, - query: list[OrderedDict[str, str]], + query: list[dict[str, Any]], cursor: int = 0, chunk_size: int = 500, ) -> tuple[int, list[AtomT]]: @@ -604,17 +602,17 @@ def get_node_by_name_starting_with(self, node_type: str, startswith: str): for document in self.mongo_atoms_collection.find(mongo_filter) ] - def get_all_nodes(self, node_type: str, names: bool = False) -> list[str]: - if names: - return [ - document[FieldNames.NODE_NAME] - for document in self.mongo_atoms_collection.find({FieldNames.TYPE_NAME: node_type}) - ] - else: - return [ - document[FieldNames.ID_HASH] - for document in self.mongo_atoms_collection.find({FieldNames.TYPE_NAME: node_type}) - ] + def get_all_nodes_handles(self, node_type: str) -> list[str]: + return [ + document[FieldNames.ID_HASH] + for document in self.mongo_atoms_collection.find({FieldNames.TYPE_NAME: node_type}) + ] + + def get_all_nodes_names(self, node_type: str) -> list[str]: + return [ + document[FieldNames.NODE_NAME] + for document in self.mongo_atoms_collection.find({FieldNames.TYPE_NAME: node_type}) + ] def get_all_links(self, link_type: str, **kwargs) -> HandleSetT: pymongo_cursor = self.mongo_atoms_collection.find({FieldNames.TYPE_NAME: link_type}) @@ -630,10 +628,7 @@ def get_link_handle(self, link_type: str, target_handles: HandleListT) -> str: f"Failed to retrieve link handle for {link_type}:{target_handles}. " "This link may not exist." ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"{link_type}:{target_handles}", - ) + raise AtomDoesNotExist("Nonexistent atom", f"{link_type}:{target_handles}") def get_link_targets(self, link_handle: str) -> HandleListT: answer = self._retrieve_outgoing_set(link_handle) @@ -660,28 +655,26 @@ def get_matched_links( ) pattern_hash = ExpressionHasher.composite_hash([link_type_hash, *target_handles]) - patterns_matched = self._retrieve_hash_targets_value( - KeyPrefix.PATTERNS, pattern_hash, **kwargs - ) + patterns_matched = self._retrieve_hash_targets_value(KeyPrefix.PATTERNS, pattern_hash) if kwargs.get("toplevel_only", False): return self._filter_non_toplevel(patterns_matched) else: return patterns_matched - def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT: + def get_incoming_links_handles(self, atom_handle: str, **kwargs) -> HandleListT: links = self._retrieve_incoming_set(atom_handle, **kwargs) + return list(links) - if kwargs.get("handles_only", False): - return list(links) - else: - return [self.get_atom(handle, **kwargs) for handle in links] + def get_incoming_links_atoms(self, atom_handle: str, **kwargs) -> list[AtomT]: + links = self._retrieve_incoming_set(atom_handle, **kwargs) + return [self.get_atom(handle, **kwargs) for handle in links] def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT: try: hash_base: HandleListT = self._build_named_type_hash_template(template) # type: ignore template_hash = ExpressionHasher.composite_hash(hash_base) templates_matched = self._retrieve_hash_targets_value( - KeyPrefix.TEMPLATES, template_hash, **kwargs + KeyPrefix.TEMPLATES, template_hash ) if kwargs.get("toplevel_only", False): return self._filter_non_toplevel(templates_matched) @@ -693,9 +686,7 @@ def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT: named_type_hash = ExpressionHasher.named_type_hash(link_type) - templates_matched = self._retrieve_hash_targets_value( - KeyPrefix.TEMPLATES, named_type_hash, **kwargs - ) + templates_matched = self._retrieve_hash_targets_value(KeyPrefix.TEMPLATES, named_type_hash) if kwargs.get("toplevel_only", False): return self._filter_non_toplevel(templates_matched) else: @@ -703,13 +694,47 @@ def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT: def get_link_type(self, link_handle: str) -> str | None: document = self.get_atom(link_handle) - return document[FieldNames.TYPE_NAME] + return document.named_type if isinstance(document, LinkT) else None # type: ignore + + def _build_atom_from_dict(self, document: DocumentT) -> AtomT: + """ + Builds an Atom object from a dictionary. + + Args: + document (DocumentT): The dictionary representing the atom. + + Returns: + AtomT: The constructed Atom object. + """ + if "targets" in document: + link = LinkT( + handle=document[FieldNames.ID_HASH], + _id=document[FieldNames.ID_HASH], + named_type=document[FieldNames.TYPE_NAME], + targets=document[FieldNames.TARGETS], + composite_type=document[FieldNames.COMPOSITE_TYPE], + is_toplevel=document.get(FieldNames.IS_TOPLEVEL, True), + named_type_hash=document[FieldNames.TYPE_NAME_HASH], + composite_type_hash=document[FieldNames.COMPOSITE_TYPE_HASH], + custom_attributes=document.get(FieldNames.CUSTOM_ATTRIBUTES, dict()), + ) + return link + else: + node = NodeT( + handle=document[FieldNames.ID_HASH], + _id=document[FieldNames.ID_HASH], + named_type=document[FieldNames.TYPE_NAME], + name=document[FieldNames.NODE_NAME], + composite_type_hash=document[FieldNames.COMPOSITE_TYPE_HASH], + custom_attributes=document.get(FieldNames.CUSTOM_ATTRIBUTES, dict()), + ) + return node def _get_atom(self, handle: str) -> AtomT | None: - try: - return self.get_atom_as_dict(handle) - except AtomDoesNotExist: + document = self._retrieve_document(handle) + if not document: return None + return self._build_atom_from_dict(document) def get_atom_type(self, handle: str) -> str | None: atom = self._retrieve_document(handle) @@ -717,22 +742,6 @@ def get_atom_type(self, handle: str) -> str | None: return None return atom[FieldNames.TYPE_NAME] - def get_atom_as_dict(self, handle: str, arity: int | None = 0) -> AtomT: - document = self._retrieve_document(handle) - if document: - document["handle"] = document[FieldNames.ID_HASH] - document["type"] = document[FieldNames.TYPE_NAME] - if "targets" in document: - document["targets"] = document["targets"] - else: - document["name"] = document["name"] - return document - logger().error(f"Failed to retrieve atom for handle: {handle}. This link may not exist.") - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"handle: {handle}", - ) - def count_atoms(self, parameters: dict[str, Any] | None = None) -> dict[str, int]: atom_count = self.mongo_atoms_collection.estimated_document_count() return_count = {"atom_count": atom_count} @@ -792,30 +801,36 @@ def commit(self, **kwargs) -> None: buffer.clear() - def add_node(self, node_params: NodeParamsT) -> NodeT | None: - _, node = self._build_node(node_params) - if sys.getsizeof(node_params["name"]) < self.max_mongo_db_document_size: + def add_node(self, node_params: NodeT) -> NodeT | None: + node: NodeT = self._build_node(node_params) + if sys.getsizeof(node_params.name) < self.max_mongo_db_document_size: _, buffer = self.mongo_bulk_insertion_buffer[MongoCollectionNames.ATOMS] - buffer.add(_HashableDocument(node)) + buffer.add(_HashableDocument(node.to_dict())) if len(buffer) >= self.mongo_bulk_insertion_limit: self.commit() return node else: - logger().warning("Discarding atom whose name is too large: {node_name}") + logger().warning(f"Discarding atom whose name is too large: {node.name}") return None - def add_link(self, link_params: LinkParamsT, toplevel: bool = True) -> LinkT | None: - result = self._build_link(link_params, toplevel) - if result is None: + def _build_link(self, link_params: LinkT, toplevel: bool = True) -> LinkT | None: + # This is necessary because `_build_link` in the parent class (implemented in C++) + # calls back to `add_link`. Without this, `nanobind` is not able to find `add_link` + # implementation in the child class, and raises a `RuntimeError` with the message that + # it is trying to call an abstract method (virtual pure). + return super()._build_link(link_params, toplevel) + + def add_link(self, link_params: LinkT, toplevel: bool = True) -> LinkT | None: + link: LinkT | None = self._build_link(link_params, toplevel) + if link is None: return None - link = result[1] _, buffer = self.mongo_bulk_insertion_buffer[MongoCollectionNames.ATOMS] - buffer.add(_HashableDocument(link)) + buffer.add(_HashableDocument(link.to_dict())) if len(buffer) >= self.mongo_bulk_insertion_limit: self.commit() return link - def _get_and_delete_links_by_handles(self, handles: HandleListT) -> list[dict[str, Any]]: + def _get_and_delete_links_by_handles(self, handles: HandleListT) -> list[DocumentT]: documents = [] for handle in handles: if document := self.mongo_atoms_collection.find_one_and_delete( @@ -866,7 +881,7 @@ def _retrieve_incoming_set(self, handle: str, **kwargs) -> HandleSetT: HandleSetT: Set of members for the given key """ key = _build_redis_key(KeyPrefix.INCOMING_SET, handle) - return self._get_redis_members(key, **kwargs) + return self._get_redis_members(key) def _delete_smember_incoming_set(self, handle: str, smember: str) -> None: """ @@ -955,7 +970,7 @@ def _retrieve_name(self, handle: str) -> str | None: else: return None - def _retrieve_hash_targets_value(self, key_prefix: str, handle: str, **kwargs) -> HandleSetT: + def _retrieve_hash_targets_value(self, key_prefix: str, handle: str) -> HandleSetT: """ Retrieve the hash targets value for the given handle from Redis. @@ -967,13 +982,12 @@ def _retrieve_hash_targets_value(self, key_prefix: str, handle: str, **kwargs) - key_prefix (str): The prefix to be used in the Redis key. handle (str): The unique identifier for the atom whose hash targets value is to be retrieved. - **kwargs: Additional keyword arguments Returns: HandleSetT: Set of members in the hash targets value. """ key = _build_redis_key(key_prefix, handle) - return self._get_redis_members(key, **kwargs) + return self._get_redis_members(key) def _delete_smember_template(self, handle: str, smember: str) -> None: """ @@ -1033,7 +1047,7 @@ def _retrieve_custom_index(self, index_id: str) -> dict[str, Any] | None: logger().error(f"Unexpected error retrieving custom index with ID {index_id}: {e}") raise e - def _get_redis_members(self, key: str, **kwargs) -> HandleSetT: + def _get_redis_members(self, key: str) -> HandleSetT: """ Retrieve members from a Redis set. @@ -1045,7 +1059,7 @@ def _get_redis_members(self, key: str, **kwargs) -> HandleSetT: """ return set(self.redis.smembers(key)) # type: ignore - def _update_atom_indexes(self, documents: Iterable[dict[str, Any]], **kwargs) -> None: + def _update_atom_indexes(self, documents: Iterable[DocumentT], **kwargs) -> None: """ Update the indexes for the given documents in the database. @@ -1054,7 +1068,7 @@ def _update_atom_indexes(self, documents: Iterable[dict[str, Any]], **kwargs) -> it updates the node index. Args: - documents (Iterable[dict[str, any]]): An iterable of documents to be indexed. + documents (Iterable[DocumentT): An iterable of documents to be indexed. **kwargs: Additional keyword arguments for index updates. """ for document in documents: @@ -1063,7 +1077,7 @@ def _update_atom_indexes(self, documents: Iterable[dict[str, Any]], **kwargs) -> else: self._update_node_index(document, **kwargs) - def _update_node_index(self, document: dict[str, Any], **kwargs) -> None: + def _update_node_index(self, document: DocumentT, **kwargs) -> None: """ Update the index for the given node document in the database. @@ -1073,7 +1087,7 @@ def _update_node_index(self, document: dict[str, Any], **kwargs) -> None: links for the node. Args: - document (dict[str, Any]): The node document to be indexed. + document (DocumentT): The node document to be indexed. **kwargs: Additional keyword arguments for index updates. Supports `delete_atom` to indicate whether the node should be deleted from the index. """ @@ -1089,7 +1103,7 @@ def _update_node_index(self, document: dict[str, Any], **kwargs) -> None: else: self.redis.set(key, node_name) - def _update_link_index(self, document: dict[str, Any], **kwargs) -> None: + def _update_link_index(self, document: DocumentT, **kwargs) -> None: """ Update the index for the given link document in the database. @@ -1099,7 +1113,7 @@ def _update_link_index(self, document: dict[str, Any], **kwargs) -> None: links for the link. Args: - document (dict[str, Any]): The link document to be indexed. + document (DocumentT): The link document to be indexed. **kwargs: Additional keyword arguments for index updates. Supports `delete_atom` to indicate whether the link should be deleted from the index. """ @@ -1166,7 +1180,7 @@ def _update_link_index(self, document: dict[str, Any], **kwargs) -> None: self.redis.sadd(key, *incoming_buffer[handle]) @staticmethod - def _is_document_link(document: dict[str, Any]) -> bool: + def _is_document_link(document: DocumentT) -> bool: """ Determine if the given document is a link. @@ -1174,7 +1188,7 @@ def _is_document_link(document: dict[str, Any]) -> bool: indicates that the document is a link. Args: - document (dict[str, Any]): The document to be checked. + document (DocumentT): The document to be checked. Returns: bool: True if the document is a link, False otherwise. @@ -1199,9 +1213,11 @@ def _calculate_composite_type_hash(composite_type: list[Any]) -> str: def calculate_composite_type_hashes(_composite_type: list[Any]) -> HandleListT: return [ - ExpressionHasher.composite_hash(calculate_composite_type_hashes(t)) - if isinstance(t, list) - else ExpressionHasher.named_type_hash(t) + ( + ExpressionHasher.composite_hash(calculate_composite_type_hashes(t)) + if isinstance(t, list) + else ExpressionHasher.named_type_hash(t) + ) for t in _composite_type ] @@ -1210,7 +1226,7 @@ def calculate_composite_type_hashes(_composite_type: list[Any]) -> HandleListT: def _retrieve_documents_by_index( self, collection: Collection, index_id: str, **kwargs - ) -> tuple[int, list[dict[str, Any]]]: + ) -> tuple[int, list[DocumentT]]: """ Retrieve documents from the specified MongoDB collection using the given index. @@ -1226,7 +1242,7 @@ def _retrieve_documents_by_index( - chunk_size (int, optional): The number of documents to retrieve per chunk. Returns: - tuple[int, list[dict[str, Any]]]: A tuple containing the cursor position and a list of + tuple[int, list[DocumentT]]: A tuple containing the cursor position and a list of retrieved documents. Raises: @@ -1268,9 +1284,7 @@ def _retrieve_documents_by_index( else: raise ValueError(f"Index '{index_id}' does not exist in collection '{collection}'") - def reindex( - self, pattern_index_templates: dict[str, list[dict[str, Any]]] | None = None - ) -> None: + def reindex(self, pattern_index_templates: dict[str, list[DocumentT]] | None = None) -> None: if pattern_index_templates is not None: self.pattern_index_templates = deepcopy(pattern_index_templates) self.redis.flushall() @@ -1281,9 +1295,7 @@ def delete_atom(self, handle: str, **kwargs) -> None: mongo_filter: dict[str, str] = {FieldNames.ID_HASH: handle} - document: dict[str, Any] | None = self.mongo_atoms_collection.find_one_and_delete( - mongo_filter - ) + document: DocumentT | None = self.mongo_atoms_collection.find_one_and_delete(mongo_filter) if not document: logger().error( @@ -1386,7 +1398,10 @@ def _get_atoms_by_index(self, index_id: str, **kwargs) -> tuple[int, list[AtomT] def retrieve_all_atoms(self) -> list[AtomT]: try: - return list(self.mongo_atoms_collection.find()) + return [ + self._build_atom_from_dict(document) + for document in self.mongo_atoms_collection.find() + ] except Exception as e: logger().error(f"Error retrieving all atoms: {str(e)}") raise e @@ -1400,7 +1415,7 @@ def bulk_insert(self, documents: list[AtomT]) -> None: Additional keyword arguments can be used to customize the insertion behavior. Args: - documents (list[dict[str, Any]]): A list of documents to be inserted into the collection. + documents (list[AtomT]): A list of atoms to be inserted into the collection. Raises: pymongo.errors.BulkWriteError: If there is an error during the bulk write operation. @@ -1408,8 +1423,9 @@ def bulk_insert(self, documents: list[AtomT]) -> None: """ try: _id = FieldNames.ID_HASH - for document in documents: + docs: list[DocumentT] = [d.to_dict() for d in documents] + for document in docs: self.mongo_atoms_collection.replace_one({_id: document[_id]}, document, upsert=True) - self._update_atom_indexes(documents) + self._update_atom_indexes(docs) except Exception as e: # pylint: disable=broad-except logger().error(f"Error bulk inserting documents: {str(e)}") diff --git a/hyperon_das_atomdb/database.py b/hyperon_das_atomdb/database.py index 8aa6b27f..4fb76208 100644 --- a/hyperon_das_atomdb/database.py +++ b/hyperon_das_atomdb/database.py @@ -1,49 +1,23 @@ -""" -This module defines the abstract base class for Atom databases and provides various -utility methods for managing nodes and links. - -The AtomDB class includes methods for adding, deleting, and retrieving nodes and links, -as well as methods for querying the database by different criteria. It also supports -indexing and pattern matching for efficient querying. - -Classes: - AtomDB: An abstract base class for Atom databases, providing a common interface - for different implementations. - FieldNames: An enumeration of field names used in the database. - FieldIndexType: An enumeration of index types used in the database. - -Constants: - WILDCARD: A constant representing a wildcard character. - -Type Aliases: - IncomingLinksT: A type alias for incoming links. -""" - -import re -from abc import ABC, abstractmethod -from collections import OrderedDict -from enum import Enum -from typing import Any, TypeAlias - -from hyperon_das_atomdb.exceptions import AddLinkException, AddNodeException, AtomDoesNotExist -from hyperon_das_atomdb.logger import logger -from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher - -WILDCARD = "*" +from typing import TypeAlias + +from hyperon_das_atomdb_cpp.constants import ( + TYPE_HASH, + TYPEDEF_MARK_HASH, + WILDCARD, + WILDCARD_HASH, + FieldIndexType, + FieldNames, +) +from hyperon_das_atomdb_cpp.database import AtomDB +from hyperon_das_atomdb_cpp.document_types import Atom, Link, Node # pylint: disable=invalid-name -HandleT: TypeAlias = str - -AtomT: TypeAlias = dict[str, Any] - -NodeT: TypeAlias = AtomT - -NodeParamsT: TypeAlias = NodeT - -LinkT: TypeAlias = AtomT +AtomT: TypeAlias = Atom +NodeT: TypeAlias = Node +LinkT: TypeAlias = Link -LinkParamsT: TypeAlias = LinkT +HandleT: TypeAlias = str HandleListT: TypeAlias = list[HandleT] @@ -54,853 +28,19 @@ # pylint: enable=invalid-name -class FieldNames(str, Enum): - """Enumeration of field names used in the AtomDB.""" - - ID_HASH = "_id" - COMPOSITE_TYPE = "composite_type" - COMPOSITE_TYPE_HASH = "composite_type_hash" - NODE_NAME = "name" - TYPE_NAME = "named_type" - TYPE_NAME_HASH = "named_type_hash" - KEY_PREFIX = "key" - KEYS = "keys" - IS_TOPLEVEL = "is_toplevel" - - -class FieldIndexType(str, Enum): - """Enumeration of index types used in the AtomDB.""" - - BINARY_TREE = "binary_tree" - TOKEN_INVERTED_LIST = "token_inverted_list" - - -class AtomDB(ABC): - """ - Abstract class for Atom databases. - """ - - key_pattern = re.compile(r"key_\d+") - - def __repr__(self) -> str: - """ - Magic method for string representation of the class. - Returns a string representation of the AtomDB class. - """ - return "" # pragma no cover - - @staticmethod - def node_handle(node_type: str, node_name: str) -> str: - """ - Generate a unique handle for a node based on its type and name. - - Args: - node_type (str): The type of the node. - node_name (str): The name of the node. - - Returns: - str: A unique handle for the node. - """ - return ExpressionHasher.terminal_hash(node_type, node_name) - - @staticmethod - def link_handle(link_type: str, target_handles: HandleListT) -> str: - """ - Generate a unique handle for a link based on its type and target handles. - - Args: - link_type (str): The type of the link. - target_handles (HandleListT): A list of target handles for the link. - - Returns: - str: A unique handle for the link. - """ - named_type_hash = ExpressionHasher.named_type_hash(link_type) - return ExpressionHasher.expression_hash(named_type_hash, target_handles) - - def _reformat_document(self, document: AtomT, **kwargs) -> AtomT: - """ - Transform a document to the target format. - - Args: - document (AtomT): The document to transform. - **kwargs: Additional keyword arguments that may be used for transformation. - - targets_document (bool, optional): If True, include the `targets_document` in the - response. Defaults to False. - - deep_representation (bool, optional): If True, include a deep representation of - the targets. Defaults to False. - - Returns: - AtomT: The transformed document in the target format. - """ - answer: AtomT = document - if kwargs.get("targets_document", False): - targets_document = [self.get_atom(target) for target in answer["targets"]] - answer["targets_document"] = targets_document - - if kwargs.get("deep_representation", False): - - def _recursive_targets(targets, **_kwargs): - return [self.get_atom(target, **_kwargs) for target in targets] - - if "targets" in answer: - deep_targets = _recursive_targets(answer["targets"], **kwargs) - answer["targets"] = deep_targets - - return answer - - def _build_node(self, node_params: NodeParamsT) -> tuple[str, NodeT]: - """ - Build a node with the specified parameters. - - Args: - node_params (NodeParamsT): A mapping containing node parameters. - It should have the following keys: - - 'type': The type of the node. - - 'name': The name of the node. - - Returns: - tuple[str, NodeT]: A tuple containing the handle of the node and the node dictionary. - - Raises: - AddNodeException: If the 'type' or 'name' fields are missing in node_params. - """ - reserved_parameters = ["handle", "_id", "composite_type_hash", "named_type"] - - valid_params = { - key: value for key, value in node_params.items() if key not in reserved_parameters - } - - node_type = valid_params.get("type") - node_name = valid_params.get("name") - - if node_type is None or node_name is None: - raise AddNodeException( - message='The "name" and "type" fields must be sent', - details=f"{valid_params=}", - ) - - handle = self.node_handle(node_type, node_name) - - node: NodeT = { - FieldNames.ID_HASH: handle, - "handle": handle, - FieldNames.COMPOSITE_TYPE_HASH: ExpressionHasher.named_type_hash(node_type), - FieldNames.NODE_NAME: node_name, - FieldNames.TYPE_NAME: node_type, - } - - node.update(valid_params) - - return handle, node - - def _build_link( - self, link_params: LinkParamsT, toplevel: bool = True - ) -> tuple[str, LinkT, HandleListT] | None: - """ - Build a link the specified parameters. - - Args: - link_params (LinkParamsT): A mapping containing link parameters. - It should have the following keys: - - 'type': The type of the link. - - 'targets': A list of target elements. - toplevel (bool): A boolean flag to indicate toplevel links, i.e., links which are not - nested inside other links. Defaults to True. - - Returns: - tuple[str, LinkT, HandleListT] | None: A tuple containing the handle of the link, the - link dictionary, and a list of target hashes. Or None if something went wrong. - - Raises: - AddLinkException: If the 'type' or 'targets' fields are missing in - link_params. - """ - reserved_parameters = [ - "handle", - "targets", - "_id", - "composite_type_hash", - "composite_type", - "is_toplevel", - "named_type", - "named_type_hash", - "key_n", - ] - - valid_params = { - key: value - for key, value in link_params.items() - if key not in reserved_parameters and not re.search(AtomDB.key_pattern, key) - } - - targets = link_params.get("targets") - link_type = link_params.get("type") - - if link_type is None or targets is None: - raise AddLinkException( - message='The "type" and "targets" fields must be sent', - details=f"{valid_params=}", - ) - - link_type_hash = ExpressionHasher.named_type_hash(link_type) - target_handles = [] - composite_type = [link_type_hash] - composite_type_hash = [link_type_hash] - - for target in targets: - if not isinstance(target, dict): - raise ValueError("The target must be a dictionary") - if "targets" not in target: - atom = self.add_node(target) - if atom is None: - return None - atom_hash = atom["composite_type_hash"] - composite_type.append(atom_hash) - else: - atom = self.add_link(target, toplevel=False) - if atom is None: - return None - atom_hash = atom["composite_type_hash"] - composite_type.append(atom["composite_type"]) - composite_type_hash.append(atom_hash) - target_handles.append(atom["_id"]) - - handle = ExpressionHasher.expression_hash(link_type_hash, target_handles) - - link: LinkT = { - FieldNames.ID_HASH: handle, - "handle": handle, - "targets": target_handles, - FieldNames.COMPOSITE_TYPE_HASH: ExpressionHasher.composite_hash(composite_type_hash), - FieldNames.IS_TOPLEVEL: toplevel, - FieldNames.COMPOSITE_TYPE: composite_type, - FieldNames.TYPE_NAME: link_type, - FieldNames.TYPE_NAME_HASH: link_type_hash, - } - - for item in range(len(targets)): - link[f"key_{item}"] = target_handles[item] - - link.update(valid_params) - - return handle, link, target_handles - - def node_exists(self, node_type: str, node_name: str) -> bool: - """ - Check if a node with the specified type and name exists in the database. - - Args: - node_type (str): The node type. - node_name (str): The node name. - - Returns: - bool: True if the node exists, False otherwise. - """ - try: - self.get_node_handle(node_type, node_name) - return True - except AtomDoesNotExist: - return False - - def link_exists(self, link_type: str, target_handles: HandleListT) -> bool: - """ - Check if a link with the specified type and targets exists in the database. - - Args: - link_type (str): The link type. - target_handles (HandleListT): A list of link target identifiers. - - Returns: - bool: True if the link exists, False otherwise. - """ - try: - self.get_link_handle(link_type, target_handles) - return True - except AtomDoesNotExist: - return False - - @abstractmethod - def get_node_handle(self, node_type: str, node_name: str) -> str: - """ - Get the handle of the node with the specified type and name. - - Args: - node_type (str): The node type. - node_name (str): The node name. - - Returns: - str: The node handle. - """ - - @abstractmethod - def get_node_name(self, node_handle: str) -> str: - """ - Get the name of the node with the specified handle. - - Args: - node_handle (str): The node handle. - - Returns: - str: The node name. - """ - - @abstractmethod - def get_node_type(self, node_handle: str) -> str | None: - """ - Get the type of the node with the specified handle. - - Args: - node_handle (str): The node handle. - - Returns: - str | None: The node type. Or None if the node does not exist. - """ - - @abstractmethod - def get_node_by_name(self, node_type: str, substring: str) -> HandleListT: - """ - Get the name of a node of the specified type containing the given substring. - - Args: - node_type (str): The node type. - substring (str): The substring to search for in node names. - - Returns: - HandleListT: list of handles of nodes whose names matched the criteria. - """ - - @abstractmethod - def get_atoms_by_field(self, query: list[OrderedDict[str, str]]) -> HandleListT: - """ - Query the database by field and value, the performance is improved if the database already - have indexes created for the fields, check 'create_field_index' to create indexes. - Ordering the fields as the index previously created can improve performance. - - Args: - query (list[dict[str, str]]): list of dicts containing 'field' and 'value' keys - - Returns: - HandleListT: list of node IDs - """ - - @abstractmethod - def get_atoms_by_index( - self, - index_id: str, - query: list[OrderedDict[str, str]], - cursor: int = 0, - chunk_size: int = 500, - ) -> tuple[int, list[AtomT]]: - """ - Queries the database to return all atoms matching a specific index ID, filtering the - results based on the provided query dictionary. This method is useful for efficiently - retrieving atoms that match certain criteria, especially when the database has been - indexed using the `create_field_index` function. - - Args: - index_id (str): The ID of the index to query against. This index should have been - created previously using the `create_field_index` method. - query (list[OrderedDict[str, str]]): A list of ordered dictionaries, each containing - a "field" and "value" key, representing the criteria for filtering atoms. - cursor (int): An optional cursor indicating the starting point within the result set - from which to return atoms. This can be used for pagination or to resume a - previous query. If not provided, the query starts from the beginning. - chunk_size (int): An optional size indicating the maximum number of atom IDs to - return in one response. Useful for controlling response size and managing large - datasets. If not provided, a default value is used. - - Returns: - tuple[int, list[AtomT]]: A tuple containing the cursor position and a list of - retrieved atoms. - - Note: - The `cursor` and `chunk_size` parameters are particularly useful for handling large - datasets by allowing the retrieval of results in manageable chunks rather than all - at once. - """ - - @abstractmethod - def get_atoms_by_text_field( - self, - text_value: str, - field: str | None = None, - text_index_id: str | None = None, - ) -> HandleListT: - """ - Query the database by a text field, use the text_value arg to query using an existing text - index (text_index_id is optional), if a TOKEN_INVERTED_LIST type of index wasn't previously - created the field arg must be provided, or it will raise an Exception. - When 'text_value' and 'field' value are provided, it will default to a regex search, - creating an index to the field can improve the performance. - - Args: - text_value (str): Value to search for, if only this argument is provided it will use - a TOKEN_INVERTED_LIST index in the search - field (str | None): Field to be used to search, if this argument is provided - it will not use TOKEN_INVERTED_LIST in the search - text_index_id (str | None): TOKEN_INVERTED_LIST index id to search for - - - Returns: - HandleListT: list of node IDs ordered by the closest match - """ - - @abstractmethod - def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> HandleListT: - """ - Query the database by node name starting with 'startswith' value, this query is indexed - and the performance is improved by searching only the index that starts with the - requested value. - - Args: - node_type (str): _description_ - startswith (str): _description_ - - Returns: - HandleListT: list of node IDs - """ - - @abstractmethod - def get_all_nodes(self, node_type: str, names: bool = False) -> list[str]: - """ - Get all nodes of a specific type. - - Args: - node_type (str): The node type. - names (bool): If True, return node names instead of handles. Default is False. - - Returns: - list[str]: A list of node handles or names, depending on the value of 'names'. - """ - - @abstractmethod - def get_all_links(self, link_type: str, **kwargs) -> HandleSetT: - """ - Get all links of a specific type. - - Args: - link_type (str): The type of the link. - **kwargs: Additional arguments that may be used for filtering or other purposes. - - Returns: - HandleSetT: Link handles. - """ - - @abstractmethod - def get_link_handle(self, link_type: str, target_handles: HandleListT) -> str: - """ - Get the handle of the link with the specified type and targets. - - Args: - link_type (str): The link type. - target_handles (HandleListT): A list of link target identifiers. - - Returns: - str: The link handle. - """ - - @abstractmethod - def get_link_type(self, link_handle: str) -> str | None: - """ - Get the type of the link with the specified handle. - - Args: - link_handle (str): The link handle. - - Returns: - str | None: The link type. Or None if the link does not exist. - """ - - @abstractmethod - def get_link_targets(self, link_handle: str) -> HandleListT: - """ - Get the target handles of a link specified by its handle. - - Args: - link_handle (str): The link handle. - - Returns: - HandleListT: A list of target identifiers of the link. - """ - - @abstractmethod - def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT: - """ - Retrieve incoming links for a specified atom handle. - - Args: - atom_handle (str): The handle of the atom for which to retrieve incoming links. - **kwargs: Additional arguments that may be used for filtering or other purposes. - - Returns: - IncomingLinksT: List of incoming links. - """ - - @abstractmethod - def get_matched_links( - self, link_type: str, target_handles: HandleListT, **kwargs - ) -> HandleSetT: - """ - Retrieve links that match a specified link type and target handles. - - Args: - link_type (str): The type of the link to match. - target_handles (HandleListT): A list of target handles to match. - **kwargs: Additional arguments that may be used for filtering or other - purposes. - - Returns: - HandleSetT: Set of matching link handles. - """ - - @abstractmethod - def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT: - """ - Retrieve links that match a specified type template. - - Args: - template (list[Any]): A list representing the type template to match. - **kwargs: Additional arguments that may be used for filtering or other - purposes. - - Returns: - HandleSetT: Set of matching link handles. - """ - - @abstractmethod - def get_matched_type(self, link_type: str, **kwargs) -> HandleSetT: - """ - Retrieve links that match a specified link type. - - Args: - link_type (str): The type of the link to match. - **kwargs: Additional arguments that may be used for filtering or other - purposes. - - Returns: - HandleSetT: Set of matching link handles. - """ - - def get_atom(self, handle: str, **kwargs) -> AtomT: - """ - Retrieve an atom by its handle. - - Args: - handle (str): The handle of the atom to retrieve. - **kwargs: Additional arguments that may be used for filtering or other purposes. - - no_target_format (bool, optional): If True, return the document without - transforming it to the target format. Defaults to False. - - targets_document (bool, optional): If True, include the `targets_document` in the - response. Defaults to False. - - deep_representation (bool, optional): If True, include a deep representation of - the targets. Defaults to False. - - Returns: - AtomT: A dictionary representation of the atom, if found. - - Raises: - AtomDoesNotExist: If the atom with the specified handle does not exist. - """ - document = self._get_atom(handle) - if document: - if not kwargs.get("no_target_format", False): - document = self._reformat_document(document, **kwargs) - return document - else: - logger().error( - f"Failed to retrieve atom for handle: {handle}. " - f"This atom does not exist. - Details: {kwargs}" - ) - raise AtomDoesNotExist( - message="Nonexistent atom", - details=f"handle: {handle}", - ) - - @abstractmethod - def _get_atom(self, handle: str) -> AtomT | None: - """ - Retrieve an atom by its handle. - - Args: - handle (str): The handle of the atom to retrieve. - - Returns: - AtomT | None: A dictionary representation of the atom if found, None otherwise. - - Note: - This method is intended for internal use and should not be called directly. - It must be implemented by subclasses to provide a concrete way to retrieve atoms by - their handles. - """ - - @abstractmethod - def get_atom_type(self, handle: str) -> str | None: - """ - Retrieve the atom's type by its handle. - - Args: - handle (str): The handle of the atom to retrieve the type for. - - Returns: - str | None: The type of the atom. Or None if the atom does not exist. - """ - - @abstractmethod - def get_atom_as_dict(self, handle: str, arity: int | None = 0) -> dict[str, Any]: - """ - Get an atom as a dictionary representation. - - Args: - handle (str): The atom handle. - arity (int | None): The arity of the atom. Defaults to 0. - - Returns: - dict[str, Any]: A dictionary representation of the atom. - """ - - @abstractmethod - def count_atoms(self, parameters: dict[str, Any] | None = None) -> dict[str, int]: - """ - Count the total number of atoms in the database. - If the optional parameter 'precise' is set to True returns the node count and link count - (slow), otherwise return the atom_count (fast). - - Args: - parameters (dict[str, Any] | None): An optional dictionary containing the - following key: - 'precise' (bool) If set to True, the count provides an accurate count - but may be slower. If set to False, the count will be an estimate, which is - faster but less precise. Defaults to None. - - Returns: - dict[str, int]: A dictionary containing the following keys: - 'node_count' (int): The count of node atoms - 'link_count' (int): The count of link atoms - 'atom_count' (int): The total count of all atoms - """ - - @abstractmethod - def clear_database(self) -> None: - """Clear the entire database, removing all data.""" - - @abstractmethod - def add_node(self, node_params: NodeParamsT) -> NodeT | None: - """ - Adds a node to the database. - - This method allows you to add a node to the database with the specified node parameters. - A node must have 'type' and 'name' fields in the node_params dictionary. - - Args: - node_params (NodeParamsT): A mapping containing node parameters. It should have the - following keys: - - 'type': The type of the node. - - 'name': The name of the node. - - Returns: - NodeT | None: The information about the added node, including its unique key and - other details. None if for some reason the node was not added. - - Raises: - AddNodeException: If the 'type' or 'name' fields are missing in node_params. - - Note: - This method creates a unique key for the node based on its type and name. If a node - with the same key already exists, it just returns the node. - - Example: - To add a node, use this method like this: - >>> node_params = { - 'type': 'Reactome', - 'name': 'Reactome:R-HSA-164843', - } - >>> db.add_node(node_params) - """ - - @abstractmethod - def add_link(self, link_params: LinkParamsT, toplevel: bool = True) -> LinkT | None: - """ - Adds a link to the database. - - This method allows to add a link to the database with the specified link parameters. - A link must have a 'type' and 'targets' field in the link_params dictionary. - - Args: - link_params (LinkParamsT): A dictionary containing link parameters. - It should have the following keys: - - 'type': The type of the link. - - 'targets': A list of target elements. - toplevel: boolean flag to indicate toplevel links i.e. links which are not nested - inside other links. - - Returns: - LinkT | None: The information about the added link, including its unique key and - other details. Or None if for some reason the link was not added. - - Raises: - AddLinkException: If the 'type' or 'targets' fields are missing in link_params. - - Note: - This method supports recursion when a target element itself contains links. It - calculates a unique key for the link based on its type and targets. If a link with - the same key already exists, it just returns the link. - - Example: - To add a link, use this method like this: - >>> link_params = { - 'type': 'Evaluation', - 'targets': [ - { - 'type': 'Predicate', - 'name': 'Predicate:has_name' - }, - { - 'type': 'Set', - 'targets': [ - { - 'type': 'Reactome', - 'name': 'Reactome:R-HSA-164843', - }, - { - 'type': 'Concept', - 'name': 'Concept:2-LTR circle formation', - }, - ], - }, - ], - } - >>> db.add_link(link_params) - """ - - @abstractmethod - def reindex( - self, pattern_index_templates: dict[str, list[dict[str, Any]]] | None = None - ) -> None: - """ - Reindex inverted pattern index according to passed templates. - - Args: - pattern_index_templates: indexes are specified by atom type in a dict mapping from atom - types to a pattern template: - - { - : - } - - is a list of dicts, each dict specifies a pattern template for: - - { - "named_type": True/False, - "selected_positions": [n1, n2, ...], - } - - Pattern templates are applied to each link entered in the atom space in order to - determine which entries should be created in the inverted pattern index. Entries - in the inverted pattern index are like patterns where the link type and each of - its targets may be replaced by wildcards. For instance, given a similarity link - Similarity(handle1, handle2) it could be used to create any of the following - entries in the inverted pattern index: - - *(handle1, handle2) - Similarity(*, handle2) - Similarity(handle1, *) - Similarity(*, *) - - If we create all possibilities of index entries to all links, the pattern index size - will grow exponentially, so we limit the entries we want to create by each type of - link. This is what a pattern template for a given link type is. For instance if - we apply this pattern template: - - { - "named_type": False - "selected_positions": [0, 1] - } - - to Similarity(handle1, handle2) we'll create only the following entries: - - Similarity(*, handle2) - Similarity(handle1, *) - Similarity(*, *) - - If we apply this pattern template instead: - - { - "named_type": True - "selected_positions": [1] - } - - We'll have: - - *(handle1, handle2) - Similarity(handle1, *) - """ - - @abstractmethod - def delete_atom(self, handle: str, **kwargs) -> None: - """Delete an atom from the database - - Args: - handle (str): Atom handle - - Raises: - AtomDoesNotExist: If the atom does not exist - """ - - @abstractmethod - def create_field_index( - self, - atom_type: str, - fields: list[str], - named_type: str | None = None, - composite_type: list[Any] | None = None, - index_type: FieldIndexType | None = None, - ) -> str: - """ - Create an index for the specified fields in the database. - - Args: - atom_type (str): The type of the atom for which the index is created. - fields (list[str]): A list of fields to be indexed. - named_type (str | None): The named type of the atom. Defaults to None. - composite_type (list[Any] | None): A list representing the composite type of - the atom. Defaults to None. - index_type (FieldIndexType | None): The type of the index to create. Defaults to None. - - Returns: - str: The ID of the created index. - """ - - @abstractmethod - def bulk_insert(self, documents: list[AtomT]) -> None: - """ - Insert multiple documents into the database. - - Args: - documents (list[AtomT]): A list of dictionaries, each representing a document to be - inserted into the database. - """ - - @abstractmethod - def retrieve_all_atoms(self) -> list[AtomT]: - """ - Retrieve all atoms from the database. - - Returns: - list[AtomT]: A list of dictionaries representing the atoms, or a list of tuples - containing atom handles and their associated data. - """ - - @abstractmethod - def commit(self, **kwargs) -> None: - """Commit the current state of the database. - - This method is intended to be implemented by subclasses to handle the commit operation, - which may involve persisting changes to a storage backend or performing other necessary - actions to finalize the current state of the database. - Updates of atoms aren't allowed on the same transaction. - - Args: - **kwargs: Additional keyword arguments that may be used by the implementation of the - commit operation. - """ +__all__ = [ + "FieldNames", + "FieldIndexType", + "AtomDB", + "WILDCARD", + "WILDCARD_HASH", + "TYPE_HASH", + "TYPEDEF_MARK_HASH", + "AtomT", + "NodeT", + "LinkT", + "HandleT", + "HandleListT", + "HandleSetT", + "IncomingLinksT", +] diff --git a/hyperon_das_atomdb/exceptions.py b/hyperon_das_atomdb/exceptions.py index 90cd0fcf..a136568a 100644 --- a/hyperon_das_atomdb/exceptions.py +++ b/hyperon_das_atomdb/exceptions.py @@ -1,45 +1,27 @@ """Custom exceptions for Atom DB""" - -class AtomDbBaseException(Exception): - """ - Base class for Atom DB exceptions - """ - - def __init__(self, message: str, details: str = ""): - self.message = message - self.details = details - - super().__init__(self.message, self.details) +from hyperon_das_atomdb_cpp.exceptions import ( + AddLinkException, + AddNodeException, + AtomDbBaseException, + AtomDoesNotExist, + InvalidAtomDB, + InvalidOperationException, + RetryException, +) class ConnectionMongoDBException(AtomDbBaseException): """Exception raised for errors in the connection to MongoDB.""" -class AtomDoesNotExist(AtomDbBaseException): - """Exception raised when an atom does not exist.""" - - -class AddNodeException(AtomDbBaseException): - """Exception raised when adding a node fails.""" - - -class AddLinkException(AtomDbBaseException): - """Exception raised when adding a link fails.""" - - -class InvalidOperationException(AtomDbBaseException): - """Exception raised for invalid operations.""" - - -class RetryException(AtomDbBaseException): - """Exception raised for retryable errors.""" - - -class InvalidAtomDB(AtomDbBaseException): - """Exception raised for invalid Atom DB operations.""" - - -class InvalidSQL(AtomDbBaseException): - """Exception raised for invalid SQL operations.""" +__all__ = [ + "ConnectionMongoDBException", + "AtomDbBaseException", + "AtomDoesNotExist", + "AddNodeException", + "AddLinkException", + "InvalidOperationException", + "RetryException", + "InvalidAtomDB", +] diff --git a/hyperon_das_atomdb/index.py b/hyperon_das_atomdb/index.py index 5a7ff0f4..b88a96c5 100644 --- a/hyperon_das_atomdb/index.py +++ b/hyperon_das_atomdb/index.py @@ -20,10 +20,7 @@ def generate_index_id(field: str, conditionals: dict[str, Any]) -> str: Returns: str: The index ID. """ - # TODO(angelo,andre): remove '_' from `ExpressionHasher._compute_hash` method? - return ExpressionHasher._compute_hash( # pylint: disable=protected-access - f"{field}{conditionals}" - ) + return ExpressionHasher.compute_hash(f"{field}{conditionals}") @abstractmethod def create( diff --git a/hyperon_das_atomdb/utils/expression_hasher.py b/hyperon_das_atomdb/utils/expression_hasher.py index c6ed3537..8e885668 100644 --- a/hyperon_das_atomdb/utils/expression_hasher.py +++ b/hyperon_das_atomdb/utils/expression_hasher.py @@ -15,9 +15,9 @@ class ExpressionHasher: compound_separator = " " @staticmethod - def _compute_hash( + def compute_hash( text: str, - ) -> str: # TODO(angelo,andre): remove '_' to make method public? + ) -> str: """ Compute the MD5 hash of the given text. @@ -47,7 +47,7 @@ def named_type_hash(name: str) -> str: Returns: str: The MD5 hash of the named type as a hexadecimal string. """ - return ExpressionHasher._compute_hash(name) + return ExpressionHasher.compute_hash(name) @staticmethod def terminal_hash(named_type: str, terminal_name: str) -> str: @@ -65,7 +65,7 @@ def terminal_hash(named_type: str, terminal_name: str) -> str: Returns: str: The MD5 hash of the terminal expression as a hexadecimal string. """ - return ExpressionHasher._compute_hash( + return ExpressionHasher.compute_hash( ExpressionHasher.compound_separator.join([named_type, terminal_name]) ) @@ -109,7 +109,7 @@ def composite_hash(hash_base: str | list[str]) -> str: if len(hash_base) == 1: return hash_base[0] else: - return ExpressionHasher._compute_hash( + return ExpressionHasher.compute_hash( ExpressionHasher.compound_separator.join(hash_base) ) # TODO unreachable @@ -123,7 +123,7 @@ class StringExpressionHasher: # TODO(angelo,andre): remove this class? it's not """Utility class for generating string representations of expression hashes.""" @staticmethod - def _compute_hash(text: str) -> str: + def compute_hash(text: str) -> str: """Compute the MD5 hash of the given text.""" return str() # TODO(angelo,andre): this seems right? diff --git a/hyperon_das_atomdb/utils/patterns.py b/hyperon_das_atomdb/utils/patterns.py deleted file mode 100644 index 40191430..00000000 --- a/hyperon_das_atomdb/utils/patterns.py +++ /dev/null @@ -1,81 +0,0 @@ -""" -This module provides utility functions for generating binary matrices and manipulating them. - -It includes functions to generate binary matrices of a given size, multiply binary matrices -by string matrices, and build pattern keys using a list of hashes. These utilities are useful -for various operations involving binary and string data manipulation. -""" - -from hyperon_das_atomdb.database import WILDCARD -from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher - -# TODO(angelo,andre): delete this commented function? -# def generate_binary_matrix(numbers: int) -> list: -# """This function is more efficient if numbers are greater than 5""" -# return list(itertools.product([0, 1], repeat=numbers)) - - -def generate_binary_matrix(numbers: int) -> list[list[int]]: - """ - Generate a binary matrix of the given size. - - Args: - numbers (int): The size of the binary matrix to generate. If numbers - is less than or equal to 0, returns a matrix with an empty list. - - Returns: - list[list[int]]: A binary matrix represented as a list of lists, where - each sublist is a row in the matrix. - """ - if numbers <= 0: - return [[]] - smaller_matrix = generate_binary_matrix(numbers - 1) - new_matrix: list[list[int]] = [] - for matrix in smaller_matrix: - new_matrix.append(matrix + [0]) - new_matrix.append(matrix + [1]) - return new_matrix - - -def multiply_binary_matrix_by_string_matrix( - binary_matrix: list[list[int]], string_matrix: list[str] -) -> list[list[str]]: - """ - Multiply a binary matrix by a string matrix. - - Args: - binary_matrix (list[list[int]]): A binary matrix represented as a list - of lists, where each sublist is a row in the matrix. - string_matrix (list[str]): A list of strings to multiply with the - binary matrix. - - Returns: - list[list[str]]: A matrix represented as a list of lists, where each - sublist is a row in the resulting matrix. - """ - result_matrix: list[list[str]] = [] - for binary_row in binary_matrix: - result_row = [ - string if bit == 1 else WILDCARD for bit, string in zip(binary_row, string_matrix) - ] - result_matrix.append(result_row) - return result_matrix[:-1] - - -def build_pattern_keys(hash_list: list[str]) -> list[str]: - """ - Build pattern keys using a list of hashes. - - Args: - hash_list (list[str]): A list of hash strings to build pattern keys from. - - Returns: - list[str]: A list of pattern keys generated from the hash list. - """ - binary_matrix = generate_binary_matrix(len(hash_list)) - result_matrix = multiply_binary_matrix_by_string_matrix(binary_matrix, hash_list) - keys = [ - ExpressionHasher.expression_hash(matrix_item[:1][0], matrix_item[1:]) - for matrix_item in result_matrix - ] - return keys diff --git a/pyproject.toml b/pyproject.toml index 5ab046ae..3cfa6656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ pymongo = "^4.5.0" python-dotenv = "^1.0.0" mongomock = "^4.1.2" setuptools = "^70.2.0" +hyperon-das-atomdb-cpp = "0.0.1" [tool.poetry.group.dev.dependencies] flake8 = "^6.1.0" @@ -29,6 +30,7 @@ pytest = "^7.4.2" pytest-cov = "^4.1.0" flake8-pyproject = "^1.2.3" pre-commit = "^3.5.0" +hyperon-das-atomdb-cpp = "0.0.1" [build-system] requires = ["poetry-core"] diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..ed5439f4 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,62 @@ +from copy import deepcopy +from typing import TypeAlias + +from hyperon_das_atomdb.database import AtomDB, LinkT, NodeT + +CustomAttributesT: TypeAlias = dict[str, str | int | float | bool] + + +def check_handle(handle): + return all((isinstance(handle, str), len(handle) == 32, int(handle, 16))) + + +def add_node( + db: AtomDB, + node_name: str, + node_type: str, + adapter: str, + custom_attributes: CustomAttributesT = {}, +) -> NodeT: + node_params = NodeT(type=node_type, name=node_name, custom_attributes=custom_attributes) + node = db.add_node(node_params) + if adapter != "in_memory_db": + db.commit() + return node + + +def add_link( + db: AtomDB, + link_type: str, + targets: list[NodeT | LinkT], + adapter: str, + is_toplevel: bool = True, + custom_attributes: CustomAttributesT = {}, +) -> LinkT: + link = db.add_link( + LinkT( + type=link_type, + targets=targets, + custom_attributes=custom_attributes, + ), + toplevel=is_toplevel, + ) + if adapter != "in_memory_db": + db.commit() + return link + + +def dict_to_node_params(node_dict: dict) -> NodeT: + return NodeT(**node_dict) + + +def dict_to_link_params(link_dict: dict) -> LinkT: + targets = [ + dict_to_link_params(target) if "targets" in target else dict_to_node_params(target) + for target in link_dict["targets"] + ] + params = deepcopy(link_dict) + params.update({"targets": targets}) + try: + return LinkT(**params) + except TypeError as ex: + raise AssertionError(f"{type(ex)}: {ex} - {params=}") diff --git a/tests/integration/adapters/test_redis_mongo.py b/tests/integration/adapters/test_redis_mongo.py index 3d79b663..3f986b1b 100644 --- a/tests/integration/adapters/test_redis_mongo.py +++ b/tests/integration/adapters/test_redis_mongo.py @@ -2,8 +2,9 @@ from hyperon_das_atomdb.adapters import RedisMongoDB from hyperon_das_atomdb.adapters.redis_mongo_db import KeyPrefix -from hyperon_das_atomdb.database import WILDCARD, AtomDB, FieldIndexType +from hyperon_das_atomdb.database import WILDCARD, AtomDB, FieldIndexType, LinkT, NodeT from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher +from tests.helpers import dict_to_link_params, dict_to_node_params from .animals_kb import ( animal, @@ -34,11 +35,11 @@ def _db(self): def _add_atoms(self, db: RedisMongoDB): for node in node_docs.values(): - db.add_node(node) + db.add_node(dict_to_node_params(node)) for link in inheritance_docs.values(): - db.add_link(link) + db.add_link(dict_to_link_params(link)) for link in similarity_docs.values(): - db.add_link(link) + db.add_link(dict_to_link_params(link)) def _connect_db(self): db = RedisMongoDB( @@ -57,7 +58,7 @@ def _check_basic_patterns(self, db, toplevel_only=False): [WILDCARD, db.node_handle("Concept", "mammal")], toplevel_only=toplevel_only, ) - assert sorted([db.get_atom(answer)["targets"][0] for answer in answers]) == sorted( + assert sorted([db.get_atom(answer).targets[0] for answer in answers]) == sorted( [human, monkey, chimp, rhino] ) answers = db.get_matched_links( @@ -65,13 +66,13 @@ def _check_basic_patterns(self, db, toplevel_only=False): [db.node_handle("Concept", "mammal"), WILDCARD], toplevel_only=toplevel_only, ) - assert sorted([db.get_atom(answer)["targets"][1] for answer in answers]) == sorted([animal]) + assert sorted([db.get_atom(answer).targets[1] for answer in answers]) == sorted([animal]) answers = db.get_matched_links( "Similarity", [WILDCARD, db.node_handle("Concept", "human")], toplevel_only=toplevel_only, ) - assert sorted([db.get_atom(answer)["targets"][0] for answer in answers]) == sorted( + assert sorted([db.get_atom(answer).targets[0] for answer in answers]) == sorted( [monkey, chimp, ent] ) answers = db.get_matched_links( @@ -79,7 +80,7 @@ def _check_basic_patterns(self, db, toplevel_only=False): [db.node_handle("Concept", "human"), WILDCARD], toplevel_only=toplevel_only, ) - assert sorted([db.get_atom(answer)["targets"][1] for answer in answers]) == sorted( + assert sorted([db.get_atom(answer).targets[1] for answer in answers]) == sorted( [monkey, chimp, ent] ) @@ -157,25 +158,28 @@ def test_commit(self, _cleanup, _db: RedisMongoDB): answers = db.get_matched_links( "Inheritance", [WILDCARD, db.node_handle("Concept", "mammal")] ) - assert sorted([db.get_atom(answer)["targets"][0] for answer in answers]) == sorted( + assert sorted([db.get_atom(answer).targets[0] for answer in answers]) == sorted( [human, monkey, chimp, rhino] ) - assert db.get_atom(human)["name"] == node_docs[human]["name"] + assert db.get_atom(human).name == node_docs[human]["name"] link_pre = db.get_atom(inheritance[human][mammal]) - assert "strength" not in link_pre - assert link_pre["named_type"] == "Inheritance" - assert link_pre["targets"] == [human, mammal] + assert link_pre.custom_attributes == dict() + assert link_pre.named_type == "Inheritance" + assert link_pre.targets == [human, mammal] link_new = inheritance_docs[inheritance[human][mammal]].copy() - link_new["strength"] = 1.0 - db.add_link(link_new) + custom_attributes = {"strength": 1.0} + link_new["custom_attributes"] = custom_attributes + db.add_link(dict_to_link_params(link_new)) db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "dog"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "dog"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.commit() assert db.count_atoms({"precise": True}) == { @@ -184,19 +188,21 @@ def test_commit(self, _cleanup, _db: RedisMongoDB): "link_count": 27, } link_pos = db.get_atom(inheritance[human][mammal]) - assert link_pos["named_type"] == "Inheritance" - assert link_pos["targets"] == [human, mammal] - assert "strength" in link_pos - assert link_pos["strength"] == 1.0 + assert link_pos.named_type == "Inheritance" + assert link_pos.targets == [human, mammal] + assert isinstance(link_pos.custom_attributes, dict) + assert "strength" in link_pos.custom_attributes + assert isinstance(link_pos.custom_attributes["strength"], float) + assert link_pos.custom_attributes["strength"] == 1.0 dog = db.node_handle("Concept", "dog") assert db.get_node_name(dog) == "dog" new_link_handle = db.get_link_handle("Inheritance", [dog, mammal]) new_link = db.get_atom(new_link_handle) - assert db.get_link_targets(new_link_handle) == new_link["targets"] + assert db.get_link_targets(new_link_handle) == new_link.targets answers = db.get_matched_links( "Inheritance", [WILDCARD, db.node_handle("Concept", "mammal")] ) - assert sorted([db.get_atom(answer)["targets"][0] for answer in answers]) == sorted( + assert sorted([db.get_atom(answer).targets[0] for answer in answers]) == sorted( [human, monkey, chimp, rhino, dog] ) @@ -212,46 +218,52 @@ def test_reindex(self, _cleanup, _db: RedisMongoDB): def test_delete_atom(self, _cleanup, _db: RedisMongoDB): def _add_all_links(): db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "dog"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "dog"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.commit() def _add_nested_links(): db.add_link( - { - "type": "Inheritance", - "targets": [ - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "dog"}, - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - }, - ], - }, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "dog"}, + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + }, + ], + }, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.commit() @@ -297,9 +309,9 @@ def _check_asserts(): for template in db.default_pattern_index_templates: key = db._apply_index_template( template, - link["named_type_hash"], - link["targets"], - len(link["targets"]), + link.named_type_hash, + link.targets, + len(link.targets), ) keys.add(key) assert set([p for p in db.redis.keys("patterns:*")]) == keys @@ -711,13 +723,15 @@ def _check_asserts_5(): _check_asserts_4() db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.commit() @@ -789,18 +803,21 @@ def test_get_matched_with_pagination(self, _cleanup, _db: RedisMongoDB): ) def test_create_field_index(self, _cleanup, _db: RedisMongoDB): + pytest.skip("Requires new implementation since the new custom attributes were introduced.") db = _db self._add_atoms(db) db.commit() db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "human"}, - {"type": "Concept", "name": "monkey"}, - ], - "tag": "DAS", - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "human"}, + {"type": "Concept", "name": "monkey"}, + ], + "custom_attributes": {"tag": "DAS"}, + } + ) ) db.commit() @@ -825,10 +842,10 @@ def test_create_field_index(self, _cleanup, _db: RedisMongoDB): with PyMongoFindExplain(db.mongo_atoms_collection) as explain: _, doc = db.get_atoms_by_index(my_index, [{"field": "tag", "value": "DAS"}]) - assert doc[0]["handle"] == ExpressionHasher.expression_hash( + assert doc[0].handle == ExpressionHasher.expression_hash( ExpressionHasher.named_type_hash("Similarity"), [human, monkey] ) - assert doc[0]["targets"] == [human, monkey] + assert doc[0].targets == [human, monkey] assert explain[0]["executionStats"]["executionSuccess"] assert explain[0]["executionStats"]["executionStages"]["docsExamined"] == 1 assert explain[0]["executionStats"]["executionStages"]["stage"] == "FETCH" @@ -847,14 +864,16 @@ def test_create_text_index(self, _cleanup, _db: RedisMongoDB): db: RedisMongoDB = _db self._add_atoms(db) db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "human"}, - {"type": "Concept", "name": "monkey"}, - ], - "tag": "DAS", - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "human"}, + {"type": "Concept", "name": "monkey"}, + ], + "custom_attributes": {"tag": "DAS"}, + } + ) ) db.commit() @@ -876,14 +895,16 @@ def test_create_compound_index(self, _cleanup, _db: RedisMongoDB): db: RedisMongoDB = _db self._add_atoms(db) db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "human"}, - {"type": "Concept", "name": "monkey"}, - ], - "tag": "DAS", - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "human"}, + {"type": "Concept", "name": "monkey"}, + ], + "custom_attributes": {"tag": "DAS"}, + } + ) ) db.commit() collection = db.mongo_atoms_collection @@ -898,17 +919,20 @@ def test_create_compound_index(self, _cleanup, _db: RedisMongoDB): assert my_index in collection_index_names def test_get_atoms_by_field_no_index(self, _cleanup, _db: RedisMongoDB): + pytest.skip("Requires new implementation since the new custom attributes were introduced.") db: RedisMongoDB = _db self._add_atoms(db) db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "human"}, - {"type": "Concept", "name": "monkey"}, - ], - "tag": "DAS", - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "human"}, + {"type": "Concept", "name": "monkey"}, + ], + "custom_attributes": {"tag": "DAS"}, + } + ) ) db.commit() @@ -920,17 +944,20 @@ def test_get_atoms_by_field_no_index(self, _cleanup, _db: RedisMongoDB): assert explain[0]["executionStats"]["totalKeysExamined"] == 0 def test_get_atoms_by_field_with_index(self, _cleanup, _db: RedisMongoDB): + pytest.skip("Requires new implementation since the new custom attributes were introduced.") db: RedisMongoDB = _db self._add_atoms(db) db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "human"}, - {"type": "Concept", "name": "monkey"}, - ], - "tag": "DAS", - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "human"}, + {"type": "Concept", "name": "monkey"}, + ], + "custom_attributes": {"tag": "DAS"}, + } + ) ) db.commit() my_index = db.create_field_index(atom_type="link", fields=["tag"]) @@ -953,26 +980,31 @@ def test_get_atoms_by_field_with_index(self, _cleanup, _db: RedisMongoDB): ) def test_get_atoms_by_index(self, _cleanup, _db: RedisMongoDB): + pytest.skip("Requires new implementation since the new custom attributes were introduced.") db: RedisMongoDB = _db db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "human"}, - {"type": "Concept", "name": "monkey"}, - ], - "tag": "DAS", - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "human"}, + {"type": "Concept", "name": "monkey"}, + ], + "custom_attributes": {"tag": "DAS"}, + } + ) ) db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "mammal"}, - {"type": "Concept", "name": "monkey"}, - ], - "tag": "DAS2", - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "mammal"}, + {"type": "Concept", "name": "monkey"}, + ], + "custom_attributes": {"tag": "DAS2"}, + } + ) ) db.commit() @@ -980,10 +1012,10 @@ def test_get_atoms_by_index(self, _cleanup, _db: RedisMongoDB): with PyMongoFindExplain(db.mongo_atoms_collection) as explain: _, doc = db.get_atoms_by_index(my_index, [{"field": "tag", "value": "DAS2"}]) - assert doc[0]["handle"] == ExpressionHasher.expression_hash( + assert doc[0].handle == ExpressionHasher.expression_hash( ExpressionHasher.named_type_hash("Similarity"), [mammal, monkey] ) - assert doc[0]["targets"] == [mammal, monkey] + assert doc[0].targets == [mammal, monkey] assert explain[0]["executionStats"]["executionSuccess"] assert explain[0]["executionStats"]["nReturned"] == 1 assert explain[0]["executionStats"]["executionStages"]["stage"] == "FETCH" @@ -1053,29 +1085,34 @@ def test_bulk_insert(self, _cleanup, _db: RedisMongoDB): assert db.count_atoms() == {"atom_count": 0} documents = [ - { - "_id": "node1", - "composite_type_hash": "ConceptHash", - "name": "human", - "named_type": "Concept", - }, - { - "_id": "node2", - "composite_type_hash": "ConceptHash", - "name": "monkey", - "named_type": "Concept", - }, - { - "_id": db.link_handle("Similarity", ["node1", "node2"]), - "composite_type_hash": "CompositeTypeHash", - "is_toplevel": True, - "composite_type": ["SimilarityHash", "ConceptHash", "ConceptHash"], - "named_type": "Similarity", - "named_type_hash": "SimilarityHash", - "key_0": "node1", - "key_1": "node2", - }, + NodeT( + _id="node1", + handle="node1", + composite_type_hash="ConceptHash", + name="human", + named_type="Concept", + ), + NodeT( + _id="node2", + handle="node2", + composite_type_hash="ConceptHash", + name="monkey", + named_type="Concept", + ), ] + handle = db.link_handle("Similarity", ["node1", "node2"]) + documents.append( + LinkT( + _id=handle, + handle=handle, + composite_type_hash="CompositeTypeHash", + is_toplevel=True, + composite_type=["SimilarityHash", "ConceptHash", "ConceptHash"], + named_type="Similarity", + named_type_hash="SimilarityHash", + targets=["node1", "node2"], + ), + ) db.bulk_insert(documents) @@ -1085,7 +1122,7 @@ def test_bulk_insert(self, _cleanup, _db: RedisMongoDB): } similarity = db.get_all_links("Similarity") assert similarity == {db.link_handle("Similarity", ["node1", "node2"])} - assert db.get_all_nodes("Concept") == ["node1", "node2"] + assert db.get_all_nodes_handles("Concept") == ["node1", "node2"] def test_retrieve_all_atoms(self, _cleanup, _db: RedisMongoDB): db = _db @@ -1095,7 +1132,7 @@ def test_retrieve_all_atoms(self, _cleanup, _db: RedisMongoDB): inheritance = db.get_all_links("Inheritance") similarity = db.get_all_links("Similarity") links = inheritance.union(similarity) - nodes = db.get_all_nodes("Concept") + nodes = db.get_all_nodes_handles("Concept") assert len(response) == len(links) + len(nodes) def test_add_fields_to_atoms(self, _cleanup, _db: RedisMongoDB): @@ -1108,29 +1145,34 @@ def test_add_fields_to_atoms(self, _cleanup, _db: RedisMongoDB): node_human = db.get_atom(human) - assert node_human["handle"] == human - assert node_human["name"] == "human" - assert node_human["named_type"] == "Concept" + assert node_human.handle == human + assert node_human.name == "human" + assert node_human.named_type == "Concept" - node_human["score"] = 0.5 + node_human_params = node_human + node_human_params.custom_attributes = {"score": 0.5} - db.add_node(node_human) + db.add_node(node_human_params) db.commit() - assert db.get_atom(human)["score"] == 0.5 + assert db.get_atom(human).custom_attributes["score"] == 0.5 link_similarity = db.get_atom(link_handle, deep_representation=True) - assert link_similarity["handle"] == link_handle - assert link_similarity["type"] == "Similarity" - assert link_similarity["targets"] == [db.get_atom(human), db.get_atom(monkey)] + assert link_similarity.handle == link_handle + assert link_similarity.named_type == "Similarity" + assert [target.to_dict() for target in link_similarity.targets_documents] == [ + db.get_atom(human).to_dict(), + db.get_atom(monkey).to_dict(), + ] - link_similarity["score"] = 0.5 + link_params = link_similarity + link_params.custom_attributes = {"score": 0.5} - db.add_link(link_similarity) + db.add_link(link_params) db.commit() - assert db.get_atom(link_handle)["score"] == 0.5 + assert db.get_atom(link_handle).custom_attributes["score"] == 0.5 def test_commit_with_buffer(self, _cleanup, _db: RedisMongoDB): db = _db @@ -1169,9 +1211,9 @@ def test_commit_with_buffer(self, _cleanup, _db: RedisMongoDB): "node_count": 2, "link_count": 1, } - assert db.get_atom("26d35e45817f4270f2b7cff971b04138")["name"] == "dog" - assert db.get_atom("b7db6a9ed2191eb77ee54479570db9a4")["name"] == "cat" - assert db.get_atom("3dab102938606f4549d68405ec9f4f61")["targets"] == [ + assert db.get_atom("26d35e45817f4270f2b7cff971b04138").name == "dog" + assert db.get_atom("b7db6a9ed2191eb77ee54479570db9a4").name == "cat" + assert db.get_atom("3dab102938606f4549d68405ec9f4f61").targets == [ "26d35e45817f4270f2b7cff971b04138", "b7db6a9ed2191eb77ee54479570db9a4", ] diff --git a/tests/integration/scripts/mongo-down.sh b/tests/integration/scripts/mongo-down.sh index b4429e2b..295c6771 100755 --- a/tests/integration/scripts/mongo-down.sh +++ b/tests/integration/scripts/mongo-down.sh @@ -10,4 +10,7 @@ fi echo "Destroying MongoDB container on port $PORT" -docker stop mongo_$PORT && docker rm mongo_$PORT && docker volume rm mongodbdata_$PORT >& /dev/null +docker stop mongo_$PORT && \ + docker kill mongo_$PORT && + docker rm mongo_$PORT && \ + docker volume rm mongodbdata_$PORT >& /dev/null diff --git a/tests/unit/adapters/test_ram_only.py b/tests/unit/adapters/test_ram_only.py index ffb655c7..98c0f401 100644 --- a/tests/unit/adapters/test_ram_only.py +++ b/tests/unit/adapters/test_ram_only.py @@ -1,13 +1,20 @@ +from typing import Any + import pytest from hyperon_das_atomdb import AtomDB -from hyperon_das_atomdb.adapters.ram_only import InMemoryDB +from hyperon_das_atomdb.adapters import InMemoryDB +from hyperon_das_atomdb.database import LinkT, NodeT from hyperon_das_atomdb.exceptions import AddLinkException, AddNodeException, AtomDoesNotExist from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher +from tests.helpers import dict_to_link_params, dict_to_node_params from tests.unit.fixtures import in_memory_db # noqa: F401 class TestInMemoryDB: + all_added_nodes = [] + all_added_links = [] + @pytest.fixture() def database(self, in_memory_db): # noqa: F811 import json @@ -17,13 +24,11 @@ def database(self, in_memory_db): # noqa: F811 db = in_memory_db with open(f"{path}/data/ram_only_nodes.json") as f: - all_nodes = json.load(f) + all_nodes: list[dict[str, Any]] = json.load(f) with open(f"{path}/data/ram_only_links.json") as f: - all_links = json.load(f) - for node in all_nodes: - db.add_node(node) - for link in all_links: - db.add_link(link) + all_links: list[dict[str, Any]] = json.load(f) + self.all_added_nodes = [db.add_node(dict_to_node_params(node)) for node in all_nodes] + self.all_added_links = [db.add_link(dict_to_link_params(link)) for link in all_links] yield db @pytest.mark.parametrize( @@ -63,10 +68,8 @@ def test_get_node_handle(self, node_type, node_name, expected, request): ], ) def test_get_node_handle_not_exist(self, node_type, node_name, database: InMemoryDB): - with pytest.raises(AtomDoesNotExist) as exc_info: + with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): database.get_node_handle(node_type=node_type, node_name=node_name) - assert exc_info.type is AtomDoesNotExist - assert exc_info.value.args[0] == "Nonexistent atom" @pytest.mark.parametrize( "targets,link_type,expected", @@ -105,10 +108,8 @@ def test_get_link_handle(self, targets, link_type, expected, request): ], ) def test_get_link_handle_not_exist(self, link_type, target_handles, database: InMemoryDB): - with pytest.raises(AtomDoesNotExist) as exc_info: + with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): database.get_link_handle(link_type=link_type, target_handles=target_handles) - assert exc_info.type is AtomDoesNotExist - assert exc_info.value.args[0] == "Nonexistent atom" def test_node_exists_true(self, database: InMemoryDB): ret = database.node_exists(node_type="Concept", node_name="human") @@ -144,10 +145,8 @@ def test_get_link_targets(self, targets, link_type, database: InMemoryDB): assert ret == target_handles def test_get_link_targets_invalid(self, database: InMemoryDB): - with pytest.raises(AtomDoesNotExist) as exc_info: + with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): database.get_link_targets("link_handle_Fake") - assert exc_info.type is AtomDoesNotExist - assert exc_info.value.args[0] == "Nonexistent atom" @pytest.mark.parametrize( "targets,link_type,expected", @@ -233,34 +232,36 @@ def test_get_matched_links_link_does_not_exist(self, database: InMemoryDB): def test_get_matched_links_toplevel_only(self, database: InMemoryDB): database.add_link( - { - "type": "Evaluation", - "targets": [ - {"type": "Predicate", "name": "Predicate:has_name"}, - { - "type": "Evaluation", - "targets": [ - { - "type": "Predicate", - "name": "Predicate:has_name", - }, - { - "targets": [ - { - "type": "Reactome", - "name": "Reactome:R-HSA-164843", - }, - { - "type": "Concept", - "name": "Concept:2-LTR circle formation", - }, - ], - "type": "Set", - }, - ], - }, - ], - } + dict_to_link_params( + { + "type": "Evaluation", + "targets": [ + {"type": "Predicate", "name": "Predicate:has_name"}, + { + "type": "Evaluation", + "targets": [ + { + "type": "Predicate", + "name": "Predicate:has_name", + }, + { + "targets": [ + { + "type": "Reactome", + "name": "Reactome:R-HSA-164843", + }, + { + "type": "Concept", + "name": "Concept:2-LTR circle formation", + }, + ], + "type": "Set", + }, + ], + }, + ], + } + ) ) expected = {"661fb5a7c90faabfeada7e1f63805fc0"} actual = database.get_matched_links("Evaluation", ["*", "*"], toplevel_only=True) @@ -269,34 +270,36 @@ def test_get_matched_links_toplevel_only(self, database: InMemoryDB): def test_get_matched_links_wrong_parameter(self, database: InMemoryDB): database.add_link( - { - "type": "Evaluation", - "targets": [ - {"type": "Predicate", "name": "Predicate:has_name"}, - { - "type": "Evaluation", - "targets": [ - { - "type": "Predicate", - "name": "Predicate:has_name", - }, - { - "targets": [ - { - "type": "Reactome", - "name": "Reactome:R-HSA-164843", - }, - { - "type": "Concept", - "name": "Concept:2-LTR circle formation", - }, - ], - "type": "Set", - }, - ], - }, - ], - } + dict_to_link_params( + { + "type": "Evaluation", + "targets": [ + {"type": "Predicate", "name": "Predicate:has_name"}, + { + "type": "Evaluation", + "targets": [ + { + "type": "Predicate", + "name": "Predicate:has_name", + }, + { + "targets": [ + { + "type": "Reactome", + "name": "Reactome:R-HSA-164843", + }, + { + "type": "Concept", + "name": "Concept:2-LTR circle formation", + }, + ], + "type": "Set", + }, + ], + }, + ], + } + ) ) actual = database.get_matched_links("Evaluation", ["*", "*"], toplevel=True) assert len(actual) == 2 @@ -309,25 +312,27 @@ def test_get_matched_links_wrong_parameter(self, database: InMemoryDB): ) def test_get_matched_links_nested_lists(self, database: InMemoryDB): database.add_link( - { - "type": "Connectivity", - "targets": [ - { - "type": "Nearness", - "targets": [ - {"type": "Concept", "name": "chimp"}, - {"type": "Concept", "name": "human"}, - ], - }, - { - "type": "Nearness", - "targets": [ - {"type": "Concept", "name": "chimp"}, - {"type": "Concept", "name": "monkey"}, - ], - }, - ], - } + dict_to_link_params( + { + "type": "Connectivity", + "targets": [ + { + "type": "Nearness", + "targets": [ + {"type": "Concept", "name": "chimp"}, + {"type": "Concept", "name": "human"}, + ], + }, + { + "type": "Nearness", + "targets": [ + {"type": "Concept", "name": "chimp"}, + {"type": "Concept", "name": "monkey"}, + ], + }, + ], + } + ) ) chimp = ExpressionHasher.terminal_hash("Concept", "chimp") human = ExpressionHasher.terminal_hash("Concept", "human") @@ -345,11 +350,11 @@ def test_get_matched_links_nested_lists(self, database: InMemoryDB): assert len(links) == 1 def test_get_all_nodes(self, database): - ret = database.get_all_nodes("Concept") + ret = database.get_all_nodes_handles("Concept") assert len(ret) == 14 - ret = database.get_all_nodes("Concept", True) + ret = database.get_all_nodes_names("Concept") assert len(ret) == 14 - ret = database.get_all_nodes("ConceptFake") + ret = database.get_all_nodes_handles("ConceptFake") assert len(ret) == 0 def test_get_matched_type_template(self, database: InMemoryDB): @@ -368,25 +373,27 @@ def test_get_matched_type_template(self, database: InMemoryDB): def test_get_matched_type_template_toplevel_only(self, database: InMemoryDB): database.add_link( - { - "type": "Evaluation", - "targets": [ - {"type": "Predicate", "name": "Predicate:has_name"}, - { - "type": "Evaluation", - "targets": [ - { - "type": "Reactome", - "name": "Reactome:R-HSA-164843", - }, - { - "type": "Concept", - "name": "Concept:2-LTR circle formation", - }, - ], - }, - ], - } + dict_to_link_params( + { + "type": "Evaluation", + "targets": [ + {"type": "Predicate", "name": "Predicate:has_name"}, + { + "type": "Evaluation", + "targets": [ + { + "type": "Reactome", + "name": "Reactome:R-HSA-164843", + }, + { + "type": "Concept", + "name": "Concept:2-LTR circle formation", + }, + ], + }, + ], + } + ) ) ret = database.get_matched_type_template( @@ -407,25 +414,27 @@ def test_get_matched_type(self, database: InMemoryDB): def test_get_matched_type_toplevel_only(self, database: InMemoryDB): database.add_link( - { - "type": "EvaluationLink", - "targets": [ - {"type": "Predicate", "name": "Predicate:has_name"}, - { - "type": "EvaluationLink", - "targets": [ - { - "type": "Reactome", - "name": "Reactome:R-HSA-164843", - }, - { - "type": "Concept", - "name": "Concept:2-LTR circle formation", - }, - ], - }, - ], - } + dict_to_link_params( + { + "type": "EvaluationLink", + "targets": [ + {"type": "Predicate", "name": "Predicate:has_name"}, + { + "type": "EvaluationLink", + "targets": [ + { + "type": "Reactome", + "name": "Reactome:R-HSA-164843", + }, + { + "type": "Concept", + "name": "Concept:2-LTR circle formation", + }, + ], + }, + ], + } + ) ) ret = database.get_matched_type("EvaluationLink") assert len(ret) == 2 @@ -440,10 +449,8 @@ def test_get_node_name(self, database): assert db_name == "monkey" def test_get_node_name_error(self, database): - with pytest.raises(AtomDoesNotExist) as exc_info: + with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): database.get_node_name("handle-test") - assert exc_info.type is AtomDoesNotExist - assert exc_info.value.args[0] == "Nonexistent atom" def test_get_node_type(self, database): handle = database.get_node_handle("Concept", "monkey") @@ -452,10 +459,8 @@ def test_get_node_type(self, database): assert db_type == "Concept" def test_get_node_type_error(self, database): - with pytest.raises(AtomDoesNotExist) as exc_info: + with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): database.get_node_type("handle-test") - assert exc_info.type is AtomDoesNotExist - assert exc_info.value.args[0] == "Nonexistent atom" def test_get_matched_node_name(self, database: InMemoryDB): expected = sorted( @@ -472,78 +477,74 @@ def test_get_matched_node_name(self, database: InMemoryDB): assert sorted(database.get_node_by_name("Concept", "blah")) == [] def test_add_node_without_type_parameter(self, database: InMemoryDB): - with pytest.raises(AddNodeException) as exc_info: - database.add_node({"color": "red", "name": "car"}) - assert exc_info.type is AddNodeException - assert exc_info.value.args[0] == 'The "name" and "type" fields must be sent' + with pytest.raises(AddNodeException, match="'type' and 'name' are required."): + database.add_node(dict_to_node_params({"type": "", "name": "car"})) def test_add_node_without_name_parameter(self, database: InMemoryDB): - with pytest.raises(AddNodeException) as exc_info: - database.add_node({"type": "Concept", "color": "red"}) - assert exc_info.type is AddNodeException - assert exc_info.value.args[0] == 'The "name" and "type" fields must be sent' + with pytest.raises(AddNodeException, match="'type' and 'name' are required."): + database.add_node(dict_to_node_params({"type": "Concept", "name": ""})) def test_add_node(self, database: InMemoryDB): - assert len(database.get_all_nodes("Concept")) == 14 - database.add_node({"type": "Concept", "name": "car"}) - assert len(database.get_all_nodes("Concept")) == 15 + assert len(database.get_all_nodes_handles("Concept")) == 14 + database.add_node(dict_to_node_params({"type": "Concept", "name": "car"})) + assert len(database.get_all_nodes_handles("Concept")) == 15 node_handle = database.get_node_handle("Concept", "car") node_name = database.get_node_name(node_handle) assert node_name == "car" def test_add_link_without_type_parameter(self, database: InMemoryDB): - with pytest.raises(AddLinkException) as exc_info: + with pytest.raises(AddLinkException, match="'type' and 'targets' are required."): database.add_link( - { - "targets": [ - {"type": "Concept", "name": "human"}, - {"type": "Concept", "name": "monkey"}, - ], - "quantity": 2, - } + dict_to_link_params( + { + "targets": [ + {"type": "Concept", "name": "human"}, + {"type": "Concept", "name": "monkey"}, + ], + "type": "", + } + ) ) - assert exc_info.type is AddLinkException - assert exc_info.value.args[0] == 'The "type" and "targets" fields must be sent' def test_add_link_without_targets_parameter(self, database: InMemoryDB): - with pytest.raises(AddLinkException) as exc_info: - database.add_link({"source": "fake", "type": "Similarity"}) - assert exc_info.type is AddLinkException - assert exc_info.value.args[0] == 'The "type" and "targets" fields must be sent' + with pytest.raises(AddLinkException, match="'type' and 'targets' are required."): + database.add_link(dict_to_link_params({"targets": [], "type": "Similarity"})) def test_add_nested_links(self, database: InMemoryDB): answer = database.get_matched_type("Evaluation") assert len(answer) == 0 database.add_link( - { - "type": "Evaluation", - "targets": [ - {"type": "Predicate", "name": "Predicate:has_name"}, - { - "type": "Evaluation", - "targets": [ - { - "type": "Predicate", - "name": "Predicate:has_name", - }, - { - "targets": [ - { - "type": "Reactome", - "name": "Reactome:R-HSA-164843", - }, - { - "type": "Concept", - "name": "Concept:2-LTR circle formation", - }, - ], - "type": "Set", - }, - ], - }, - ], - } + dict_to_link_params( + { + "type": "Evaluation", + "targets": [ + {"type": "Predicate", "name": "Predicate:has_name"}, + { + "type": "Evaluation", + "targets": [ + { + "type": "Predicate", + "name": "Predicate:has_name", + }, + { + "targets": [ + { + "type": "Reactome", + "name": "Reactome:R-HSA-164843", + }, + { + "type": "Concept", + "name": "Concept:2-LTR circle formation", + }, + ], + "type": "Set", + }, + ], + }, + ], + } + ) ) answer = database.get_matched_type("Evaluation") assert len(answer) == 2 @@ -557,6 +558,7 @@ def test_get_link_type(self, database: InMemoryDB): ret = database.get_link_type(link_handle=link_handle) assert ret == "Similarity" + @pytest.mark.skip("Removed from C++ implementation") def test_build_targets_list(self, database: InMemoryDB): targets = database._build_targets_list( { @@ -583,19 +585,19 @@ def test_get_atom(self, database: InMemoryDB): m = database.get_node_handle("Concept", "monkey") s = database.get_link_handle("Similarity", [h, m]) atom = database.get_atom(handle=s) - assert atom["handle"] == s - assert atom["targets"] == [h, m] + assert atom.handle == s + assert atom.targets == [h, m] with pytest.raises(AtomDoesNotExist) as exc: database.get_atom(handle="test") - assert exc.value.message == "Nonexistent atom" - assert exc.value.details == "handle: test" + assert "Nonexistent atom" in str(exc.value) + assert "handle: test" in str(exc.value) def test_get_atom_as_dict(self, database: InMemoryDB): h = database.get_node_handle("Concept", "human") m = database.get_node_handle("Concept", "monkey") s = database.get_link_handle("Similarity", [h, m]) - atom = database.get_atom_as_dict(handle=s) + atom = database.get_atom(handle=s).to_dict() assert atom["handle"] == s assert atom["targets"] == [h, m] @@ -604,25 +606,19 @@ def test_get_incoming_links(self, database: InMemoryDB): m = database.get_node_handle("Concept", "monkey") s = database.get_link_handle("Similarity", [h, m]) - links = database.get_incoming_links(atom_handle=h, handles_only=False) + links = database.get_incoming_links_atoms(atom_handle=h) atom = database.get_atom(handle=s) assert atom in links - links = database.get_incoming_links( - atom_handle=h, handles_only=False, targets_document=True - ) + links = database.get_incoming_links_atoms(atom_handle=h, targets_document=True) for link in links: - for a, b in zip(link["targets"], link["targets_document"]): - assert a == b["handle"] + for a, b in zip(link.targets, link.targets_documents): + assert a == b.handle - links = database.get_incoming_links(atom_handle=h, handles_only=True) - assert links == list(database.db.incoming_set.get(h)) + links = database.get_incoming_links_handles(atom_handle=h) assert s in links - links = database.get_incoming_links(atom_handle=m, handles_only=True) - assert links == list(database.db.incoming_set.get(m)) - - links = database.get_incoming_links(atom_handle=s, handles_only=True) + links = database.get_incoming_links_handles(atom_handle=s) assert links == [] def test_get_atom_type(self, database: InMemoryDB): @@ -657,378 +653,157 @@ def test_delete_atom(self): assert db.count_atoms() == {"atom_count": 0, "node_count": 0, "link_count": 0} db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "dog"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "dog"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) assert db.count_atoms() == {"atom_count": 5, "node_count": 3, "link_count": 2} - assert db.db.incoming_set == { - dog_handle: {inheritance_dog_mammal_handle}, - cat_handle: {inheritance_cat_mammal_handle}, - mammal_handle: { - inheritance_cat_mammal_handle, - inheritance_dog_mammal_handle, - }, - } - assert db.db.outgoing_set == { - inheritance_dog_mammal_handle: [dog_handle, mammal_handle], - inheritance_cat_mammal_handle: [cat_handle, mammal_handle], - } - assert db.db.templates == { - "41c082428b28d7e9ea96160f7fd614ad": { - inheritance_cat_mammal_handle, - inheritance_dog_mammal_handle, - }, - "e40489cd1e7102e35469c937e05c8bba": { - inheritance_cat_mammal_handle, - inheritance_dog_mammal_handle, - }, - } - assert db.db.patterns == { - "6e644e70a9fe3145c88b5b6261af5754": { - inheritance_cat_mammal_handle, - inheritance_dog_mammal_handle, - }, - "5dd515aa7a451276feac4f8b9d84ae91": { - inheritance_cat_mammal_handle, - inheritance_dog_mammal_handle, - }, - "a11d7cbf62bc544f75702b5fb6a514ff": { - inheritance_cat_mammal_handle, - }, - "f29daafee640d91aa7091e44551fc74a": { - inheritance_cat_mammal_handle, - }, - "7ead6cfa03894c62761162b7603aa885": { - inheritance_cat_mammal_handle, - inheritance_dog_mammal_handle, - }, - "112002ff70ea491aad735f978e9d95f5": { - inheritance_cat_mammal_handle, - inheritance_dog_mammal_handle, - }, - "3ba42d45a50c89600d92fb3f1a46c1b5": { - inheritance_cat_mammal_handle, - }, - "e55007a8477a4e6bf4fec76e4ffd7e10": { - inheritance_dog_mammal_handle, - }, - "23dc149b3218d166a14730db55249126": { - inheritance_dog_mammal_handle, - }, - "399751d7319f9061d97cd1d75728b66b": { - inheritance_dog_mammal_handle, - }, - } db.delete_atom(inheritance_cat_mammal_handle) db.delete_atom(inheritance_dog_mammal_handle) assert db.count_atoms() == {"atom_count": 3, "node_count": 3, "link_count": 0} - assert db.db.incoming_set == { - dog_handle: set(), - cat_handle: set(), - mammal_handle: set(), - } - assert db.db.outgoing_set == {} - assert db.db.templates == { - "41c082428b28d7e9ea96160f7fd614ad": set(), - "e40489cd1e7102e35469c937e05c8bba": set(), - } - assert db.db.patterns == { - "6e644e70a9fe3145c88b5b6261af5754": set(), - "5dd515aa7a451276feac4f8b9d84ae91": set(), - "a11d7cbf62bc544f75702b5fb6a514ff": set(), - "f29daafee640d91aa7091e44551fc74a": set(), - "7ead6cfa03894c62761162b7603aa885": set(), - "112002ff70ea491aad735f978e9d95f5": set(), - "3ba42d45a50c89600d92fb3f1a46c1b5": set(), - "e55007a8477a4e6bf4fec76e4ffd7e10": set(), - "23dc149b3218d166a14730db55249126": set(), - "399751d7319f9061d97cd1d75728b66b": set(), - } db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "dog"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "dog"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.delete_atom(mammal_handle) assert db.count_atoms() == {"atom_count": 2, "node_count": 2, "link_count": 0} - assert db.db.incoming_set == { - dog_handle: set(), - cat_handle: set(), - } - assert db.db.outgoing_set == {} - assert db.db.templates == { - "41c082428b28d7e9ea96160f7fd614ad": set(), - "e40489cd1e7102e35469c937e05c8bba": set(), - } - assert db.db.patterns == { - "6e644e70a9fe3145c88b5b6261af5754": set(), - "5dd515aa7a451276feac4f8b9d84ae91": set(), - "a11d7cbf62bc544f75702b5fb6a514ff": set(), - "f29daafee640d91aa7091e44551fc74a": set(), - "7ead6cfa03894c62761162b7603aa885": set(), - "112002ff70ea491aad735f978e9d95f5": set(), - "3ba42d45a50c89600d92fb3f1a46c1b5": set(), - "e55007a8477a4e6bf4fec76e4ffd7e10": set(), - "23dc149b3218d166a14730db55249126": set(), - "399751d7319f9061d97cd1d75728b66b": set(), - } db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "dog"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "dog"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.delete_atom(cat_handle) assert db.count_atoms() == {"atom_count": 3, "node_count": 2, "link_count": 1} - assert db.db.incoming_set == { - dog_handle: {inheritance_dog_mammal_handle}, - mammal_handle: {inheritance_dog_mammal_handle}, - } - assert db.db.outgoing_set == {inheritance_dog_mammal_handle: [dog_handle, mammal_handle]} - assert db.db.templates == { - "41c082428b28d7e9ea96160f7fd614ad": { - inheritance_dog_mammal_handle, - }, - "e40489cd1e7102e35469c937e05c8bba": { - inheritance_dog_mammal_handle, - }, - } - assert db.db.patterns == { - "6e644e70a9fe3145c88b5b6261af5754": { - inheritance_dog_mammal_handle, - }, - "5dd515aa7a451276feac4f8b9d84ae91": { - inheritance_dog_mammal_handle, - }, - "a11d7cbf62bc544f75702b5fb6a514ff": set(), - "f29daafee640d91aa7091e44551fc74a": set(), - "7ead6cfa03894c62761162b7603aa885": { - inheritance_dog_mammal_handle, - }, - "3ba42d45a50c89600d92fb3f1a46c1b5": set(), - "112002ff70ea491aad735f978e9d95f5": { - inheritance_dog_mammal_handle, - }, - "e55007a8477a4e6bf4fec76e4ffd7e10": { - inheritance_dog_mammal_handle, - }, - "23dc149b3218d166a14730db55249126": { - inheritance_dog_mammal_handle, - }, - "399751d7319f9061d97cd1d75728b66b": { - inheritance_dog_mammal_handle, - }, - } db.add_link( - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.delete_atom(dog_handle) assert db.count_atoms() == {"atom_count": 3, "node_count": 2, "link_count": 1} - assert db.db.incoming_set == { - cat_handle: {inheritance_cat_mammal_handle}, - mammal_handle: {inheritance_cat_mammal_handle}, - } - assert db.db.outgoing_set == {inheritance_cat_mammal_handle: [cat_handle, mammal_handle]} - assert db.db.templates == { - "41c082428b28d7e9ea96160f7fd614ad": { - inheritance_cat_mammal_handle, - }, - "e40489cd1e7102e35469c937e05c8bba": { - inheritance_cat_mammal_handle, - }, - } - assert db.db.patterns == { - "6e644e70a9fe3145c88b5b6261af5754": { - inheritance_cat_mammal_handle, - }, - "5dd515aa7a451276feac4f8b9d84ae91": { - inheritance_cat_mammal_handle, - }, - "a11d7cbf62bc544f75702b5fb6a514ff": { - inheritance_cat_mammal_handle, - }, - "f29daafee640d91aa7091e44551fc74a": { - inheritance_cat_mammal_handle, - }, - "7ead6cfa03894c62761162b7603aa885": { - inheritance_cat_mammal_handle, - }, - "112002ff70ea491aad735f978e9d95f5": { - inheritance_cat_mammal_handle, - }, - "3ba42d45a50c89600d92fb3f1a46c1b5": { - inheritance_cat_mammal_handle, - }, - "e55007a8477a4e6bf4fec76e4ffd7e10": set(), - "23dc149b3218d166a14730db55249126": set(), - "399751d7319f9061d97cd1d75728b66b": set(), - } db.clear_database() db.add_link( - { - "type": "Inheritance", - "targets": [ - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "dog"}, - { - "type": "Inheritance", - "targets": [ - {"type": "Concept", "name": "cat"}, - {"type": "Concept", "name": "mammal"}, - ], - }, - ], - }, - {"type": "Concept", "name": "mammal"}, - ], - } + dict_to_link_params( + { + "type": "Inheritance", + "targets": [ + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "dog"}, + { + "type": "Inheritance", + "targets": [ + {"type": "Concept", "name": "cat"}, + {"type": "Concept", "name": "mammal"}, + ], + }, + ], + }, + {"type": "Concept", "name": "mammal"}, + ], + } + ) ) db.delete_atom(inheritance_cat_mammal_handle) assert db.count_atoms() == {"atom_count": 3, "node_count": 3, "link_count": 0} - assert db.db.incoming_set == { - dog_handle: set(), - cat_handle: set(), - mammal_handle: set(), - } - assert db.db.outgoing_set == {} - assert db.db.templates == { - "41c082428b28d7e9ea96160f7fd614ad": set(), - "e40489cd1e7102e35469c937e05c8bba": set(), - "62bcbcec7fdc1bf896c0c9c99fe2f6b6": set(), - "451c57cb0a3d43eb9ca208aebe11cf9e": set(), - } - assert db.db.patterns == { - "6e644e70a9fe3145c88b5b6261af5754": set(), - "5dd515aa7a451276feac4f8b9d84ae91": set(), - "a11d7cbf62bc544f75702b5fb6a514ff": set(), - "f29daafee640d91aa7091e44551fc74a": set(), - "7ead6cfa03894c62761162b7603aa885": set(), - "112002ff70ea491aad735f978e9d95f5": set(), - "3ba42d45a50c89600d92fb3f1a46c1b5": set(), - "1515eec36602aa53aa58a132cad99564": set(), - "e55007a8477a4e6bf4fec76e4ffd7e10": set(), - "1a81db4866eb3cc14dae6fd5a732a0b5": set(), - "113b45c48122d22790870abb1152f218": set(), - "399751d7319f9061d97cd1d75728b66b": set(), - "3b23b5e8ecf01bb53c1e531018ee3b2a": set(), - "1a8d5143240997c7179d99c846812ee1": set(), - "1be2f1be6e8a65d5ddd8a9efbfb93233": set(), - } def test_add_link_that_already_exists(self): db = InMemoryDB() db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Test", "name": "test_1"}, - {"type": "Test", "name": "test_2"}, - ], - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Test", "name": "test_1"}, + {"type": "Test", "name": "test_2"}, + ], + } + ) ) + assert db.count_atoms() == {"atom_count": 3, "node_count": 2, "link_count": 1} + db.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Test", "name": "test_1"}, - {"type": "Test", "name": "test_2"}, - ], - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Test", "name": "test_1"}, + {"type": "Test", "name": "test_2"}, + ], + } + ) ) - - assert db.db.incoming_set["167a378d17b1eda5587292814c8d0769"] == { - "4a7f5140c0017fe270c8693605fd000a" - } - assert db.db.incoming_set["e24c839b9ffaf295c5d9be05171cf5d1"] == { - "4a7f5140c0017fe270c8693605fd000a" - } - - assert db.db.patterns["6e644e70a9fe3145c88b5b6261af5754"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - assert db.db.patterns["dab80dcb22dc4b246e3f8642a4e99449"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - assert db.db.patterns["957e33112374129ee9a7afacc702fe33"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - assert db.db.patterns["7fc3951816751ca77e6e14efecff2529"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - assert db.db.patterns["c48b5236102ae75ba3e71729a6bfa2e5"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - assert db.db.patterns["699ac93da51eeb8d573f9a20d7e81010"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - assert db.db.patterns["7d277b5039fb500cbf51806d06dbdc78"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - - assert db.db.templates["4c201422342d157b2dded43181e7782d"] == { - "4a7f5140c0017fe270c8693605fd000a", - } - assert db.db.templates["a9dea78180588431ec64d6bc4872fdbc"] == { - "4a7f5140c0017fe270c8693605fd000a", - } + assert db.count_atoms() == {"atom_count": 3, "node_count": 2, "link_count": 1} def test_bulk_insert(self): db = InMemoryDB() @@ -1036,28 +811,30 @@ def test_bulk_insert(self): assert db.count_atoms() == {"atom_count": 0, "node_count": 0, "link_count": 0} documents = [ - { - "_id": "node1", - "composite_type_hash": "ConceptHash", - "name": "human", - "named_type": "Concept", - }, - { - "_id": "node2", - "composite_type_hash": "ConceptHash", - "name": "monkey", - "named_type": "Concept", - }, - { - "_id": "link1", - "composite_type_hash": "CompositeTypeHash", - "is_toplevel": True, - "composite_type": ["SimilarityHash", "ConceptHash", "ConceptHash"], - "named_type": "Similarity", - "named_type_hash": "SimilarityHash", - "key_0": "node1", - "key_1": "node2", - }, + NodeT( + _id="node1", + handle="node1", + composite_type_hash="ConceptHash", + name="human", + named_type="Concept", + ), + NodeT( + _id="node2", + handle="node2", + composite_type_hash="ConceptHash", + name="monkey", + named_type="Concept", + ), + LinkT( + _id="link1", + handle="link1", + composite_type_hash="CompositeTypeHash", + is_toplevel=True, + composite_type=["SimilarityHash", "ConceptHash", "ConceptHash"], + named_type="Similarity", + named_type_hash="SimilarityHash", + targets=["node1", "node2"], + ), ] db.bulk_insert(documents) @@ -1065,9 +842,17 @@ def test_bulk_insert(self): assert db.count_atoms() == {"atom_count": 3, "node_count": 2, "link_count": 1} def test_retrieve_all_atoms(self, database: InMemoryDB): - expected = list(database.db.node.values()) + list(database.db.link.values()) + expected = self.all_added_nodes + self.all_added_links + assert len(expected) == len(self.all_added_nodes + self.all_added_links) actual = database.retrieve_all_atoms() - database.clear_database() - atoms = database.retrieve_all_atoms() - assert len(atoms) == 0 - assert expected == actual + assert len(expected) == len(actual) + assert sorted([e.handle for e in expected]) == sorted([a.handle for a in actual]) + assert sorted([e.to_dict() for e in expected], key=lambda d: d["handle"]) == sorted( + [a.to_dict() for a in actual], key=lambda d: d["handle"] + ) + assert len(expected) == len(set([e.handle for e in expected])) + assert sorted([e.handle for e in expected]) == sorted( + list(set([e.handle for e in expected])) + ) + assert len(actual) == len(set([a.handle for a in actual])) + assert sorted([a.handle for a in actual]) == sorted(list(set([a.handle for a in actual]))) diff --git a/tests/unit/adapters/test_ram_only_extra.py b/tests/unit/adapters/test_ram_only_extra.py index 3be8add3..3e07530c 100644 --- a/tests/unit/adapters/test_ram_only_extra.py +++ b/tests/unit/adapters/test_ram_only_extra.py @@ -1,8 +1,11 @@ +import pytest + from hyperon_das_atomdb.adapters.ram_only import InMemoryDB from tests.unit.fixtures import in_memory_db # noqa: F401 from tests.unit.test_database_public_methods import check_handle +@pytest.mark.skip("testing protected members - must be moved to the C++ implementation.") class TestRamOnlyExtra: def test__build_atom_type_key_hash(self, in_memory_db): # noqa: F811 db: InMemoryDB = in_memory_db diff --git a/tests/unit/adapters/test_redis_mongo_db.py b/tests/unit/adapters/test_redis_mongo_db.py index ee9eef61..b33d67dc 100644 --- a/tests/unit/adapters/test_redis_mongo_db.py +++ b/tests/unit/adapters/test_redis_mongo_db.py @@ -7,9 +7,10 @@ from hyperon_das_atomdb.adapters import RedisMongoDB from hyperon_das_atomdb.adapters.redis_mongo_db import KeyPrefix -from hyperon_das_atomdb.database import FieldIndexType, FieldNames +from hyperon_das_atomdb.database import FieldIndexType, FieldNames, LinkT from hyperon_das_atomdb.exceptions import AtomDoesNotExist from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher +from tests.helpers import dict_to_link_params, dict_to_node_params from tests.unit.fixtures import redis_mongo_db # noqa: F401 FILE_CACHE = {} @@ -31,9 +32,10 @@ def database(self, redis_mongo_db: RedisMongoDB): # noqa: F811 atoms = loader("atom_mongo_redis.json") for atom in atoms: if "name" in atom: - redis_mongo_db.add_node(atom) + redis_mongo_db.add_node(dict_to_node_params(atom)) else: - redis_mongo_db.add_link(atom, toplevel=atom["is_toplevel"]) + is_toplevel = atom.pop("is_toplevel", True) + redis_mongo_db.add_link(dict_to_link_params(atom), toplevel=is_toplevel) redis_mongo_db.commit() yield redis_mongo_db @@ -228,7 +230,10 @@ def test_get_all_nodes(self, node_type, names, expected, database: RedisMongoDB) "Returning links, also break if it's a link and name is true" "https://github.com/singnet/das-atom-db/issues/210" ) - ret = database.get_all_nodes(node_type, names=names) + if names: + ret = database.get_all_nodes_names(node_type) + else: + ret = database.get_all_nodes_handles(node_type) assert len(ret) == expected def test_get_matched_type_template(self, database: RedisMongoDB): @@ -393,7 +398,7 @@ def test_get_atoms_by_index(self, atom_type, fields, query, expected, database: cursor, actual = database.get_atoms_by_index(result, query) assert cursor == 0 assert isinstance(actual, list) - assert all([a["handle"] in expected for a in actual]) + assert all([a.handle in expected for a in actual]) @pytest.mark.parametrize( "text_value,field,expected", @@ -410,12 +415,15 @@ def test_get_node_by_text_field(self, text_value, field, expected, database: Red [ (ExpressionHasher.terminal_hash("Concept", "monkey"), "Concept"), (ExpressionHasher.terminal_hash("Concept", "human"), "Concept"), - ("b5459e299a5c5e8662c427f7e01b3bf1", "Similarity"), # NOTE: Should break? + ("b5459e299a5c5e8662c427f7e01b3bf1", None), # Similarity handle ], ) def test_get_node_type(self, handle, expected, database: RedisMongoDB): resp_node = database.get_node_type(handle) - assert expected == resp_node + if expected is None: + assert resp_node is None + else: + assert expected == resp_node def test_get_node_type_without_cache(self, database: RedisMongoDB): from hyperon_das_atomdb.adapters import redis_mongo_db # noqa: F811 @@ -428,14 +436,17 @@ def test_get_node_type_without_cache(self, database: RedisMongoDB): @pytest.mark.parametrize( "handle,expected", [ - (ExpressionHasher.terminal_hash("Concept", "monkey"), "Concept"), # NOTE: Should break? - (ExpressionHasher.terminal_hash("Concept", "human"), "Concept"), # NOTE: Should break? + (ExpressionHasher.terminal_hash("Concept", "monkey"), None), + (ExpressionHasher.terminal_hash("Concept", "human"), None), ("b5459e299a5c5e8662c427f7e01b3bf1", "Similarity"), ], ) def test_get_link_type(self, handle, expected, database: RedisMongoDB): resp_link = database.get_link_type(handle) - assert expected == resp_link + if expected is None: + assert resp_link is None + else: + assert expected == resp_link def test_get_link_type_without_cache(self, database: RedisMongoDB): from hyperon_das_atomdb.adapters import redis_mongo_db # noqa: F811 @@ -457,15 +468,10 @@ def test_atom_count_fast(self, database: RedisMongoDB): def test_add_node(self, database: RedisMongoDB): assert {"atom_count": 42} == database.count_atoms() - all_nodes_before = database.get_all_nodes("Concept") - database.add_node( - { - "type": "Concept", - "name": "lion", - } - ) + all_nodes_before = database.get_all_nodes_handles("Concept") + database.add_node(dict_to_node_params({"type": "Concept", "name": "lion"})) database.commit() - all_nodes_after = database.get_all_nodes("Concept") + all_nodes_after = database.get_all_nodes_handles("Concept") assert len(all_nodes_before) == 14 assert len(all_nodes_after) == 15 assert { @@ -478,29 +484,45 @@ def test_add_node(self, database: RedisMongoDB): assert new_node_handle not in all_nodes_before assert new_node_handle in all_nodes_after new_node = database.get_atom(new_node_handle) - assert new_node["handle"] == new_node_handle - assert new_node["named_type"] == "Concept" - assert new_node["name"] == "lion" + assert new_node.handle == new_node_handle + assert new_node.named_type == "Concept" + assert new_node.name == "lion" def test_add_link(self, database: RedisMongoDB): assert {"atom_count": 42} == database.count_atoms() - all_nodes_before = database.get_all_nodes("Concept") + all_nodes_before = database.get_all_nodes_handles("Concept") similarity = database.get_all_links("Similarity") inheritance = database.get_all_links("Inheritance") evaluation = database.get_all_links("Evaluation") all_links_before = similarity.union(inheritance).union(evaluation) database.add_link( - { - "type": "Similarity", - "targets": [ - {"type": "Concept", "name": "lion"}, - {"type": "Concept", "name": "cat"}, - ], - } + dict_to_link_params( + { + "type": "Similarity", + "targets": [ + {"type": "Concept", "name": "lion"}, + {"type": "Concept", "name": "cat"}, + { + "type": "Dumminity", + "targets": [ + {"type": "Dummy", "name": "dummy1"}, + {"type": "Dummy", "name": "dummy2"}, + { + "type": "Anidity", + "targets": [ + {"type": "Any", "name": "any1"}, + {"type": "Any", "name": "any2"}, + ], + }, + ], + }, + ], + } + ) ) database.commit() - all_nodes_after = database.get_all_nodes("Concept") + all_nodes_after = database.get_all_nodes_handles("Concept") similarity = database.get_all_links("Similarity") inheritance = database.get_all_links("Inheritance") evaluation = database.get_all_links("Evaluation") @@ -510,9 +532,9 @@ def test_add_link(self, database: RedisMongoDB): assert len(all_links_before) == 28 assert len(all_links_after) == 29 assert { - "atom_count": 45, - "node_count": 16, - "link_count": 29, + "atom_count": 51, + "node_count": 20, + "link_count": 31, } == database.count_atoms({"precise": True}) new_node_handle = database.get_node_handle("Concept", "lion") @@ -520,18 +542,18 @@ def test_add_link(self, database: RedisMongoDB): assert new_node_handle not in all_nodes_before assert new_node_handle in all_nodes_after new_node = database.get_atom(new_node_handle) - assert new_node["handle"] == new_node_handle - assert new_node["named_type"] == "Concept" - assert new_node["name"] == "lion" + assert new_node.handle == new_node_handle + assert new_node.named_type == "Concept" + assert new_node.name == "lion" new_node_handle = database.get_node_handle("Concept", "cat") assert new_node_handle == ExpressionHasher.terminal_hash("Concept", "cat") assert new_node_handle not in all_nodes_before assert new_node_handle in all_nodes_after new_node = database.get_atom(new_node_handle) - assert new_node["handle"] == new_node_handle - assert new_node["named_type"] == "Concept" - assert new_node["name"] == "cat" + assert new_node.handle == new_node_handle + assert new_node.named_type == "Concept" + assert new_node.name == "cat" @pytest.mark.parametrize( "node,expected_count", @@ -544,14 +566,14 @@ def test_add_link(self, database: RedisMongoDB): ) def test_get_incoming_links_by_node(self, node, expected_count, database: RedisMongoDB): handle = database.get_node_handle(*node) - links = database.get_incoming_links(atom_handle=handle, handles_only=False) - link_handles = database.get_incoming_links(atom_handle=handle, handles_only=True) + links = database.get_incoming_links_atoms(atom_handle=handle) + link_handles = database.get_incoming_links_handles(atom_handle=handle) assert len(links) > 0 assert all(isinstance(link, str) for link in link_handles) answer = database.redis.smembers(f"{KeyPrefix.INCOMING_SET.value}:{handle}") assert len(links) == len(answer) == expected_count assert sorted(link_handles) == sorted(answer) - assert all([handle in link["targets"] for link in links]) + assert all([handle in link.targets for link in links]) @pytest.mark.parametrize( "key", @@ -594,23 +616,21 @@ def test_get_incoming_links_by_links(self, link_type, link_targets, database: Re h = database.get_node_handle(*target) else: database.get_link_handle(*target) - links = database.get_incoming_links(atom_handle=h, handles_only=True) + links = database.get_incoming_links_handles(atom_handle=h) assert len(links) > 0 assert all(isinstance(link, str) for link in links) answer = database.redis.smembers(f"{KeyPrefix.INCOMING_SET.value}:{h}") assert sorted(links) == sorted(answer) assert handle in links - links = database.get_incoming_links(atom_handle=h, handles_only=False) + links = database.get_incoming_links_atoms(atom_handle=h) atom = database.get_atom(handle=handle) - assert atom in links - links = database.get_incoming_links( - atom_handle=h, handles_only=False, targets_document=True - ) + assert atom.handle in [link.handle for link in links] + links = database.get_incoming_links_atoms(atom_handle=h, targets_document=True) assert len(links) > 0 - assert all(isinstance(link, dict) for link in links) + assert all(isinstance(link, LinkT) for link in links) for link in links: - for a, b in zip(link["targets"], link["targets_document"]): - assert a == b["handle"] + for a, b in zip(link.targets, link.targets_documents): + assert a == b.handle @pytest.mark.parametrize( "link_type,link_targets,expected_count", @@ -676,7 +696,7 @@ def test_redis_templates(self, template_values, expected_count, database: RedisM ], ) def test_redis_names(self, node_type, expected_count, database: RedisMongoDB): - nodes = database.get_all_nodes(node_type) + nodes = database.get_all_nodes_handles(node_type) assert len(nodes) == expected_count assert all( [database.redis.smembers(f"{KeyPrefix.NAMED_ENTITIES.value}:{node}") for node in nodes] diff --git a/tests/unit/adapters/test_redis_mongo_extra.py b/tests/unit/adapters/test_redis_mongo_extra.py index d4edbb76..c619cb45 100644 --- a/tests/unit/adapters/test_redis_mongo_extra.py +++ b/tests/unit/adapters/test_redis_mongo_extra.py @@ -3,13 +3,14 @@ import pytest from hyperon_das_atomdb.adapters.redis_mongo_db import MongoDBIndex, RedisMongoDB, _HashableDocument +from tests.helpers import dict_to_node_params from tests.unit.fixtures import redis_mongo_db # noqa: F401 class TestRedisMongoExtra: def test_hashable_document_str(self, redis_mongo_db): # noqa: F811 db = redis_mongo_db - node = db._build_node({"type": "A", "name": "A"}) + node = db._build_node(dict_to_node_params({"type": "A", "name": "A"})) hashable = _HashableDocument(node) str_hashable = str(hashable) assert isinstance(str_hashable, str) diff --git a/tests/unit/test_database_private_methods.py b/tests/unit/test_database_private_methods.py index ab10b308..5fad1f32 100644 --- a/tests/unit/test_database_private_methods.py +++ b/tests/unit/test_database_private_methods.py @@ -1,9 +1,10 @@ import pytest -from hyperon_das_atomdb.database import AtomDB +from hyperon_das_atomdb.database import AtomDB, LinkT, NodeT +from hyperon_das_atomdb.exceptions import AddLinkException, AddNodeException, AtomDoesNotExist +from tests.helpers import add_link, add_node, check_handle from .fixtures import in_memory_db, redis_mongo_db # noqa: F401 -from .test_database_public_methods import add_link, add_node, check_handle class TestDatabasePrivateMethods: @@ -11,9 +12,9 @@ class TestDatabasePrivateMethods: def test__get_atom(self, database, request): db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) - link_a = add_link(db, "Aa", [], database) - node = db._get_atom(node_a["handle"]) - link = db._get_atom(link_a["handle"]) + link_a = add_link(db, "Aa", [node_a], database) + node = db._get_atom(node_a.handle) + link = db._get_atom(link_a.handle) assert node, link @pytest.mark.parametrize("database", ["redis_mongo_db", "in_memory_db"]) @@ -31,17 +32,26 @@ def test__get_atom_none(self, database, request): ) def test__reformat_document(self, database, kwlist, request): db: AtomDB = request.getfixturevalue(database) - node_handle = db.add_node({"name": "A", "type": "Test"}).get("handle") + link = db.add_link( + LinkT( + type="Relation", + targets=[ + NodeT(name="A", type="Test"), + NodeT(name="B", type="Test"), + ], + ), + ) if database != "in_memory_db": db.commit() - link = {"name": "A", "targets": [node_handle]} for kw in kwlist: answer = db._reformat_document(link, **{kw: True}) - assert set(answer.keys()) == {"name", "targets", "targets_document"} - assert len(answer["targets"]) == 1 - assert len(answer["targets_document"]) == 1 - assert answer["name"] == "A" - assert isinstance(answer["targets"][0], (str if kw == "targets_document" else dict)) + assert answer.targets_documents is not None + assert len(answer.targets) == 2 + assert len(answer.targets_documents) == 2 + assert answer.named_type == "Relation" + assert all( + isinstance(t, NodeT) for t in answer.targets_documents + ), answer.targets_documents @pytest.mark.parametrize( "database,kwlist", @@ -52,9 +62,18 @@ def test__reformat_document(self, database, kwlist, request): ) def test__reformat_document_exceptions(self, database, kwlist, request): db: AtomDB = request.getfixturevalue(database) - link = {"name": "A", "targets": ["test"]} + link = LinkT( + _id="dummy", + handle="dummy", + composite_type_hash="dummy", + composite_type=["dummy"], + named_type="dummy", + named_type_hash="dummy", + is_toplevel=True, + targets=["dummy"], + ) for kw in kwlist: - with pytest.raises(Exception, match="Nonexistent atom"): + with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): db._reformat_document(link, **{kw: True}) @pytest.mark.parametrize( @@ -74,19 +93,19 @@ def test__reformat_document_exceptions(self, database, kwlist, request): ) def test__build_node(self, database, expected_fields, expected_handle, request): db: AtomDB = request.getfixturevalue(database) - handle, node = db._build_node({"type": "Test", "name": "test"}) + node = db._build_node(NodeT(type="Test", name="test")) assert node - assert handle == expected_handle - assert all([k in node for k in expected_fields]) - assert isinstance(node, dict) - assert check_handle(handle) + assert node.handle == expected_handle + assert all([k in node.to_dict() for k in expected_fields]) + assert isinstance(node, NodeT) + assert check_handle(node.handle) # Test exception - with pytest.raises(Exception, match="The \"name\" and \"type\" fields must be sent"): - db._build_node({}) + with pytest.raises(AddNodeException, match="'type' and 'name' are required."): + db._build_node(NodeT(type="", name="")) @pytest.mark.parametrize( - "database,expected_fields, expected_handle,is_top_level", + "database,expected_fields, expected_handle,is_toplevel", [ ( "redis_mongo_db", @@ -146,24 +165,28 @@ def test__build_node(self, database, expected_fields, expected_handle, request): ), ], ) - def test__build_link(self, database, expected_fields, expected_handle, is_top_level, request): + def test__build_link(self, database, expected_fields, expected_handle, is_toplevel, request): db: AtomDB = request.getfixturevalue(database) - handle, link, targets = db._build_link( - {"type": "Test", "targets": [{"type": "Test", "name": "test"}]}, is_top_level + link = db._build_link( + LinkT( + type="Test", + targets=[ + NodeT(type="Test", name="test"), + ], + ), + is_toplevel, ) - assert expected_handle in targets - assert all([k in link for k in expected_fields]) - assert link["is_toplevel"] == is_top_level - assert check_handle(handle) - assert isinstance(link, dict) - assert isinstance(targets, list) + assert expected_handle in link.targets + assert all([k in link.to_dict() for k in expected_fields]) + assert link.is_toplevel == is_toplevel + assert check_handle(link.handle) + assert isinstance(link, LinkT) + assert isinstance(link.targets, list) @pytest.mark.parametrize("database", ["redis_mongo_db", "in_memory_db"]) def test__build_link_exceptions(self, database, request): db: AtomDB = request.getfixturevalue(database) - with pytest.raises(ValueError, match="The target must be a dictionary"): - db._build_link({"type": "Test", "targets": [""]}) - with pytest.raises(Exception, match="The \"type\" and \"targets\" fields must be sent"): - db._build_link({"type": "Test", "targets": None}) - with pytest.raises(Exception, match="The \"type\" and \"targets\" fields must be sent"): - db._build_link({"type": None, "targets": []}) + with pytest.raises(AddLinkException, match="'type' and 'targets' are required."): + db._build_link(LinkT(type="Test", targets=[])) + with pytest.raises(AddLinkException, match="'type' and 'targets' are required."): + db._build_link(LinkT(type="", targets=[NodeT(type="Test", name="test")])) diff --git a/tests/unit/test_database_public_methods.py b/tests/unit/test_database_public_methods.py index 1f28ccac..add1e1c8 100644 --- a/tests/unit/test_database_public_methods.py +++ b/tests/unit/test_database_public_methods.py @@ -1,34 +1,13 @@ +import functools from unittest import mock import pytest -from hyperon_das_atomdb.database import AtomDB -from hyperon_das_atomdb.exceptions import AtomDoesNotExist +from hyperon_das_atomdb.database import AtomDB, AtomT, LinkT, NodeT +from tests.helpers import add_link, add_node, check_handle, dict_to_link_params, dict_to_node_params from tests.unit.fixtures import in_memory_db, redis_mongo_db # noqa: F401 -def check_handle(handle): - return all((isinstance(handle, str), len(handle) == 32, int(handle, 16))) - - -def add_node(db: AtomDB, node_name, node_type, adapter, extra_fields=None): - node_dict = {"name": node_name, "type": node_type} - node_dict.update(extra_fields or {}) - node = db.add_node(node_dict) - if adapter == "redis_mongo_db": - db.commit() - return node - - -def add_link(db: AtomDB, link_type, dict_targets, adapter, is_top_level=True, extra_fields=None): - link_params = {"type": link_type, "targets": dict_targets} - link_params.update(extra_fields or {}) - link = db.add_link(link_params, toplevel=is_top_level) - if adapter != "in_memory_db": - db.commit() - return link - - def pytest_generate_tests(metafunc): idlist = [] argvalues = [] @@ -57,9 +36,9 @@ def _load_db(self, db): with open(f"{path}/adapters/data/ram_only_links.json") as f: all_links = json.load(f) for node in all_nodes: - db.add_node(node) + db.add_node(dict_to_node_params(node)) for link in all_links: - db.add_link(link) + db.add_link(dict_to_link_params(link)) @pytest.mark.parametrize( "expected", @@ -100,7 +79,8 @@ def test_node_handle_exceptions(self, database, expected, request): def test_link_handle(self, database, expected, request): db: AtomDB = request.getfixturevalue(database) handle = db.link_handle("Similarity", []) - assert len(set([db.link_handle("Similarity", f) for f in [[], [], ""]])) == 1 + handles = set([db.link_handle("Similarity", f) for f in [[], [], (), list(), tuple()]]) + assert len(handles) == 1, handles assert handle assert check_handle(handle) @@ -122,7 +102,7 @@ def test_link_handle_exceptions(self, database, expected, request): def test_node_exists(self, database, request): db: AtomDB = request.getfixturevalue(database) - db.add_node({"name": "A", "type": "Test"}) + db.add_node(NodeT(name="A", type="Test")) if database != "in_memory_db": db.commit() no_exists = db.node_exists("Test", "B") @@ -136,13 +116,12 @@ def test_node_exists(self, database, request): "targets", [ (["180fed764dbd593f1ea45b63b13d7e69"]), - ([]), ], ) def test_link_exists(self, database, targets, request): db: AtomDB = request.getfixturevalue(database) - dict_targets = [{"type": "Test", "name": "test"}] if targets else [] - link = {"type": "Test", "targets": dict_targets} + targets_params = [NodeT(type="Test", name="test")] + link = LinkT("Test", targets_params) db.add_link(link) if database != "in_memory_db": db.commit() @@ -158,7 +137,7 @@ def test_get_node_handle(self, database, request): db: AtomDB = request.getfixturevalue(database) expected_node = add_node(db, "A", "Test", database) node = db.get_node_handle("Test", "A") - assert node == expected_node["handle"] + assert node == expected_node.handle assert check_handle(node) @pytest.mark.parametrize( @@ -197,10 +176,10 @@ def test_get_node_handle_exceptions(self, database, request): def test_get_node_name(self, database, request): db: AtomDB = request.getfixturevalue(database) expected_node = add_node(db, "A", "Test", database) - name = db.get_node_name(expected_node["handle"]) + name = db.get_node_name(expected_node.handle) # NOTE all adapters must return the same type assert isinstance(name, str) - assert name == expected_node["name"] + assert name == expected_node.name def test_get_node_name_exceptions(self, database, request): if database == "redis_mongo_db": @@ -218,9 +197,9 @@ def test_get_node_name_exceptions(self, database, request): def test_get_node_type(self, database, request): db: AtomDB = request.getfixturevalue(database) expected_node = add_node(db, "A", "Test", database) - node_type = db.get_node_type(expected_node["handle"]) + node_type = db.get_node_type(expected_node.handle) assert isinstance(node_type, str) - assert node_type == expected_node["named_type"] + assert node_type == expected_node.named_type def test_get_node_type_exceptions(self, database, request): db: AtomDB = request.getfixturevalue(database) @@ -237,23 +216,23 @@ def test_get_node_by_name(self, database, request): assert isinstance(nodes, list) assert len(nodes) == 3 assert all(check_handle(node) for node in nodes) - assert all(n["handle"] in nodes for n in expected_nodes) - assert not any(n["handle"] in nodes for n in not_expected_nodes) + assert all(n.handle in nodes for n in expected_nodes) + assert not any(n.handle in nodes for n in not_expected_nodes) @pytest.mark.parametrize( "atom_type,atom_values,query_values,expected", [ ( - "node", - {"node_name": "Ac", "node_type": "Test"}, - [{"field": "name", "value": "Ac"}], - "785a4a9c6a986f8b1ba35d0de70e8fd8", + "node", # atom_type + {"node_name": "Ac", "node_type": "Test"}, # atom_values + [{"field": "name", "value": "Ac"}], # query_values + "785a4a9c6a986f8b1ba35d0de70e8fd8", # expected ), ( - "link", - {"link_type": "Ac", "dict_targets": []}, - [{"field": "named_type", "value": "Ac"}], - "8819a837186918b90b59cc316f36b1e1", + "link", # atom_type + {"link_type": "Ac", "dict_targets": [NodeT("A", "A")]}, # atom_values + [{"field": "named_type", "value": "Ac"}], # query_values + "8ec320f9ffe82c28fcefd256a20b5c60", # expected ), ], ) @@ -267,9 +246,19 @@ def test_get_atoms_by_field( ) db: AtomDB = request.getfixturevalue(database) if atom_type == "link": - add_link(db, atom_values["link_type"], atom_values["dict_targets"], database) + add_link( + db, + link_type=atom_values["link_type"], + targets=atom_values["dict_targets"], + adapter=database, + ) else: - add_node(db, atom_values["node_name"], atom_values["node_type"], database) + add_node( + db, + node_name=atom_values["node_name"], + node_type=atom_values["node_type"], + adapter=database, + ) atoms = db.get_atoms_by_field(query_values) assert isinstance(atoms, list) assert all(check_handle(atom) for atom in atoms) @@ -279,8 +268,11 @@ def test_get_atoms_by_field( "index_params,query_params,expected", [ ( + # index_params {"atom_type": "node", "fields": ["value"], "named_type": "Test"}, + # query_params / custom attributes [{"field": "value", "value": 3}], + # expected "815212e3d7ac246e70c1744d14a8c402", ), ( @@ -306,6 +298,7 @@ def test_get_atoms_by_field( ], ) def test_get_atoms_by_index(self, database, index_params, query_params, expected, request): + pytest.skip("Requires new implementation since the new custom attributes were introduced.") if database == "in_memory_db": pytest.skip( "ERROR Not implemented. See https://github.com/singnet/das-atom-db/issues/210" @@ -337,9 +330,10 @@ def test_get_atoms_by_index(self, database, index_params, query_params, expected assert isinstance(atoms, list) assert cursor == 0 assert len(atoms) == 1 - assert atom["handle"] in [atom["handle"] for atom in atoms] - assert expected in [atom["handle"] for atom in atoms] - assert all(isinstance(a, dict) for a in atoms) + handles = [atom.handle for atom in atoms] + assert atom.handle in handles + assert expected in handles + assert all(isinstance(a, AtomT) for a in atoms) def test_get_atoms_by_index_exceptions(self, database, request): if database == "in_memory_db": @@ -352,6 +346,7 @@ def test_get_atoms_by_index_exceptions(self, database, request): db.get_atoms_by_index("", []) def test_get_atoms_by_text_field_regex(self, database, request): + pytest.skip("Requires new implementation since the new custom attributes were introduced.") if database == "in_memory_db": # TODO: fix this pytest.skip( @@ -370,6 +365,7 @@ def test_get_atoms_by_text_field_regex(self, database, request): assert len(atoms) == 1 def test_get_atoms_by_text_field_text(self, database, request): + pytest.skip("Requires new implementation since the new custom attributes were introduced.") if database == "in_memory_db": # TODO: fix this pytest.skip( @@ -403,7 +399,7 @@ def test_get_node_by_name_starting_with(self, database, request): assert isinstance(nodes, list) assert all(check_handle(n) for n in nodes) assert all(isinstance(n, str) for n in nodes) - assert all(handle in nodes for handle in [node_a["handle"], node_b["handle"]]) + assert all(handle in nodes for handle in [node_a.handle, node_b.handle]) assert len(nodes) == 2 @pytest.mark.parametrize( @@ -419,12 +415,18 @@ def test_get_node_by_name_starting_with(self, database, request): def test_get_all_nodes(self, database, params, nodes_len, request): db: AtomDB = request.getfixturevalue(database) values = {"Test": ["Aaa", "Abb", "Bbb"], "Test2": ["Bcc", "Ccc"]} - _ = [add_node(db, vv, k, database) for k, v in values.items() for vv in v] - nodes = db.get_all_nodes(**params) - assert isinstance(nodes, list) - if params.get("names"): - assert all([n in values[params["node_type"]] for n in nodes]) + _ = [ + add_node(db, node_name, node_type, database) + for node_type, node_names in values.items() + for node_name in node_names + ] + names: bool = params.pop("names", False) + if names: + nodes = db.get_all_nodes_names(**params) else: + nodes = db.get_all_nodes_handles(**params) + assert isinstance(nodes, list) + if not names: assert all(check_handle(n) for n in nodes) assert all(isinstance(n, str) for n in nodes) assert len(nodes) == nodes_len @@ -443,9 +445,9 @@ def test_get_all_nodes(self, database, params, nodes_len, request): ) def test_get_all_links(self, database, params, links_len, request): db: AtomDB = request.getfixturevalue(database) - add_link(db, "Ac", [{"name": "A", "type": "A"}], database) - add_link(db, "Ac", [{"name": "B", "type": "B"}], database) - add_link(db, "Ac", [{"name": "C", "type": "C"}], database) + add_link_ = functools.partial(add_link, link_type="Ac", db=db, adapter=database) + for node_name, node_type in (("A", "A"), ("B", "B"), ("C", "C")): + add_link_(targets=[NodeT(name=node_name, type=node_type)]) links = db.get_all_links(**params) assert isinstance(links, set) assert all(check_handle(link) for link in links) @@ -454,8 +456,8 @@ def test_get_all_links(self, database, params, links_len, request): def test_get_link_handle(self, database, request): db: AtomDB = request.getfixturevalue(database) - link = add_link(db, "Ac", [{"name": "A", "type": "A"}], database) - handle = db.get_link_handle(link["type"], link["targets"]) + link = add_link(db, "Ac", [NodeT(name="A", type="A")], database) + handle = db.get_link_handle(link.named_type, link.targets) assert check_handle(handle) def test_get_link_handle_exceptions(self, database, request): @@ -465,12 +467,13 @@ def test_get_link_handle_exceptions(self, database, request): def test_get_link_type(self, database, request): db: AtomDB = request.getfixturevalue(database) - link_a = add_link(db, "Ac", [{"name": "A", "type": "A"}], database) - add_link(db, "Bc", [{"name": "A", "type": "A"}], database) - link_type = db.get_link_type(link_a["handle"]) + node_params = NodeT(name="A", type="A") + link_a = add_link(db, "Ac", [node_params], database) + add_link(db, "Bc", [node_params], database) + link_type = db.get_link_type(link_a.handle) assert link_type assert isinstance(link_type, str) - assert link_type == link_a["type"] + assert link_type == link_a.named_type def test_get_link_type_exceptions(self, database, request): db: AtomDB = request.getfixturevalue(database) @@ -479,22 +482,21 @@ def test_get_link_type_exceptions(self, database, request): def test_get_link_targets(self, database, request): db: AtomDB = request.getfixturevalue(database) - link_a = add_link(db, "Ac", [{"name": "A", "type": "A"}], database) - targets = db.get_link_targets(link_a["handle"]) + link_a = add_link(db, "Ac", [NodeT(name="A", type="A")], database) + targets = db.get_link_targets(link_a.handle) assert isinstance(targets, list) assert len(targets) == 1 assert all(check_handle(t) for t in targets) assert all(isinstance(t, str) for t in targets) - assert targets == link_a["targets"] + assert targets == link_a.targets @pytest.mark.parametrize( "params,links_len", [ # TODO: differences here must be fixed if possible ({}, 3), ({"handles_only": True}, 3), - # NOTE should return the same value, returning 4 - # ({"handles_only": False}, 3), - # ({"no_target_format": True}, 3), + ({"handles_only": False}, 3), + ({"no_target_format": True}, 3), ], ) def test_get_incoming_links(self, database, params, links_len, request): @@ -503,10 +505,13 @@ def test_get_incoming_links(self, database, params, links_len, request): add_link(db, "Aa", [node_a], database) add_link(db, "Ab", [node_a], database) add_link(db, "Ac", [node_a], database) - links = db.get_incoming_links(node_a["handle"], **params) + if params.get("handles_only"): + links = db.get_incoming_links_handles(node_a.handle, **params) + else: + links = db.get_incoming_links_atoms(node_a.handle, **params) assert len(links) == links_len assert all( - [check_handle(link if params.get("handles_only") else link["handle"]) for link in links] + [check_handle(link if params.get("handles_only") else link.handle) for link in links] ) @pytest.mark.parametrize( @@ -535,12 +540,12 @@ def test_get_matched_links(self, database, params, links_len, request): db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) link_a = add_link(db, "Aa", [node_a], database) - _ = add_link(db, "NoTopLevel", [node_a], database, is_top_level=False) - _ = add_link(db, "Ac", [node_a], database) - params["link_type"] = link_a["type"] if not params.get("link_type") else params["link_type"] - params["target_handles"] = ( - link_a["targets"] if not params.get("target_handles") else params["target_handles"] - ) + add_link(db, "NoTopLevel", [node_a], database, is_toplevel=False) + add_link(db, "Ac", [node_a], database) + if not params.get("link_type"): + params["link_type"] = link_a.named_type + if not params.get("target_handles"): + params["target_handles"] = link_a.targets links = db.get_matched_links(**params) assert len(links) == links_len if all(isinstance(link, tuple) for link in links): @@ -648,7 +653,7 @@ def test_get_matched_no_links(self, database, params, links_len, request): assert len(links) == links_len @pytest.mark.parametrize( - "params,links_len,is_top_level", + "params,links_len,is_toplevel", [ # TODO: differences here must be fixed if possible ({}, 1, True), ({}, 1, False), @@ -656,18 +661,18 @@ def test_get_matched_no_links(self, database, params, links_len, request): # NOTE should return None or same as redis_mongo ], ) - def test_get_matched_type_template(self, database, params, links_len, is_top_level, request): + def test_get_matched_type_template(self, database, params, links_len, is_toplevel, request): db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) node_b = add_node(db, "Bbb", "Test", database) - link_a = add_link(db, "Aa", [node_a, node_b], database, is_top_level=is_top_level) + link_a = add_link(db, "Aa", [node_a, node_b], database, is_toplevel=is_toplevel) links = db.get_matched_type_template(["Aa", *["Test", "Test"]], **params) assert len(links) == links_len if len(links) > 0: for link in links: assert check_handle(link) - assert link == link_a["handle"] - assert sorted(db.get_atom(link)["targets"]) == sorted(link_a["targets"]) + assert link == link_a.handle + assert sorted(db.get_atom(link).targets) == sorted(link_a.targets) def test_get_matched_type(self, database, request): if database == "redis_mongo_db": @@ -677,15 +682,16 @@ def test_get_matched_type(self, database, request): "See https://github.com/singnet/das-atom-db/issues/210" ) db: AtomDB = request.getfixturevalue(database) - link_a = add_link(db, "Aa", [], database) - add_link(db, "Ab", [], database) - links = db.get_matched_type(link_a["type"]) + targets_params = [NodeT(type="Test", name="test")] + link_a = add_link(db, "Aa", targets_params, database) + add_link(db, "Ab", targets_params, database) + links = db.get_matched_type(link_a.named_type) assert len(links) == 1 if len(links) > 0: for link in links: assert check_handle(link) - assert link == link_a["handle"] - assert sorted(db.get_atom(link)["targets"]) == sorted(link_a["targets"]) + assert link == link_a.handle + assert sorted(db.get_atom(link).targets) == sorted(link_a.targets) @pytest.mark.parametrize( "params,top_level,n_links,n_nodes", @@ -693,8 +699,7 @@ def test_get_matched_type(self, database, request): ({}, True, 1, 1), ({"no_target_format": True}, False, 1, 1), ({"no_target_format": False}, False, 1, 1), - # NOTE breaks when is a node - # ({"targets_document": True}, False, 1, 1), + ({"targets_document": True}, False, 1, 1), ({"deep_representation": True}, False, 1, 1), ({"deep_representation": False}, False, 1, 1), ], @@ -702,10 +707,10 @@ def test_get_matched_type(self, database, request): def test_get_atom_node(self, database, params, top_level, n_links, n_nodes, request): db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) - atom_n = db.get_atom(node_a["handle"], **params) + atom_n = db.get_atom(node_a.handle, **params) assert atom_n - assert atom_n["handle"] == node_a["handle"] - assert check_handle(atom_n["handle"]) + assert atom_n.handle == node_a.handle + assert check_handle(atom_n.handle) @pytest.mark.parametrize( "params,top_level,n_links,n_nodes", @@ -718,18 +723,20 @@ def test_get_atom_node(self, database, params, top_level, n_links, n_nodes, requ ) def test_get_atom_link(self, database, params, top_level, n_links, n_nodes, request): db: AtomDB = request.getfixturevalue(database) - link_a = add_link(db, "Aa", [], database, is_top_level=top_level) - atom_l = db.get_atom(link_a["handle"], **params) + link_a = add_link( + db, "Aa", [NodeT(type="Test", name="test")], database, is_toplevel=top_level + ) + atom_l = db.get_atom(link_a.handle, **params) assert atom_l - assert atom_l["handle"] == link_a["handle"] - assert check_handle(atom_l["handle"]) + assert atom_l.handle == link_a.handle + assert check_handle(atom_l.handle) def test_get_atom_type(self, database, request): db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) - link_a = add_link(db, "Test", [], database) - atom_type_node = db.get_atom_type(node_a["handle"]) - atom_type_link = db.get_atom_type(link_a["handle"]) + link_a = add_link(db, "Test", [node_a], database) + atom_type_node = db.get_atom_type(node_a.handle) + atom_type_link = db.get_atom_type(link_a.handle) assert isinstance(atom_type_node, str) assert isinstance(atom_type_link, str) assert atom_type_node == atom_type_link @@ -741,21 +748,29 @@ def test_get_atom_type_none(self, database, request): assert atom_type_node is None assert atom_type_link is None - def test_get_atom_as_dict(self, database, request): - db: AtomDB = request.getfixturevalue(database) - node_a = add_node(db, "Aaa", "Test", database) - link_a = add_link(db, "Test", [], database) - atom_node = db.get_atom_as_dict(node_a["handle"]) - atom_link = db.get_atom_as_dict(link_a["handle"]) - assert isinstance(atom_node, dict) - assert isinstance(atom_link, dict) - - def test_get_atom_as_dict_exception(self, database, request): - db: AtomDB = request.getfixturevalue(database) - with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): - db.get_atom_as_dict("handle") + # NOTE: not needed - Atom class has a method to get the atom as a dict (`as_dict`) + # def test_get_atom_as_dict(self, database, request): + # if database == "in_memory_db": + # pytest.skip("in_memory_db doesn't implement this `get_atom_as_dict`") + # db: AtomDB = request.getfixturevalue(database) + # node_a = add_node(db, "Aaa", "Test", database) + # link_a = add_link(db, "Test", [node_a], database) + # atom_node = db.get_atom_as_dict(node_a.handle) + # atom_link = db.get_atom_as_dict(link_a.handle) + # assert isinstance(atom_node, dict) + # assert isinstance(atom_link, dict) + + # NOTE: not needed - Atom class has a method to get the atom as a dict (`as_dict`) + # def test_get_atom_as_dict_exception(self, database, request): + # if database == "in_memory_db": + # pytest.skip("in_memory_db doesn't implement this `get_atom_as_dict`") + # db: AtomDB = request.getfixturevalue(database) + # with pytest.raises(AtomDoesNotExist, match="Nonexistent atom"): + # db.get_atom_as_dict("handle") def test_get_atom_as_dict_exceptions(self, database, request): + if database == "in_memory_db": + pytest.skip("in_memory_db doesn't implement this `get_atom_as_dict`") if database == "redis_mongo_db": # TODO: fix this pytest.skip( @@ -779,9 +794,9 @@ def test_get_atom_as_dict_exceptions(self, database, request): ) def test_count_atoms(self, database, params, request): db: AtomDB = request.getfixturevalue(database) - add_node(db, "Aaa", "Test", database) - add_link(db, "Test", [], database) - atoms_count = db.count_atoms(params) + node_a = add_node(db, "Aaa", "Test", database) + add_link(db, "Test", [node_a], database) + atoms_count = db.count_atoms(params) # InMemoryDB ignores params assert atoms_count assert isinstance(atoms_count, dict) assert isinstance(atoms_count["atom_count"], int) @@ -794,8 +809,8 @@ def test_count_atoms(self, database, params, request): def test_clear_database(self, database, request): db: AtomDB = request.getfixturevalue(database) - add_node(db, "Aaa", "Test", database) - add_link(db, "Test", [], database) + node_a = add_node(db, "Aaa", "Test", database) + add_link(db, "Test", [node_a], database) assert db.count_atoms()["atom_count"] == 2 db.clear_database() assert db.count_atoms()["atom_count"] == 0 @@ -811,11 +826,11 @@ def testadd_node(self, database, node, request): db: AtomDB = request.getfixturevalue(database) if database == "redis_mongo_db": db.mongo_bulk_insertion_limit = 1 - node = db.add_node(node) + node = db.add_node(NodeT(name="A", type="A")) count = db.count_atoms() assert node assert count["atom_count"] == 1 - assert isinstance(node, dict) + assert isinstance(node, NodeT) @pytest.mark.parametrize( "node", @@ -829,7 +844,8 @@ def test_add_node_discard(self, database, node, request): db: AtomDB = request.getfixturevalue(database) db.mongo_bulk_insertion_limit = 1 db.max_mongo_db_document_size = 1 - node = db.add_node(node) + node_params = NodeT(name="AAAA", type="A") + node = db.add_node(node_params) count = db.count_atoms() assert node is None assert count["atom_count"] == 0 @@ -859,18 +875,21 @@ def testadd_node_exceptions(self, database, node, request): "params,expected_count,top_level", [ ({"type": "A", "targets": [{"name": "A", "type": "A"}]}, 2, True), - ({"type": "A", "targets": []}, 1, True), ], ) def testadd_link(self, database, params, expected_count, top_level, request): db: AtomDB = request.getfixturevalue(database) if database == "redis_mongo_db": db.mongo_bulk_insertion_limit = 1 - link = db.add_link(params, top_level) + targets = [dict_to_node_params(t) for t in params["targets"]] + link = db.add_link( + LinkT(type=params["type"], targets=targets), + top_level, + ) count = db.count_atoms() assert link assert count["atom_count"] == expected_count - assert isinstance(link, dict) + assert isinstance(link, LinkT) @pytest.mark.parametrize( "params", @@ -885,37 +904,34 @@ def test_reindex(self, database, params, request): "ERROR Not implemented. See https://github.com/singnet/das-atom-db/issues/210" ) db: AtomDB = request.getfixturevalue(database) - add_node(db, "Aaa", "Test", database) - add_link(db, "Test", [], database) + node_a = add_node(db, "Aaa", "Test", database) + add_link(db, "Test", [node_a], database) db.reindex(params) def test_delete_atom(self, database, request): - if database == "in_memory_db": - # TODO: fix this - pytest.skip( - "ERROR Atom not in incoming_set. See https://github.com/singnet/das-atom-db/issues/210" - ) db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) - link_a = add_link(db, "Test", [], database) + node_b = add_node(db, "Bbb", "Test", database) + link_a = add_link(db, "Test", [node_b], database) + count = db.count_atoms({"precise": True}) + assert count["atom_count"] == 3 + assert count["node_count"] == 2 + assert count["link_count"] == 1 + db.delete_atom(node_a.handle) count = db.count_atoms({"precise": True}) assert count["atom_count"] == 2 assert count["node_count"] == 1 assert count["link_count"] == 1 - db.delete_atom(node_a["handle"]) + db.delete_atom(link_a.handle) count = db.count_atoms({"precise": True}) assert count["atom_count"] == 1 - assert count["node_count"] == 0 - assert count["link_count"] == 1 - db.delete_atom(link_a["handle"]) - count = db.count_atoms({"precise": True}) - assert count["atom_count"] == 0 - assert count["node_count"] == 0 + assert count["node_count"] == 1 assert count["link_count"] == 0 def test_delete_atom_exceptions(self, database, request): if database == "in_memory_db": # TODO: fix this + # TODO: C++ implementation does not raise any exception when atom does not exist pytest.skip( "ERROR Atom not in incoming_set. See https://github.com/singnet/das-atom-db/issues/210" ) @@ -1032,24 +1048,44 @@ def test_bulk_insert(self, database, request): ) db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) - link_a = add_link(db, "Test", [{"name": "A", "type": "A"}], database) - node_a["name"] = "B" - link_a["targets"] = [node_a["handle"]] - db.bulk_insert([node_a, link_a]) + link_a = add_link( + db, + "Test", + [NodeT(name="A", type="A")], + database, + ) + node_a_copy = node_a.__class__( + name="B", # different name + _id=node_a._id, + handle=node_a.handle, + composite_type_hash=node_a.composite_type_hash, + named_type=node_a.named_type, + ) + link_a_copy = link_a.__class__( + targets=[node_a_copy.handle], # different targets + _id=link_a._id, + handle=link_a.handle, + composite_type_hash=link_a.composite_type_hash, + named_type=link_a.named_type, + composite_type=link_a.composite_type, + named_type_hash=link_a.named_type_hash, + is_toplevel=link_a.is_toplevel, + ) + db.bulk_insert([node_a_copy, link_a_copy]) count = db.count_atoms({"precise": True}) - node = db.get_atom(node_a["handle"]) - link = db.get_atom(link_a["handle"]) + node = db.get_atom(node_a.handle) + link = db.get_atom(link_a.handle) assert count["atom_count"] == 3 assert count["node_count"] == 2 assert count["link_count"] == 1 - assert node["name"] == "B" - assert link["targets"] == [node_a["handle"]] + assert node.name == "B" + assert link.targets == [node_a.handle] # Note no exception is raised if error def test_bulk_insert_exceptions(self, database, request): db: AtomDB = request.getfixturevalue(database) - node_a = db._build_node({"name": "A", "type": "A"}) - link_a = db._build_link({"targets": [], "type": "A"}) + node_a = db._build_node(NodeT(name="A", type="A")) + link_a = db._build_link(LinkT(targets=[node_a], type="A")) with pytest.raises(Exception): db.bulk_insert([node_a, link_a]) # TODO: fix this @@ -1060,12 +1096,19 @@ def test_bulk_insert_exceptions(self, database, request): def test_retrieve_all_atoms(self, database, request): db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) - link_a = add_link(db, "Test", [{"name": "A", "type": "A"}], database) + link_a = add_link( + db, + "Test", + [NodeT(name="A", type="A")], + database, + ) node_b = db.get_atom(db.get_node_handle(node_type="A", node_name="A")) atoms = db.retrieve_all_atoms() assert isinstance(atoms, list) assert len(atoms) == 3 - assert all(a in atoms for a in [node_a, link_a, node_b]) + all_atoms_handles = [a.handle for a in atoms] + for atom in [node_a, link_a, node_b]: + assert atom.handle in all_atoms_handles, f"{atom=}, {atoms=}" def test_commit(self, database, request): if database == "in_memory_db": @@ -1074,8 +1117,10 @@ def test_commit(self, database, request): "ERROR Not implemented on in_memory_db. See https://github.com/singnet/das-atom-db/issues/210" ) db: AtomDB = request.getfixturevalue(database) - db.add_node({"name": "A", "type": "Test"}) - db.add_link({"type": "Test", "targets": []}) + node_a = db.add_node(NodeT(name="A", type="Test")) + db.add_link(LinkT(type="Test", targets=[node_a])) + count = db.count_atoms({"precise": True}) + assert count["atom_count"] == 0 db.commit() count = db.count_atoms({"precise": True}) assert count["atom_count"] == 2 @@ -1090,10 +1135,32 @@ def test_commit_buffer(self, database, request): ) db: AtomDB = request.getfixturevalue(database) node_a = add_node(db, "Aaa", "Test", database) - link_a = add_link(db, "Test", [{"name": "A", "type": "A"}], database) - node_a["name"] = "B" - link_a["targets"] = [node_a["handle"]] - db.commit(buffer=[node_a, link_a]) + link_a = add_link( + db, + "Test", + [NodeT(name="A", type="A")], + database, + ) + + node_a_dict = dict( + name="B", # different name + _id=node_a._id, + handle=node_a.handle, + composite_type_hash=node_a.composite_type_hash, + named_type=node_a.named_type, + ) + link_a_dict = dict( + targets=[node_a_dict["handle"]], # different targets + _id=link_a._id, + handle=link_a.handle, + composite_type_hash=link_a.composite_type_hash, + named_type=link_a.named_type, + composite_type=link_a.composite_type, + named_type_hash=link_a.named_type_hash, + is_toplevel=link_a.is_toplevel, + ) + + db.commit(buffer=[node_a_dict, link_a_dict]) count = db.count_atoms({"precise": True}) assert count["atom_count"] == 3 assert count["node_count"] == 2 From 8b288e21d6133e401814ba25e1e945245843f463 Mon Sep 17 00:00:00 2001 From: Angelo Probst Date: Tue, 22 Oct 2024 13:56:12 -0300 Subject: [PATCH 2/6] removing `unbuffer` from mypy --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index e3438500..2e9ad58f 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ pylint: @pylint ./hyperon_das_atomdb --rcfile=.pylintrc mypy: - @unbuffer mypy --color-output --config-file mypy.ini ./hyperon_das_atomdb + @mypy --color-output --config-file mypy.ini ./hyperon_das_atomdb lint: isort black flake8 pylint mypy From d90e84cd8a3ef02e848ae93a19689f1e3d96f062 Mon Sep 17 00:00:00 2001 From: Angelo Probst Date: Wed, 23 Oct 2024 15:57:03 -0300 Subject: [PATCH 3/6] addressing reviewers feedback --- hyperon_das_atomdb/adapters/redis_mongo_db.py | 2 +- hyperon_das_atomdb/logger.py | 5 +-- hyperon_das_atomdb/utils/expression_hasher.py | 31 ------------------- pyproject.toml | 4 +-- 4 files changed, 4 insertions(+), 38 deletions(-) diff --git a/hyperon_das_atomdb/adapters/redis_mongo_db.py b/hyperon_das_atomdb/adapters/redis_mongo_db.py index 928d14c0..81190e2e 100644 --- a/hyperon_das_atomdb/adapters/redis_mongo_db.py +++ b/hyperon_das_atomdb/adapters/redis_mongo_db.py @@ -426,7 +426,7 @@ def _retrieve_document(self, handle: str) -> DocumentT | None: mongo_filter = {FieldNames.ID_HASH: handle} if document := self.mongo_atoms_collection.find_one(mongo_filter): if self._is_document_link(document): - document["targets"] = self._get_document_keys(document) + document[FieldNames.TARGETS] = self._get_document_keys(document) return document return None diff --git a/hyperon_das_atomdb/logger.py b/hyperon_das_atomdb/logger.py index a2701cdc..5edd098b 100644 --- a/hyperon_das_atomdb/logger.py +++ b/hyperon_das_atomdb/logger.py @@ -35,10 +35,7 @@ def __init__(self): """ if Logger.__instance is not None: - # TODO(angelo,andre): raise a more specific type of exception? - raise Exception( # pylint: disable=broad-exception-raised - "Invalid re-instantiation of Logger" - ) + raise RuntimeError("Invalid re-instantiation of Logger") logging.basicConfig( filename=LOG_FILE_NAME, diff --git a/hyperon_das_atomdb/utils/expression_hasher.py b/hyperon_das_atomdb/utils/expression_hasher.py index 8e885668..f3f4458a 100644 --- a/hyperon_das_atomdb/utils/expression_hasher.py +++ b/hyperon_das_atomdb/utils/expression_hasher.py @@ -117,34 +117,3 @@ def composite_hash(hash_base: str | list[str]) -> str: raise ValueError( "Invalid base to compute composite hash: " f"{type(hash_base)}: {hash_base}" ) - - -class StringExpressionHasher: # TODO(angelo,andre): remove this class? it's not used anywhere - """Utility class for generating string representations of expression hashes.""" - - @staticmethod - def compute_hash(text: str) -> str: - """Compute the MD5 hash of the given text.""" - return str() # TODO(angelo,andre): this seems right? - - @staticmethod - def named_type_hash(name: str) -> str: - """Compute the hash for a named type.""" - return f"" - - @staticmethod - def terminal_hash(named_type: str, terminal_name: str) -> str: - """Compute the hash for a terminal expression.""" - return f"<{named_type}: {terminal_name}>" - - @staticmethod - def expression_hash(named_type_hash: str, elements: list[str]) -> str: - """Compute the hash for a composite expression.""" - return f"<{named_type_hash}: {elements}>" - - @staticmethod - def composite_hash(hash_list: list[str]) -> str: - """Compute the composite hash from a list of hashes.""" - if len(hash_list) == 1: - return hash_list[0] - return f"{hash_list}" diff --git a/pyproject.toml b/pyproject.toml index 3cfa6656..6942ef5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ pymongo = "^4.5.0" python-dotenv = "^1.0.0" mongomock = "^4.1.2" setuptools = "^70.2.0" -hyperon-das-atomdb-cpp = "0.0.1" +hyperon-das-atomdb-cpp = "0.0.2" [tool.poetry.group.dev.dependencies] flake8 = "^6.1.0" @@ -30,7 +30,7 @@ pytest = "^7.4.2" pytest-cov = "^4.1.0" flake8-pyproject = "^1.2.3" pre-commit = "^3.5.0" -hyperon-das-atomdb-cpp = "0.0.1" +hyperon-das-atomdb-cpp = "0.0.2" [build-system] requires = ["poetry-core"] From 7d7ce79f9af359a5ecae231f7c4b30ac2ada7c6f Mon Sep 17 00:00:00 2001 From: Angelo Probst Date: Wed, 23 Oct 2024 16:00:46 -0300 Subject: [PATCH 4/6] updating `hyperon-das-atomdb-cpp` version --- hyperon_das_atomdb_cpp/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hyperon_das_atomdb_cpp/pyproject.toml b/hyperon_das_atomdb_cpp/pyproject.toml index 70e4e3dc..25e2f052 100644 --- a/hyperon_das_atomdb_cpp/pyproject.toml +++ b/hyperon_das_atomdb_cpp/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "hyperon_das_atomdb_cpp" -version = "0.0.1" +version = "0.0.2" description = "Atom Space DB for Hyperon DAS" readme = "README.md" requires-python = ">=3.10" From df2f287b5870e9a4b336386f967d9cdd6b3532d8 Mon Sep 17 00:00:00 2001 From: Angelo Probst Date: Thu, 24 Oct 2024 10:55:48 -0300 Subject: [PATCH 5/6] small fix --- hyperon_das_atomdb/adapters/redis_mongo_db.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/hyperon_das_atomdb/adapters/redis_mongo_db.py b/hyperon_das_atomdb/adapters/redis_mongo_db.py index 81190e2e..7f1cafb7 100644 --- a/hyperon_das_atomdb/adapters/redis_mongo_db.py +++ b/hyperon_das_atomdb/adapters/redis_mongo_db.py @@ -425,7 +425,7 @@ def _retrieve_document(self, handle: str) -> DocumentT | None: """ mongo_filter = {FieldNames.ID_HASH: handle} if document := self.mongo_atoms_collection.find_one(mongo_filter): - if self._is_document_link(document): + if self._is_document_link(document) and FieldNames.TARGETS not in document: document[FieldNames.TARGETS] = self._get_document_keys(document) return document return None @@ -719,7 +719,7 @@ def _build_atom_from_dict(self, document: DocumentT) -> AtomT: custom_attributes=document.get(FieldNames.CUSTOM_ATTRIBUTES, dict()), ) return link - else: + elif "name" in document: node = NodeT( handle=document[FieldNames.ID_HASH], _id=document[FieldNames.ID_HASH], @@ -729,6 +729,8 @@ def _build_atom_from_dict(self, document: DocumentT) -> AtomT: custom_attributes=document.get(FieldNames.CUSTOM_ATTRIBUTES, dict()), ) return node + else: + raise ValueError("Invalid atom type") def _get_atom(self, handle: str) -> AtomT | None: document = self._retrieve_document(handle) @@ -1398,12 +1400,16 @@ def _get_atoms_by_index(self, index_id: str, **kwargs) -> tuple[int, list[AtomT] def retrieve_all_atoms(self) -> list[AtomT]: try: - return [ - self._build_atom_from_dict(document) - for document in self.mongo_atoms_collection.find() - ] + all_atoms: list[AtomT] = [] + document: DocumentT = {} + for document in self.mongo_atoms_collection.find(): + if self._is_document_link(document) and FieldNames.TARGETS not in document: + document[FieldNames.TARGETS] = self._get_document_keys(document) + atom = self._build_atom_from_dict(document) + all_atoms.append(atom) + return all_atoms except Exception as e: - logger().error(f"Error retrieving all atoms: {str(e)}") + logger().error(f"Error retrieving all atoms: {type(e)}: {str(e)}, {document=}") raise e def bulk_insert(self, documents: list[AtomT]) -> None: From 92f88e368395f4e81cfe4199b40214602866a39f Mon Sep 17 00:00:00 2001 From: Angelo Probst Date: Thu, 24 Oct 2024 12:14:12 -0300 Subject: [PATCH 6/6] adding ticket URL to skipped tests --- tests/integration/adapters/test_redis_mongo.py | 16 ++++++++++++---- tests/unit/test_database_public_methods.py | 12 +++++++++--- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/integration/adapters/test_redis_mongo.py b/tests/integration/adapters/test_redis_mongo.py index 3f986b1b..4adec953 100644 --- a/tests/integration/adapters/test_redis_mongo.py +++ b/tests/integration/adapters/test_redis_mongo.py @@ -803,7 +803,9 @@ def test_get_matched_with_pagination(self, _cleanup, _db: RedisMongoDB): ) def test_create_field_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip("Requires new implementation since the new custom attributes were introduced.") + pytest.skip( + "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + ) db = _db self._add_atoms(db) db.commit() @@ -919,7 +921,9 @@ def test_create_compound_index(self, _cleanup, _db: RedisMongoDB): assert my_index in collection_index_names def test_get_atoms_by_field_no_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip("Requires new implementation since the new custom attributes were introduced.") + pytest.skip( + "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + ) db: RedisMongoDB = _db self._add_atoms(db) db.add_link( @@ -944,7 +948,9 @@ def test_get_atoms_by_field_no_index(self, _cleanup, _db: RedisMongoDB): assert explain[0]["executionStats"]["totalKeysExamined"] == 0 def test_get_atoms_by_field_with_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip("Requires new implementation since the new custom attributes were introduced.") + pytest.skip( + "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + ) db: RedisMongoDB = _db self._add_atoms(db) db.add_link( @@ -980,7 +986,9 @@ def test_get_atoms_by_field_with_index(self, _cleanup, _db: RedisMongoDB): ) def test_get_atoms_by_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip("Requires new implementation since the new custom attributes were introduced.") + pytest.skip( + "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + ) db: RedisMongoDB = _db db.add_link( dict_to_link_params( diff --git a/tests/unit/test_database_public_methods.py b/tests/unit/test_database_public_methods.py index add1e1c8..57170021 100644 --- a/tests/unit/test_database_public_methods.py +++ b/tests/unit/test_database_public_methods.py @@ -298,7 +298,9 @@ def test_get_atoms_by_field( ], ) def test_get_atoms_by_index(self, database, index_params, query_params, expected, request): - pytest.skip("Requires new implementation since the new custom attributes were introduced.") + pytest.skip( + "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + ) if database == "in_memory_db": pytest.skip( "ERROR Not implemented. See https://github.com/singnet/das-atom-db/issues/210" @@ -346,7 +348,9 @@ def test_get_atoms_by_index_exceptions(self, database, request): db.get_atoms_by_index("", []) def test_get_atoms_by_text_field_regex(self, database, request): - pytest.skip("Requires new implementation since the new custom attributes were introduced.") + pytest.skip( + "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + ) if database == "in_memory_db": # TODO: fix this pytest.skip( @@ -365,7 +369,9 @@ def test_get_atoms_by_text_field_regex(self, database, request): assert len(atoms) == 1 def test_get_atoms_by_text_field_text(self, database, request): - pytest.skip("Requires new implementation since the new custom attributes were introduced.") + pytest.skip( + "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + ) if database == "in_memory_db": # TODO: fix this pytest.skip(