Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#203] Review AtomDB tests #212

Merged
merged 6 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions hyperon_das_atomdb/adapters/ram_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,8 @@ def get_matched_links(
link_type_hash = (
WILDCARD if link_type == WILDCARD else ExpressionHasher.named_type_hash(link_type)
)

if link_type in UNORDERED_LINK_TYPES:
# NOTE unreachable
if link_type in UNORDERED_LINK_TYPES: # pragma: no cover
logger().error(
"Failed to get matched links: Queries with unordered links are not implemented. "
f"link_type: {link_type}"
Expand Down Expand Up @@ -598,7 +598,9 @@ def get_matched_type(self, link_type: str, **kwargs) -> MatchedTypesResultT:
return kwargs.get("cursor"), self._filter_non_toplevel(templates_matched)
return kwargs.get("cursor"), templates_matched

def get_atoms_by_field(self, query: list[OrderedDict[str, str]]) -> list[str]:
def get_atoms_by_field(
self, query: list[OrderedDict[str, str]]
) -> list[str]: # pragma: no cover
raise NotImplementedError()

def get_atoms_by_index(
Expand All @@ -607,18 +609,20 @@ def get_atoms_by_index(
query: list[OrderedDict[str, str]],
cursor: int = 0,
chunk_size: int = 500,
) -> tuple[int, list[AtomT]]:
) -> tuple[int, list[AtomT]]: # pragma: no cover
raise NotImplementedError()

def get_atoms_by_text_field(
self,
text_value: str,
field: str | None = None,
text_index_id: str | None = None,
) -> list[str]:
) -> list[str]: # pragma: no cover
raise NotImplementedError()

def get_node_by_name_starting_with(self, node_type: str, startswith: str) -> list[str]:
def get_node_by_name_starting_with(
self, node_type: str, startswith: str
) -> list[str]: # pragma: no cover
raise NotImplementedError()

def _get_atom(self, handle: str) -> AtomT | None:
Expand Down Expand Up @@ -672,7 +676,8 @@ def add_node(self, node_params: NodeParamsT) -> NodeT | None:

def add_link(self, link_params: LinkParamsT, toplevel: bool = True) -> LinkT | None:
result = self._build_link(link_params, toplevel)
if result is None:
# NOTE unreachable
if result is None: # pragma: no cover
return None
handle, link, _ = result
self.db.link[handle] = link
Expand All @@ -681,7 +686,7 @@ def add_link(self, link_params: LinkParamsT, toplevel: bool = True) -> LinkT | N

def reindex(
self, pattern_index_templates: dict[str, list[dict[str, Any]]] | None = None
) -> None:
) -> None: # pragma: no cover
raise NotImplementedError()

def delete_atom(self, handle: str, **kwargs) -> None:
Expand Down Expand Up @@ -710,7 +715,7 @@ def create_field_index(
named_type: str | None = None,
composite_type: list[Any] | None = None,
index_type: FieldIndexType | None = None,
) -> str:
) -> str: # pragma: no cover
raise NotImplementedError()

def bulk_insert(self, documents: list[AtomT]) -> None:
Expand All @@ -732,5 +737,5 @@ def retrieve_all_atoms(self) -> list[AtomT]:
logger().error(f"Error retrieving all atoms: {str(e)}")
raise e

def commit(self, **kwargs) -> None:
def commit(self, **kwargs) -> None: # pragma: no cover
raise NotImplementedError()
56 changes: 2 additions & 54 deletions hyperon_das_atomdb/adapters/redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,60 +89,6 @@ class MongoIndexType(str, Enum):
TEXT = "text"


class NodeDocuments:
"""Class for managing node documents in MongoDB."""

def __init__(self, collection: Collection) -> None:
"""
Initialize the NodeDocuments class with a MongoDB collection.

Args:
collection (Collection): The MongoDB collection to manage node documents.
"""
self.mongo_collection = collection
self.cached_nodes: dict[str, Any] = {}
self.count = 0

def add(self) -> None:
"""Increment the count of node documents managed by this instance."""
self.count += 1

def get(self, handle: str, default_value: Any = None) -> Any:
"""
Retrieve a node document from the MongoDB collection using the given handle.

Args:
handle (str): The unique identifier for the node document.
default_value (Any): The value to return if the node document is not found.
Defaults to None.

Returns:
The node document if found, otherwise the default value.
"""
mongo_filter = {FieldNames.ID_HASH: handle}
node = self.mongo_collection.find_one(mongo_filter)
return node if node else default_value

def size(self) -> int:
"""
Return the count of node documents managed by this instance.

Returns:
int: The count of node documents.
"""
return self.count

def values(self) -> Iterable[dict[str, Any]]:
"""
Yield all node documents from the MongoDB collection.

Returns:
generator: A generator yielding each document in the MongoDB collection.
"""
for document in self.mongo_collection.find():
yield document


class _HashableDocument:
"""Class for making documents hashable."""

Expand Down Expand Up @@ -763,6 +709,7 @@ def get_matched_links(

link_type_hash = WILDCARD if link_type == WILDCARD else self._get_atom_type_hash(link_type)

# NOTE unreachable
if link_type in UNORDERED_LINK_TYPES:
target_handles = sorted(target_handles)

Expand Down Expand Up @@ -861,6 +808,7 @@ def commit(self, **kwargs) -> None:
{id_tag: document[id_tag]}, document, upsert=True
)
self._update_atom_indexes([document])

except Exception as e:
logger().error(f"Failed to commit buffer - Details: {str(e)}")
raise e
Expand Down
4 changes: 0 additions & 4 deletions hyperon_das_atomdb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,10 @@ def _build_link(
raise ValueError("The target must be a dictionary")
if "targets" not in target:
atom = self.add_node(target)
if atom is None:
return None
atom_hash = atom["composite_type_hash"]
composite_type.append(atom_hash)
else:
atom = self.add_link(target, toplevel=False)
if atom is None:
return None
atom_hash = atom["composite_type_hash"]
composite_type.append(atom["composite_type"])
composite_type_hash.append(atom_hash)
Expand Down
1 change: 1 addition & 0 deletions hyperon_das_atomdb/utils/expression_hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def composite_hash(hash_base: str | list[str]) -> str:
return ExpressionHasher._compute_hash(
ExpressionHasher.compound_separator.join(hash_base)
)
# TODO unreachable
else:
raise ValueError(
"Invalid base to compute composite hash: " f"{type(hash_base)}: {hash_base}"
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/adapters/test_ram_only_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from hyperon_das_atomdb.adapters.ram_only import InMemoryDB
from tests.unit.fixtures import in_memory_db # noqa: F401
from tests.unit.test_database import _check_handle


class TestRamOnlyExtra:
def test__build_atom_type_key_hash(self, in_memory_db): # noqa: F811
db: InMemoryDB = in_memory_db
hash = db._build_atom_type_key_hash("A")
assert _check_handle(hash)
assert hash == "2c832bdcd9d74bf961205676d861540a"

def test__delete_atom_type(self, in_memory_db): # noqa: F811
db: InMemoryDB = in_memory_db
node = db.add_node({"name": "A", "type": "A"})
assert len(db.all_named_types) == 1
assert node["named_type"] in db.all_named_types
db._delete_atom_type("A")
assert len(db.all_named_types) == 0
assert node["named_type"] not in db.all_named_types

def test__update_atom_indexes(self, in_memory_db): # noqa: F811
db: InMemoryDB = in_memory_db
node = db.add_node({"name": "A", "type": "A"})
db._update_atom_indexes([node])
assert len(db.all_named_types) == 1
41 changes: 41 additions & 0 deletions tests/unit/adapters/test_redis_mongo_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from unittest import mock

import pytest

from hyperon_das_atomdb.adapters.redis_mongo_db import MongoDBIndex, RedisMongoDB, _HashableDocument
from tests.unit.fixtures import redis_mongo_db # noqa: F401


class TestRedisMongoExtra:
def test_hashable_document_str(self, redis_mongo_db): # noqa: F811
db = redis_mongo_db
node = db._build_node({"type": "A", "name": "A"})
hashable = _HashableDocument(node)
str_hashable = str(hashable)
assert isinstance(str_hashable, str)
assert hashable
assert str(node) == str_hashable

@pytest.mark.parametrize(
"params",
[
{"atom_type": "A", "fields": []},
{"atom_type": "A", "fields": None},
],
)
def test_index_create_exceptions(self, params, request):
db = request.getfixturevalue("redis_mongo_db")
mi = MongoDBIndex(db.mongo_db)
with pytest.raises(ValueError):
mi.create(**params)

@mock.patch(
"hyperon_das_atomdb.adapters.redis_mongo_db.MongoClient", return_value=mock.MagicMock()
)
@mock.patch("hyperon_das_atomdb.adapters.redis_mongo_db.Redis", return_value=mock.MagicMock())
@mock.patch(
"hyperon_das_atomdb.adapters.redis_mongo_db.RedisCluster", return_value=mock.MagicMock()
)
def test_create_db_connection_mongo(self, mock_mongo, mock_redis, mock_redis_cluster):
RedisMongoDB(mongo_tls_ca_file="/tmp/mock", redis_password="12", redis_username="A")
RedisMongoDB(redis_cluster=False)
128 changes: 128 additions & 0 deletions tests/unit/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from unittest import mock

import mongomock
import pytest

from hyperon_das_atomdb.adapters.ram_only import InMemoryDB
from hyperon_das_atomdb.adapters.redis_mongo_db import MongoCollectionNames, RedisMongoDB


class MockRedis:
def __init__(self, cache=dict()):
self.cache = cache

def get(self, key):
if key in self.cache:
return self.cache[key]
return None

def set(self, key, value, *args, **kwargs):
if self.cache:
self.cache[key] = value
return "OK"
return None

def hget(self, hash, key):
if hash in self.cache:
if key in self.cache[hash]:
return self.cache[hash][key]
return None

def hset(self, hash, key, value, *args, **kwargs):
if self.cache:
self.cache[hash][key] = value
return 1
return None

def exists(self, key):
if key in self.cache:
return 1
return 0

def cache_overwrite(self, cache=dict()):
self.cache = cache

def sadd(self, key, *members):
if key not in self.cache:
self.cache[key] = set()
before_count = len(self.cache[key])
self.cache[key].update(members)
after_count = len(self.cache[key])
return after_count - before_count

def smembers(self, key):
if key in self.cache:
return self.cache[key]
return set()

def flushall(self):
self.cache.clear()

def delete(self, *keys):
deleted_count = 0
for key in keys:
if key in self.cache:
del self.cache[key]
deleted_count += 1
return deleted_count

def getdel(self, key):
value = self.cache.get(key)
if key in self.cache:
del self.cache[key]
return value

def srem(self, key, *members):
if key not in self.cache:
return 0
initial_count = len(self.cache[key])
self.cache[key].difference_update(members)
removed_count = initial_count - len(self.cache[key])
return removed_count

def sscan(self, name, cursor=0, match=None, count=None):
key = name
if key not in self.cache:
return (0, [])

elements = list(self.cache[key])
if match:
elements = [e for e in elements if match in e]
start = cursor
end = min(start + (count if count else len(elements)), len(elements))
new_cursor = end if end < len(elements) else 0

return (new_cursor, elements[start:end])


@pytest.fixture
def redis_mongo_db():
mongo_db = mongomock.MongoClient().db
redis_db = MockRedis()
with mock.patch(
"hyperon_das_atomdb.adapters.redis_mongo_db.RedisMongoDB._connection_mongo_db",
return_value=mongo_db,
), mock.patch(
"hyperon_das_atomdb.adapters.redis_mongo_db.RedisMongoDB._connection_redis",
return_value=redis_db,
):
db = RedisMongoDB()
db.mongo_atoms_collection = mongo_db.collection
db.mongo_types_collection = mongo_db.collection

db.all_mongo_collections = [
(MongoCollectionNames.ATOMS, db.mongo_atoms_collection),
(MongoCollectionNames.ATOM_TYPES, db.mongo_types_collection),
]
db.mongo_bulk_insertion_buffer = {
MongoCollectionNames.ATOMS: tuple([db.mongo_atoms_collection, set()]),
MongoCollectionNames.ATOM_TYPES: tuple([db.mongo_types_collection, set()]),
}

yield db


@pytest.fixture
def in_memory_db():
db = InMemoryDB()
yield db
Loading