Skip to content

Commit

Permalink
Redis index review ok, fixed retrieve link from template
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebrissow committed Oct 22, 2024
1 parent 266731c commit 22fcbb8
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 37 deletions.
15 changes: 7 additions & 8 deletions hyperon_das_atomdb/adapters/redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,15 @@ def _build_named_type_hash_template(self, template: str | list[Any]) -> str | li
str | list[Any]: The processed hash template corresponding to the provided template.
Raises:
AssertionError: If the template is not a string or an iterable of strings.
ValueError: If the template is not a string or a list of strings.
"""
if isinstance(template, str):
return ExpressionHasher.named_type_hash(template)
else:
assert isinstance(
template, collections.abc.Iterable
), "template must be a string or an iterable of anything"
return [self._build_named_type_hash_template(element) for element in template]
if isinstance(template, list):
return ExpressionHasher.composite_hash(
[self._build_named_type_hash_template(element) for element in template]
)
raise ValueError("Template must be a string or an iterable of anything")

@staticmethod
def _get_document_keys(document: dict[str, Any]) -> HandleListT:
Expand Down Expand Up @@ -678,8 +678,7 @@ def get_incoming_links(self, atom_handle: str, **kwargs) -> IncomingLinksT:

def get_matched_type_template(self, template: list[Any], **kwargs) -> HandleSetT:
try:
hash_base: HandleListT = self._build_named_type_hash_template(template) # type: ignore
template_hash = ExpressionHasher.composite_hash(hash_base)
template_hash = self._build_named_type_hash_template(template) # type: ignore
templates_matched = self._retrieve_hash_targets_value(
KeyPrefix.TEMPLATES, template_hash, **kwargs
)
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/adapters/data/atom_mongo_redis.json
Original file line number Diff line number Diff line change
Expand Up @@ -455,5 +455,27 @@
}
],
"is_toplevel": true
},
{
"type": "LinkTest",
"targets": [
{
"name": "triceratops",
"type": "Concept"
},
{
"name": "rhino",
"type": "Concept"
},
{
"name": "ent",
"type": "Concept"
},
{
"name": "reptile",
"type": "Concept"
}
],
"is_toplevel": false
}
]
89 changes: 60 additions & 29 deletions tests/unit/adapters/test_redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ def test_get_link_handle_link_does_not_exist(self, link_type, targets, database:
("Similarity", [("Concept", "human"), ("Concept", "chimp")], 2),
("Inheritance", [("Concept", "human"), ("Concept", "mammal")], 2),
("Evaluation", [("Concept", "triceratops"), ("Concept", "rhino")], 2),
(
"LinkTest",
[
("Concept", "triceratops"),
("Concept", "rhino"),
("Concept", "ent"),
("Concept", "reptile"),
],
4,
),
],
)
def test_get_link_targets(self, link_type, targets, expected_count, database: RedisMongoDB):
Expand Down Expand Up @@ -231,30 +241,44 @@ def test_get_all_nodes(self, node_type, names, expected, database: RedisMongoDB)
ret = database.get_all_nodes(node_type, names=names)
assert len(ret) == expected

def test_get_matched_type_template(self, database: RedisMongoDB):
v1 = database.get_matched_type_template(["Inheritance", "Concept", "Concept"])
v2 = database.get_matched_type_template(["Similarity", "Concept", "Concept"])
v3 = database.get_matched_type_template(["Inheritance", "Concept", "blah"])
v4 = database.get_matched_type_template(["Similarity", "blah", "Concept"])
v5 = database.get_matched_links("Inheritance", ["*", "*"])
v6 = database.get_matched_links("Similarity", ["*", "*"])
assert len(v1) == 12
assert len(v2) == 14
assert len(v3) == 0
assert len(v4) == 0
assert v1 == v5
assert v2 == v6
@pytest.mark.parametrize(
"template,expected",
[
(["Inheritance", "Concept", "Concept"], 12),
(["Similarity", "Concept", "Concept"], 14),
(["Inheritanc", "Concept", "blah"], 0),
],
)
def test_get_matched_type_template(self, template, expected, database: RedisMongoDB):
matched = database.get_matched_type_template(template)
matched_links = database.get_matched_links(template[0], ["*", "*"])
assert len(matched) == expected
assert matched_links == matched

@pytest.mark.parametrize(
"template,template_equal,expected",
[
(["Inheritance", "Concept", "Concept"], ["Inheritance", ["*", "*"]], 12),
(["Similarity", "Concept", "Concept"], ["Similarity", ["*", "*"]], 14),
],
)
def test_get_matched_type_template_equal(
self, template, template_equal, expected, database: RedisMongoDB
):
matched = database.get_matched_type_template(template)
to_match = database.get_matched_links(*template_equal)
assert len(matched) == expected
assert to_match == matched

@pytest.mark.parametrize(
"template",
"template_list",
[
["Inheritance", "Concept", "Concept", {"aaa": "bbb"}],
["Inheritance", "Concept", "Concept", ["aaa", "bbb"]],
],
)
def test_get_matched_type_template_error(self, template, database: RedisMongoDB):
def test_get_matched_type_template_error(self, template_list, database: RedisMongoDB):
with pytest.raises(ValueError) as exc_info:
database.get_matched_type_template(template)
print(database.get_matched_type_template(template_list))
assert exc_info.type is ValueError

@pytest.mark.parametrize(
Expand Down Expand Up @@ -449,14 +473,14 @@ def test_get_link_type_without_cache(self, database: RedisMongoDB):

def test_atom_count(self, database: RedisMongoDB):
response = database.count_atoms({"precise": True})
assert response == {"atom_count": 42, "node_count": 14, "link_count": 28}
assert response == {"atom_count": 43, "node_count": 14, "link_count": 29}

def test_atom_count_fast(self, database: RedisMongoDB):
response = database.count_atoms()
assert response == {"atom_count": 42}
assert response == {"atom_count": 43}

def test_add_node(self, database: RedisMongoDB):
assert {"atom_count": 42} == database.count_atoms()
assert {"atom_count": 43} == database.count_atoms()
all_nodes_before = database.get_all_nodes("Concept")
database.add_node(
{
Expand All @@ -469,9 +493,9 @@ def test_add_node(self, database: RedisMongoDB):
assert len(all_nodes_before) == 14
assert len(all_nodes_after) == 15
assert {
"atom_count": 43,
"atom_count": 44,
"node_count": 15,
"link_count": 28,
"link_count": 29,
} == database.count_atoms({"precise": True})
new_node_handle = database.get_node_handle("Concept", "lion")
assert new_node_handle == ExpressionHasher.terminal_hash("Concept", "lion")
Expand All @@ -483,7 +507,7 @@ def test_add_node(self, database: RedisMongoDB):
assert new_node["name"] == "lion"

def test_add_link(self, database: RedisMongoDB):
assert {"atom_count": 42} == database.count_atoms()
assert {"atom_count": 43} == database.count_atoms()

all_nodes_before = database.get_all_nodes("Concept")
similarity = database.get_all_links("Similarity")
Expand All @@ -510,9 +534,9 @@ def test_add_link(self, database: RedisMongoDB):
assert len(all_links_before) == 28
assert len(all_links_after) == 29
assert {
"atom_count": 45,
"atom_count": 46,
"node_count": 16,
"link_count": 29,
"link_count": 30,
} == database.count_atoms({"precise": True})

new_node_handle = database.get_node_handle("Concept", "lion")
Expand All @@ -538,8 +562,8 @@ def test_add_link(self, database: RedisMongoDB):
[
(("Concept", "human"), 7),
(("Concept", "monkey"), 5),
(("Concept", "rhino"), 4),
(("Concept", "reptile"), 3),
(("Concept", "rhino"), 5),
(("Concept", "reptile"), 4),
],
)
def test_get_incoming_links_by_node(self, node, expected_count, database: RedisMongoDB):
Expand Down Expand Up @@ -598,6 +622,7 @@ def test_get_incoming_links_by_links(self, link_type, link_targets, database: Re
assert len(links) > 0
assert all(isinstance(link, str) for link in links)
answer = database.redis.smembers(f"{KeyPrefix.INCOMING_SET.value}:{h}")
assert isinstance(answer, set)
assert sorted(links) == sorted(answer)
assert handle in links
links = database.get_incoming_links(atom_handle=h, handles_only=False)
Expand All @@ -615,6 +640,8 @@ def test_get_incoming_links_by_links(self, link_type, link_targets, database: Re
@pytest.mark.parametrize(
"link_type,link_targets,expected_count",
[
("*", ["*", "*"], 28),
# ("LinkTest", ["*", "*", "*", "*"], 1),
("Similarity", ["*", "af12f10f9ae2002a1607ba0b47ba8407"], 3),
("Similarity", ["af12f10f9ae2002a1607ba0b47ba8407", "*"], 3),
(
Expand All @@ -638,9 +665,13 @@ def test_get_incoming_links_by_links(self, link_type, link_targets, database: Re
def test_redis_patterns(self, link_type, link_targets, expected_count, database: RedisMongoDB):
links = database.get_matched_links(link_type, link_targets)
pattern_hash = ExpressionHasher.composite_hash(
[ExpressionHasher.named_type_hash(link_type), *link_targets]
[
ExpressionHasher.named_type_hash(link_type) if link_type != "*" else "*",
*link_targets,
]
)
answer = database.redis.smembers(f"{KeyPrefix.PATTERNS.value}:{pattern_hash}")
print(len(answer), len(links))
assert len(answer) == len(links) == expected_count
assert sorted(links) == sorted(answer)
assert len(links) == expected_count
Expand All @@ -656,7 +687,7 @@ def test_redis_patterns(self, link_type, link_targets, expected_count, database:
(["Similarity"], 14),
(["Evaluation", "Concept", "Concept"], 1),
(["Evaluation"], 2),
# (["Evaluation", "Concept", ["Evaluation", "Concept", ["Evaluation", "Concept", "Concept"]]], 1)
(["Evaluation", "Concept", ["Evaluation", "Concept", "Concept"]], 1),
],
)
def test_redis_templates(self, template_values, expected_count, database: RedisMongoDB):
Expand Down
100 changes: 100 additions & 0 deletions tests/unit/test_database_public_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def _load_db(self, db):
for link in all_links:
db.add_link(link)

def _load_db_redis_mongo(self, db):
import json
import pathlib

path = pathlib.Path(__file__).parent.resolve()
with open(f"{path}/adapters/data/atom_mongo_redis.json") as f:
atoms = json.load(f)
for atom in atoms:
if "name" in atom:
db.add_node(atom)
else:
db.add_link(atom, toplevel=atom["is_toplevel"])
try:
db.commit()
except: # noqa: F841,E722
pass

@pytest.mark.parametrize(
"expected",
[
Expand Down Expand Up @@ -551,6 +568,89 @@ def test_get_matched_links(self, database, params, links_len, request):
else:
assert all([check_handle(link) for link in links])

@pytest.mark.parametrize(
"params,links_len",
[
({"link_type": "*", "target_handles": ["*", "*"]}, 2),
({"link_type": "*", "target_handles": ["*"]}, 1),
({"link_type": "Aa", "target_handles": ["123123"]}, 0),
({"link_type": "Aa", "target_handles": ["afdb1c23e7f2da1f33c2a3a91d7959a7"]}, 1),
({"link_type": "*", "target_handles": ["afdb1c23e7f2da1f33c2a3a91d7959a7"]}, 1),
(
{
"link_type": "Bab",
"target_handles": [
"afdb1c23e7f2da1f33c2a3a91d7959a7",
"762745ca7757082780f428ba4116ea46",
],
},
1,
),
(
{
"link_type": "*",
"target_handles": [
"afdb1c23e7f2da1f33c2a3a91d7959a7",
"762745ca7757082780f428ba4116ea46",
],
},
1,
),
({"link_type": "*", "target_handles": ["afdb1c23e7f2da1f33c2a3a91d7959a7", "*"]}, 2),
({"link_type": "*", "target_handles": ["*", "762745ca7757082780f428ba4116ea46"]}, 1),
({"link_type": "*", "target_handles": ["*", "47a0059c63c6943615c232a29a315018"]}, 1),
({"link_type": "CaA", "target_handles": ["*", "47a0059c63c6943615c232a29a315018"]}, 1),
],
)
def test_get_matched_links_more(self, database, params, links_len, request):
db: AtomDB = request.getfixturevalue(database)
node_a_d = {"name": "a", "type": "Test"}
node_b_d = {"name": "b", "type": "Test"}
link_a_d = {"type": "Aa", "targets": [node_a_d]}
# afdb1c23e7f2da1f33c2a3a91d7959a7
add_node(db, "a", "Test", database)
# 762745ca7757082780f428ba4116ea46
add_node(db, "b", "Test", database)
# 47a0059c63c6943615c232a29a315018
add_link(db, "Aa", [node_a_d], database)
# 51255240d91ea1e045260355cf19d3b2
add_link(db, "Bab", [node_a_d, node_b_d], database)
# 2b9c92b0b219881f4b6121f08f4850ba
add_link(db, "CaA", [node_a_d, link_a_d], database)
links = db.get_matched_links(**params)
assert len(links) == links_len

@pytest.mark.parametrize(
"link_type,link_targets,expected_count",
[
("*", ["*", "*"], 28),
# ("LinkTest", ["*", "*", "*", "*"], 1),
("Similarity", ["*", "af12f10f9ae2002a1607ba0b47ba8407"], 3),
("Similarity", ["af12f10f9ae2002a1607ba0b47ba8407", "*"], 3),
(
"Inheritance",
["c1db9b517073e51eb7ef6fed608ec204", "b99ae727c787f1b13b452fd4c9ce1b9a"],
1,
),
(
"Evaluation",
["d03e59654221c1e8fcda404fd5c8d6cb", "99d18c702e813b07260baf577c60c455"],
1,
),
(
"Evaluation",
["d03e59654221c1e8fcda404fd5c8d6cb", "99d18c702e813b07260baf577c60c455"],
1,
),
("Evaluation", ["*", "99d18c702e813b07260baf577c60c455"], 1),
],
)
def test_patterns(self, link_type, link_targets, expected_count, database, request):
db: AtomDB = request.getfixturevalue(database)
self._load_db_redis_mongo(db)
links = db.get_matched_links(link_type, link_targets)
assert len(links) == expected_count

@pytest.mark.parametrize(
"targets,link_type,expected",
[
Expand Down

0 comments on commit 22fcbb8

Please sign in to comment.