Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#204] Redis indexes review #249

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions hyperon_das_atomdb/adapters/redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# pylint: enable=invalid-name


def _build_redis_key(prefix: str, key: str) -> str:
def _build_redis_key(prefix: str, key: str | list[Any]) -> str:
"""
Build a Redis key by concatenating the given prefix and key with a colon separator.

Expand All @@ -61,7 +61,7 @@ def _build_redis_key(prefix: str, key: str) -> str:
Returns:
str: The concatenated Redis key.
"""
return prefix + ":" + key
return prefix + ":" + str(key)


class MongoCollectionNames(str, Enum):
Expand Down Expand Up @@ -547,15 +547,15 @@ def _build_named_type_hash_template(self, template: str | list[Any]) -> str | li
str | list[Any]: The processed hash template corresponding to the provided template.

Raises:
AssertionError: If the template is not a string or an iterable of strings.
ValueError: If the template is not a string or a list of strings.
"""
if isinstance(template, str):
return ExpressionHasher.named_type_hash(template)
else:
assert isinstance(
template, collections.abc.Iterable
), "template must be a string or an iterable of anything"
return [self._build_named_type_hash_template(element) for element in template]
if isinstance(template, list):
return ExpressionHasher.composite_hash(
[self._build_named_type_hash_template(element) for element in template]
)
raise ValueError("Template must be a string or an iterable of anything")

@staticmethod
def _get_document_keys(document: DocumentT) -> HandleListT:
Expand Down Expand Up @@ -770,8 +770,7 @@ def get_incoming_links_atoms(self, atom_handle: str, **kwargs) -> list[AtomT]:

def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT:
try:
hash_base: HandleListT = self._build_named_type_hash_template(template) # type: ignore
template_hash = ExpressionHasher.composite_hash(hash_base)
template_hash = self._build_named_type_hash_template(template)
templates_matched = self._retrieve_hash_targets_value(
KeyPrefix.TEMPLATES, template_hash
)
Expand Down Expand Up @@ -1087,7 +1086,7 @@ def _retrieve_name(self, handle: str) -> str | None:
else:
return None

def _retrieve_hash_targets_value(self, key_prefix: str, handle: str) -> HandleSetT:
def _retrieve_hash_targets_value(self, key_prefix: str, handle: str | list[Any]) -> HandleSetT:
"""
Retrieve the hash targets value for the given handle from Redis.

Expand Down
3 changes: 2 additions & 1 deletion hyperon_das_atomdb/utils/expression_hasher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from hashlib import md5
from typing import Any


class ExpressionHasher:
Expand Down Expand Up @@ -88,7 +89,7 @@ def expression_hash(named_type_hash: str, elements: list[str]) -> str:
return ExpressionHasher.composite_hash([named_type_hash, *elements])

@staticmethod
def composite_hash(hash_base: str | list[str]) -> str:
def composite_hash(hash_base: str | list[Any]) -> str:
"""
Compute the composite hash for the given base.

Expand Down
22 changes: 22 additions & 0 deletions tests/unit/adapters/data/atom_mongo_redis.json
Original file line number Diff line number Diff line change
Expand Up @@ -481,5 +481,27 @@
}
],
"is_toplevel": true
},
{
"type": "LinkTest",
"targets": [
{
"name": "triceratops",
"type": "Concept"
},
{
"name": "rhino",
"type": "Concept"
},
{
"name": "ent",
"type": "Concept"
},
{
"name": "reptile",
"type": "Concept"
}
],
"is_toplevel": false
}
]
80 changes: 55 additions & 25 deletions tests/unit/adapters/test_redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def loader(file_name):
class TestRedisMongoDB:
def _load_database(self, db):
atoms = loader("atom_mongo_redis.json")
self.atom_count = 43
self.atom_count = 44
self.node_count = 14
self.link_count = 29
self.link_count = 30
for atom in atoms:
if "name" in atom:
db.add_node(dict_to_node_params(atom))
Expand Down Expand Up @@ -154,6 +154,16 @@ def test_get_link_handle_link_does_not_exist(self, link_type, targets, database:
("Similarity", [("Concept", "human"), ("Concept", "chimp")], 2),
("Inheritance", [("Concept", "human"), ("Concept", "mammal")], 2),
("Evaluation", [("Concept", "triceratops"), ("Concept", "rhino")], 2),
(
"LinkTest",
[
("Concept", "triceratops"),
("Concept", "rhino"),
("Concept", "ent"),
("Concept", "reptile"),
],
4,
),
],
)
def test_get_link_targets(self, link_type, targets, expected_count, database: RedisMongoDB):
Expand Down Expand Up @@ -253,30 +263,44 @@ def test_get_all_nodes(self, node_type, names, expected, database: RedisMongoDB)
ret = database.get_all_nodes_handles(node_type)
assert len(ret) == expected

def test_get_matched_type_template(self, database: RedisMongoDB):
v1 = database.get_matched_type_template(["Inheritance", "Concept", "Concept"])
v2 = database.get_matched_type_template(["Similarity", "Concept", "Concept"])
v3 = database.get_matched_type_template(["Inheritance", "Concept", "blah"])
v4 = database.get_matched_type_template(["Similarity", "blah", "Concept"])
v5 = database.get_matched_links("Inheritance", ["*", "*"])
v6 = database.get_matched_links("Similarity", ["*", "*"])
assert len(v1) == 12
assert len(v2) == 14
assert len(v3) == 0
assert len(v4) == 0
assert v1 == v5
assert v2 == v6
@pytest.mark.parametrize(
"template,expected",
[
(["Inheritance", "Concept", "Concept"], 12),
(["Similarity", "Concept", "Concept"], 14),
(["Inheritanc", "Concept", "blah"], 0),
],
)
def test_get_matched_type_template(self, template, expected, database: RedisMongoDB):
matched = database.get_matched_type_template(template)
matched_links = database.get_matched_links(template[0], ["*", "*"])
assert len(matched) == expected
assert matched_links == matched

@pytest.mark.parametrize(
"template",
"template,template_equal,expected",
[
(["Inheritance", "Concept", "Concept"], ["Inheritance", ["*", "*"]], 12),
(["Similarity", "Concept", "Concept"], ["Similarity", ["*", "*"]], 14),
],
)
def test_get_matched_type_template_equal(
self, template, template_equal, expected, database: RedisMongoDB
):
matched = database.get_matched_type_template(template)
to_match = database.get_matched_links(*template_equal)
assert len(matched) == expected
assert to_match == matched

@pytest.mark.parametrize(
"template_list",
[
["Inheritance", "Concept", "Concept", {"aaa": "bbb"}],
["Inheritance", "Concept", "Concept", ["aaa", "bbb"]],
],
)
def test_get_matched_type_template_error(self, template, database: RedisMongoDB):
def test_get_matched_type_template_error(self, template_list, database: RedisMongoDB):
with pytest.raises(ValueError) as exc_info:
database.get_matched_type_template(template)
database.get_matched_type_template(template_list)
assert exc_info.type is ValueError

@pytest.mark.parametrize(
Expand Down Expand Up @@ -565,9 +589,9 @@ def test_add_link(self, database: RedisMongoDB):
assert len(all_links_before) == 28
assert len(all_links_after) == 29
assert {
"atom_count": 52,
"atom_count": 53,
"node_count": 20,
"link_count": 32,
"link_count": 33,
} == database.count_atoms({"precise": True})

new_node_handle = database.get_node_handle("Concept", "lion")
Expand All @@ -593,8 +617,8 @@ def test_add_link(self, database: RedisMongoDB):
[
(("Concept", "human"), 8),
(("Concept", "monkey"), 6),
(("Concept", "rhino"), 4),
(("Concept", "reptile"), 3),
(("Concept", "rhino"), 5),
(("Concept", "reptile"), 4),
],
)
def test_get_incoming_links_by_node(self, node, expected_count, database: RedisMongoDB):
Expand Down Expand Up @@ -653,6 +677,7 @@ def test_get_incoming_links_by_links(self, link_type, link_targets, database: Re
assert len(links) > 0
assert all(isinstance(link, str) for link in links)
answer = database.redis.smembers(f"{KeyPrefix.INCOMING_SET.value}:{h}")
assert isinstance(answer, set)
assert sorted(links) == sorted(answer)
assert handle in links
links = database.get_incoming_links_atoms(atom_handle=h)
Expand All @@ -668,6 +693,8 @@ def test_get_incoming_links_by_links(self, link_type, link_targets, database: Re
@pytest.mark.parametrize(
"link_type,link_targets,expected_count",
[
("*", ["*", "*"], 28),
# ("LinkTest", ["*", "*", "*", "*"], 1),
("Similarity", ["*", "af12f10f9ae2002a1607ba0b47ba8407"], 3),
("Similarity", ["af12f10f9ae2002a1607ba0b47ba8407", "*"], 3),
(
Expand All @@ -691,7 +718,10 @@ def test_get_incoming_links_by_links(self, link_type, link_targets, database: Re
def test_redis_patterns(self, link_type, link_targets, expected_count, database: RedisMongoDB):
links = database.get_matched_links(link_type, link_targets)
pattern_hash = ExpressionHasher.composite_hash(
[ExpressionHasher.named_type_hash(link_type), *link_targets]
[
ExpressionHasher.named_type_hash(link_type) if link_type != "*" else "*",
*link_targets,
]
)
answer = database.redis.smembers(f"{KeyPrefix.PATTERNS.value}:{pattern_hash}")
assert len(answer) == len(links) == expected_count
Expand Down Expand Up @@ -1206,7 +1236,7 @@ def test_custom_index_templates_find(
(["Similarity"], 14),
(["Evaluation", "Concept", "Concept"], 1),
(["Evaluation"], 2),
# (["Evaluation", "Concept", ["Evaluation", "Concept", ["Evaluation", "Concept", "Concept"]]], 1)
(["Evaluation", "Concept", ["Evaluation", "Concept", "Concept"]], 1),
],
)
def test_redis_templates(self, template_values, expected_count, database: RedisMongoDB):
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/test_database_public_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@ def _load_db(self, db):
for link in all_links:
db.add_link(dict_to_link_params(link))

def _load_db_redis_mongo(self, db):
import json
import pathlib

path = pathlib.Path(__file__).parent.resolve()
with open(f"{path}/adapters/data/atom_mongo_redis.json") as f:
atoms = json.load(f)
for atom in atoms:
if "name" in atom:
db.add_node(dict_to_node_params(atom))
else:
# atom.update({"named_type": atom["type"]})
top_level = atom["is_toplevel"]
del atom["is_toplevel"]
db.add_link(dict_to_link_params(atom), toplevel=top_level)
try:
db.commit()
except: # noqa: F841,E722
pass

@pytest.mark.parametrize(
"expected",
[
Expand Down Expand Up @@ -562,6 +582,93 @@ def test_get_matched_links(self, database, params, links_len, request):
else:
assert all([check_handle(link) for link in links])

@pytest.mark.parametrize(
"params,links_len",
[
({"link_type": "*", "target_handles": ["*", "*"]}, 2),
({"link_type": "*", "target_handles": ["*"]}, 1),
({"link_type": "Aa", "target_handles": ["123123"]}, 0),
({"link_type": "Aa", "target_handles": ["afdb1c23e7f2da1f33c2a3a91d7959a7"]}, 1),
({"link_type": "*", "target_handles": ["afdb1c23e7f2da1f33c2a3a91d7959a7"]}, 1),
(
{
"link_type": "Bab",
"target_handles": [
"afdb1c23e7f2da1f33c2a3a91d7959a7",
"762745ca7757082780f428ba4116ea46",
],
},
1,
),
(
{
"link_type": "*",
"target_handles": [
"afdb1c23e7f2da1f33c2a3a91d7959a7",
"762745ca7757082780f428ba4116ea46",
],
},
1,
),
({"link_type": "*", "target_handles": ["afdb1c23e7f2da1f33c2a3a91d7959a7", "*"]}, 2),
({"link_type": "*", "target_handles": ["*", "762745ca7757082780f428ba4116ea46"]}, 1),
({"link_type": "*", "target_handles": ["*", "47a0059c63c6943615c232a29a315018"]}, 1),
({"link_type": "CaA", "target_handles": ["*", "47a0059c63c6943615c232a29a315018"]}, 1),
],
)
def test_get_matched_links_more(self, database, params, links_len, request):
db: AtomDB = request.getfixturevalue(database)
node_a_d = {"name": "a", "type": "Test"}
node_b_d = {"name": "b", "type": "Test"}
link_a_d = {"type": "Aa", "targets": [node_a_d]}
# afdb1c23e7f2da1f33c2a3a91d7959a7
add_node(db, "a", "Test", database)
# 762745ca7757082780f428ba4116ea46
add_node(db, "b", "Test", database)
# 47a0059c63c6943615c232a29a315018
add_link(db, "Aa", [dict_to_node_params(node_a_d)], database)
# 51255240d91ea1e045260355cf19d3b2
add_link(
db, "Bab", [dict_to_node_params(node_a_d), dict_to_node_params(node_b_d)], database
)
# 2b9c92b0b219881f4b6121f08f4850ba
add_link(
db, "CaA", [dict_to_node_params(node_a_d), dict_to_link_params(link_a_d)], database
)
links = db.get_matched_links(**params)
assert len(links) == links_len

@pytest.mark.parametrize(
"link_type,link_targets,expected_count",
[
("*", ["*", "*"], 28),
# ("LinkTest", ["*", "*", "*", "*"], 1),
("Similarity", ["*", "af12f10f9ae2002a1607ba0b47ba8407"], 3),
("Similarity", ["af12f10f9ae2002a1607ba0b47ba8407", "*"], 3),
(
"Inheritance",
["c1db9b517073e51eb7ef6fed608ec204", "b99ae727c787f1b13b452fd4c9ce1b9a"],
1,
),
(
"Evaluation",
["d03e59654221c1e8fcda404fd5c8d6cb", "99d18c702e813b07260baf577c60c455"],
1,
),
(
"Evaluation",
["d03e59654221c1e8fcda404fd5c8d6cb", "99d18c702e813b07260baf577c60c455"],
1,
),
("Evaluation", ["*", "99d18c702e813b07260baf577c60c455"], 1),
],
)
def test_patterns(self, link_type, link_targets, expected_count, database, request):
db: AtomDB = request.getfixturevalue(database)
self._load_db_redis_mongo(db)
links = db.get_matched_links(link_type, link_targets)
assert len(links) == expected_count

@pytest.mark.parametrize(
"targets,link_type,expected",
[
Expand Down