diff --git a/tests/integration/adapters/animals_kb.py b/tests/integration/adapters/animals_kb.py index ae60668..3bea970 100644 --- a/tests/integration/adapters/animals_kb.py +++ b/tests/integration/adapters/animals_kb.py @@ -22,7 +22,7 @@ node_docs[human] = {"type": "Concept", "name": "human"} node_docs[monkey] = {"type": "Concept", "name": "monkey"} node_docs[chimp] = {"type": "Concept", "name": "chimp"} -node_docs[mammal] = {"type": "Concept", "name": "mammal"} +node_docs[mammal] = {"type": "Concept", "name": "mammal", "custom_attributes": {"name": "mammal"}} node_docs[reptile] = {"type": "Concept", "name": "reptile"} node_docs[snake] = {"type": "Concept", "name": "snake"} node_docs[dinosaur] = {"type": "Concept", "name": "dinosaur"} diff --git a/tests/integration/adapters/test_redis_mongo.py b/tests/integration/adapters/test_redis_mongo.py index 4adec95..094d471 100644 --- a/tests/integration/adapters/test_redis_mongo.py +++ b/tests/integration/adapters/test_redis_mongo.py @@ -803,9 +803,6 @@ def test_get_matched_with_pagination(self, _cleanup, _db: RedisMongoDB): ) def test_create_field_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip( - "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" - ) db = _db self._add_atoms(db) db.commit() @@ -825,25 +822,33 @@ def test_create_field_index(self, _cleanup, _db: RedisMongoDB): collection = db.mongo_atoms_collection - response = collection.find({"named_type": "Similarity", "tag": "DAS"}).explain() + response = collection.find( + {"named_type": "Similarity", "custom_attributes.tag": "DAS"} + ).explain() with pytest.raises(KeyError): response["queryPlanner"]["winningPlan"]["inputStage"]["indexName"] # Create the index - my_index = db.create_field_index(atom_type="link", fields=["tag"], named_type="Similarity") + my_index = db.create_field_index( + atom_type="link", fields=["custom_attributes.tag"], named_type="Similarity" + ) collection_index_names = [idx.get("name") for idx in collection.list_indexes()] # assert my_index in collection_index_names # # Using the index - response = collection.find({"named_type": "Similarity", "tag": "DAS"}).explain() + response = collection.find( + {"named_type": "Similarity", "custom_attributes.tag": "DAS"} + ).explain() assert my_index == response["queryPlanner"]["winningPlan"]["inputStage"]["indexName"] with PyMongoFindExplain(db.mongo_atoms_collection) as explain: - _, doc = db.get_atoms_by_index(my_index, [{"field": "tag", "value": "DAS"}]) + _, doc = db.get_atoms_by_index( + my_index, [{"field": "custom_attributes.tag", "value": "DAS"}] + ) assert doc[0].handle == ExpressionHasher.expression_hash( ExpressionHasher.named_type_hash("Similarity"), [human, monkey] ) @@ -884,13 +889,13 @@ def test_create_text_index(self, _cleanup, _db: RedisMongoDB): # Create the index my_index = db.create_field_index( atom_type="link", - fields=["tag"], + fields=["custom_attributes.tag"], named_type="Similarity", index_type=FieldIndexType.TOKEN_INVERTED_LIST, ) collection_index_names = [idx.get("name") for idx in collection.list_indexes()] - # + print(my_index) assert my_index in collection_index_names def test_create_compound_index(self, _cleanup, _db: RedisMongoDB): @@ -913,7 +918,7 @@ def test_create_compound_index(self, _cleanup, _db: RedisMongoDB): # Create the index my_index = db.create_field_index( atom_type="link", - fields=["type", "tag"], + fields=["custom_attributes.type", "custom_attributes.tag"], named_type="Similarity", index_type=FieldIndexType.BINARY_TREE, ) @@ -921,9 +926,6 @@ def test_create_compound_index(self, _cleanup, _db: RedisMongoDB): assert my_index in collection_index_names def test_get_atoms_by_field_no_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip( - "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" - ) db: RedisMongoDB = _db self._add_atoms(db) db.add_link( @@ -941,16 +943,16 @@ def test_get_atoms_by_field_no_index(self, _cleanup, _db: RedisMongoDB): db.commit() with PyMongoFindExplain(db.mongo_atoms_collection) as explain: - result = db.get_atoms_by_field([{"field": "tag", "value": "DAS"}]) + result = db.get_atoms_by_field([{"field": "custom_attributes.tag", "value": "DAS"}]) assert len(result) == 1 assert explain[0]["executionStats"]["executionSuccess"] assert explain[0]["queryPlanner"]["winningPlan"]["stage"] == "COLLSCAN" assert explain[0]["executionStats"]["totalKeysExamined"] == 0 def test_get_atoms_by_field_with_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip( - "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" - ) + # pytest.skip( + # "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" + # ) db: RedisMongoDB = _db self._add_atoms(db) db.add_link( @@ -966,10 +968,10 @@ def test_get_atoms_by_field_with_index(self, _cleanup, _db: RedisMongoDB): ) ) db.commit() - my_index = db.create_field_index(atom_type="link", fields=["tag"]) + my_index = db.create_field_index(atom_type="link", fields=["custom_attributes.tag"]) with PyMongoFindExplain(db.mongo_atoms_collection) as explain: - result = db.get_atoms_by_field([{"field": "tag", "value": "DAS"}]) + result = db.get_atoms_by_field([{"field": "custom_attributes.tag", "value": "DAS"}]) assert len(result) == 1 assert explain[0]["executionStats"]["executionSuccess"] assert explain[0]["executionStats"]["nReturned"] == 1 @@ -986,9 +988,6 @@ def test_get_atoms_by_field_with_index(self, _cleanup, _db: RedisMongoDB): ) def test_get_atoms_by_index(self, _cleanup, _db: RedisMongoDB): - pytest.skip( - "Requires new implementation since the new custom attributes were introduced. See https://github.com/singnet/das-atom-db/issues/255" - ) db: RedisMongoDB = _db db.add_link( dict_to_link_params( @@ -1016,10 +1015,14 @@ def test_get_atoms_by_index(self, _cleanup, _db: RedisMongoDB): ) db.commit() - my_index = db.create_field_index(atom_type="link", fields=["tag"], named_type="Similarity") + my_index = db.create_field_index( + atom_type="link", fields=["custom_attributes.tag"], named_type="Similarity" + ) with PyMongoFindExplain(db.mongo_atoms_collection) as explain: - _, doc = db.get_atoms_by_index(my_index, [{"field": "tag", "value": "DAS2"}]) + _, doc = db.get_atoms_by_index( + my_index, [{"field": "custom_attributes.tag", "value": "DAS2"}] + ) assert doc[0].handle == ExpressionHasher.expression_hash( ExpressionHasher.named_type_hash("Similarity"), [mammal, monkey] ) @@ -1061,12 +1064,12 @@ def test_get_atoms_by_text_field_with_index(self, _cleanup, _db: RedisMongoDB): db.create_field_index( atom_type="node", - fields=["name"], + fields=["custom_attributes.name"], index_type=FieldIndexType.TOKEN_INVERTED_LIST, ) with PyMongoFindExplain(db.mongo_atoms_collection) as explain: - result = db.get_atoms_by_text_field("mammal") + result = db.get_atoms_by_text_field("custom_attributes.mammal") assert len(result) == 1 assert result[0] == db.get_node_handle("Concept", "mammal") assert explain[0]["executionStats"]["executionSuccess"] @@ -1182,6 +1185,21 @@ def test_add_fields_to_atoms(self, _cleanup, _db: RedisMongoDB): assert db.get_atom(link_handle).custom_attributes["score"] == 0.5 + @pytest.mark.parametrize( + "node", [({"type": "A", "name": "type_a", "custom_attributes": {"status": "ready"}})] + ) + def test_get_atoms_by_index_custom_att(self, node, _cleanup, _db: RedisMongoDB): + node = _db.add_node(NodeT(**node)) + _db.commit() + result = _db.create_field_index("node", fields=["custom_attributes.status"]) + cursor, actual = _db.get_atoms_by_index( + result, [{"field": "custom_attributes.status", "value": "ready"}] + ) + assert cursor == 0 + assert isinstance(actual, list) + assert len(actual) == 1 + assert all([a.handle == node.handle for a in actual]) + def test_commit_with_buffer(self, _cleanup, _db: RedisMongoDB): db = _db assert db.count_atoms() == {"atom_count": 0} diff --git a/tests/unit/adapters/test_redis_mongo_db.py b/tests/unit/adapters/test_redis_mongo_db.py index b33d67d..060f8de 100644 --- a/tests/unit/adapters/test_redis_mongo_db.py +++ b/tests/unit/adapters/test_redis_mongo_db.py @@ -10,7 +10,7 @@ from hyperon_das_atomdb.database import FieldIndexType, FieldNames, LinkT from hyperon_das_atomdb.exceptions import AtomDoesNotExist from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher -from tests.helpers import dict_to_link_params, dict_to_node_params +from tests.helpers import add_node, dict_to_link_params, dict_to_node_params from tests.unit.fixtures import redis_mongo_db # noqa: F401 FILE_CACHE = {} @@ -400,6 +400,18 @@ def test_get_atoms_by_index(self, atom_type, fields, query, expected, database: assert isinstance(actual, list) assert all([a.handle in expected for a in actual]) + @pytest.mark.parametrize("node", [("A", "type_a", "redis_mongo_db", {"status": "ready"})]) + def test_get_atoms_by_index_custom_att(self, node, database: RedisMongoDB): + node = add_node(database, *node) + result = database.create_field_index("node", fields=["custom_attributes.status"]) + cursor, actual = database.get_atoms_by_index( + result, [{"field": "custom_attributes.status", "value": "ready"}] + ) + assert cursor == 0 + assert isinstance(actual, list) + assert len(actual) == 1 + assert all([a.handle == node.handle for a in actual]) + @pytest.mark.parametrize( "text_value,field,expected", [