diff --git a/CHANGELOG b/CHANGELOG index f82f785f..649cf152 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,4 @@ [#112] Fix return of the functions get_matched_links(), get_incoming_links(), get_matched_type_template(), get_matched_type() from set to list [#114] Add create_field_index() to RedisMongoDB adapter -[#120] Refactor Collections in RedisMongoDB adapter \ No newline at end of file +[#120] Refactor Collections in RedisMongoDB adapter +[#118] Create a new set in Redis to save custom index filters \ No newline at end of file diff --git a/hyperon_das_atomdb/adapters/ram_only.py b/hyperon_das_atomdb/adapters/ram_only.py index 30ead233..80a2ebe6 100644 --- a/hyperon_das_atomdb/adapters/ram_only.py +++ b/hyperon_das_atomdb/adapters/ram_only.py @@ -498,5 +498,11 @@ def delete_atom(self, handle: str, **kwargs) -> None: details=f'handle: {handle}', ) - def create_field_index(self, atom_type: str, field: str, type: Optional[str] = None) -> str: + def create_field_index( + self, + atom_type: str, + field: str, + type: Optional[str] = None, + composite_type: Optional[List[Any]] = None, + ) -> str: pass diff --git a/hyperon_das_atomdb/adapters/redis_mongo_db.py b/hyperon_das_atomdb/adapters/redis_mongo_db.py index 20097646..5cac7485 100644 --- a/hyperon_das_atomdb/adapters/redis_mongo_db.py +++ b/hyperon_das_atomdb/adapters/redis_mongo_db.py @@ -1,3 +1,5 @@ +import base64 +import pickle import sys from copy import deepcopy from enum import Enum @@ -31,7 +33,6 @@ class MongoCollectionNames(str, Enum): ATOMS = 'atoms' ATOM_TYPES = 'atom_types' DAS_CONFIG = 'das_config' - CUSTOM_INDEXES = 'custom_indexes' class MongoFieldNames(str, Enum): @@ -51,6 +52,7 @@ class KeyPrefix(str, Enum): PATTERNS = 'patterns' TEMPLATES = 'templates' NAMED_ENTITIES = 'names' + CUSTOM_INDEXES = 'custom_indexes' class NodeDocuments: @@ -90,14 +92,14 @@ class MongoDBIndex(Index): def __init__(self, collection: Collection) -> None: self.collection = collection - def create(self, field: str, **kwargs) -> Tuple[str, Any]: + def create(self, atom_type: str, field: str, **kwargs) -> Tuple[str, Any]: conditionals = None for key, value in kwargs.items(): conditionals = {key: {"$eq": value}} break # only one key-value pair - index_id = self.generate_index_id(field) + index_id = f"{atom_type}_{self.generate_index_id(field, conditionals)}" index_conditionals = {"name": index_id} @@ -128,9 +130,6 @@ def __init__(self, **kwargs: Optional[Dict[str, Any]]) -> None: self._setup_databases(**kwargs) self.mongo_atoms_collection = self.mongo_db.get_collection(MongoCollectionNames.ATOMS) self.mongo_types_collection = self.mongo_db.get_collection(MongoCollectionNames.ATOM_TYPES) - self.mongo_custom_indexes_collection = self.mongo_db.get_collection( - MongoCollectionNames.CUSTOM_INDEXES - ) self.all_mongo_collections = [ (MongoCollectionNames.ATOMS, self.mongo_atoms_collection), (MongoCollectionNames.ATOM_TYPES, self.mongo_types_collection), @@ -661,6 +660,30 @@ def _delete_smember_template(self, handle: str, smember: str) -> None: def _retrieve_pattern(self, handle: str, **kwargs) -> Tuple[int, List[str]]: return self._retrieve_hash_targets_value(KeyPrefix.PATTERNS, handle, **kwargs) + def _retrieve_custom_index(self, index_id: str) -> dict: + try: + key = _build_redis_key(KeyPrefix.CUSTOM_INDEXES, index_id) + custom_index_str = self.redis.get(key) + + if custom_index_str is None: + logger().info(f"Custom index with ID {index_id} not found in Redis") + return None + + custom_index_bytes = base64.b64decode(custom_index_str) + custom_index = pickle.loads(custom_index_bytes) + + if not isinstance(custom_index, dict): + logger().error(f"Custom index with ID {index_id} is not a dictionary") + raise ValueError("Custom index is not a dictionary") + + return custom_index + except ConnectionError as e: + logger().error(f"Error connecting to Redis: {e}") + raise e + except Exception as e: + logger().error(f"Unexpected error retrieving custom index with ID {index_id}: {e}") + raise e + def _get_redis_members(self, key, **kwargs) -> Tuple[int, list]: """ Retrieve members from a Redis set, with optional cursor-based paging. @@ -776,6 +799,52 @@ def _process_matched_results( def _is_document_link(self, document: Dict[str, Any]) -> bool: return True if MongoFieldNames.COMPOSITE_TYPE in document else False + def _calculate_composite_type_hash(self, composite_type: List[Any]) -> str: + def calculate_composite_type_hashes(composite_type: List[Any]) -> List[str]: + response = [] + for type in composite_type: + if isinstance(type, list): + _hash = calculate_composite_type_hashes(type) + response.append(ExpressionHasher.composite_hash(_hash)) + else: + response.append(ExpressionHasher.named_type_hash(type)) + return response + + composite_type_hashes_list = calculate_composite_type_hashes(composite_type) + return ExpressionHasher.composite_hash(composite_type_hashes_list) + + def _retrieve_mongo_documents_by_index( + self, collection: Collection, index_id: str, **kwargs + ) -> Tuple[int, List[Dict[str, Any]]]: + if MongoDBIndex(collection).index_exists(index_id): + cursor = kwargs.pop('cursor', None) + chunk_size = kwargs.pop('chunk_size', 500) + + try: + kwargs.update(self._retrieve_custom_index(index_id)) + except Exception as e: + raise e + + # Using the hint() method is an additional measure to ensure its use + pymongo_cursor = collection.find(kwargs).hint(index_id) + + if cursor is not None: + pymongo_cursor.skip(cursor).limit(chunk_size) + + documents = [document for document in pymongo_cursor] + + if not documents: + return 0, [] + + if len(documents) < chunk_size: + return 0, documents + else: + return cursor + chunk_size, documents + + return 0, [document for document in pymongo_cursor] + else: + raise ValueError(f"Index '{index_id}' does not exist in collection '{collection}'") + def reindex(self, pattern_index_templates: Optional[Dict[str, Dict[str, Any]]] = None): if pattern_index_templates is not None: self.pattern_index_templates = deepcopy(pattern_index_templates) @@ -799,18 +868,33 @@ def delete_atom(self, handle: str, **kwargs) -> None: ) self._update_atom_indexes([document], delete_atom=True) - def create_field_index(self, atom_type: str, field: str, type: Optional[str] = None) -> str: + def create_field_index( + self, + atom_type: str, + field: str, + type: Optional[str] = None, + composite_type: Optional[List[Any]] = None, + ) -> str: + if type and composite_type: + raise ValueError("Both type and composite_type cannot be specified") + + if type: + kwargs = {MongoFieldNames.TYPE_NAME: type} + elif composite_type: + kwargs = {MongoFieldNames.TYPE: self._calculate_composite_type_hash(composite_type)} + collection = self.mongo_atoms_collection index_id = "" try: exc = "" - index_id, conditionals = MongoDBIndex(collection).create(field, named_type=type) - self.mongo_custom_indexes_collection.update_one( - filter={'_id': index_id}, - update={'$set': {'_id': index_id, 'conditionals': conditionals}}, - upsert=True, + index_id, conditionals = MongoDBIndex(collection).create(atom_type, field, **kwargs) + serialized_conditionals = pickle.dumps(conditionals) + serialized_conditionals_str = base64.b64encode(serialized_conditionals).decode('utf-8') + self.redis.set( + _build_redis_key(KeyPrefix.CUSTOM_INDEXES, index_id), + serialized_conditionals_str, ) except pymongo_errors.OperationFailure as e: exc = e @@ -828,14 +912,13 @@ def create_field_index(self, atom_type: str, field: str, type: Optional[str] = N return index_id - def retrieve_mongo_document_by_index( - self, collection: Collection, index_id: str, **kwargs - ) -> List[Dict[str, Any]]: - if MongoDBIndex(collection).index_exists(index_id): - kwargs.update( - self.mongo_custom_indexes_collection.find_one({'_id': index_id})['conditionals'] + def get_atoms_by_index(self, index_id: str, **kwargs) -> Union[Tuple[int, list], list]: + try: + documents = self._retrieve_mongo_documents_by_index( + self.mongo_atoms_collection, index_id, **kwargs ) - pymongo_cursor = collection.find(kwargs).hint( - index_id - ) # Using the hint() method is an additional measure to ensure its use - return [document for document in pymongo_cursor] + cursor, documents = documents + return cursor, [self.get_atom(document['_id']) for document in documents] + except Exception as e: + logger().error(f"Error retrieving atoms by index: {str(e)}") + raise e diff --git a/hyperon_das_atomdb/database.py b/hyperon_das_atomdb/database.py index e78b3dd5..0048b877 100644 --- a/hyperon_das_atomdb/database.py +++ b/hyperon_das_atomdb/database.py @@ -594,5 +594,11 @@ def delete_atom(self, handle: str, **kwargs) -> None: ... # pragma no cover @abstractmethod - def create_field_index(self, atom_type: str, field: str, type: Optional[str] = None) -> str: + def create_field_index( + self, + atom_type: str, + field: str, + type: Optional[str] = None, + composite_type: Optional[List[Any]] = None, + ) -> str: ... # pragma no cover diff --git a/hyperon_das_atomdb/index.py b/hyperon_das_atomdb/index.py index 06d1a9ab..0a04d4c0 100644 --- a/hyperon_das_atomdb/index.py +++ b/hyperon_das_atomdb/index.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod -from typing import Any, Tuple +from typing import Any, Dict, Tuple from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher class Index(ABC): @staticmethod - def generate_index_id(field: str) -> str: + def generate_index_id(field: str, conditionals: Dict[str, Any]) -> str: """Generates an index ID based on the field name. Args: @@ -15,10 +15,10 @@ def generate_index_id(field: str) -> str: Returns: str: The index ID. """ - return f"index_{ExpressionHasher._compute_hash(field)}" + return ExpressionHasher._compute_hash(f'{field}{conditionals}') @abstractmethod - def create(self, field: str, **kwargs) -> Tuple[str, Any]: + def create(self, atom_type: str, field: str, **kwargs) -> Tuple[str, Any]: """Creates an index on the given field. Args: diff --git a/tests/integration/adapters/test_redis_mongo.py b/tests/integration/adapters/test_redis_mongo.py index 5612997e..dd32862f 100644 --- a/tests/integration/adapters/test_redis_mongo.py +++ b/tests/integration/adapters/test_redis_mongo.py @@ -969,11 +969,10 @@ def test_create_field_index(self): assert my_index == response['queryPlanner']['winningPlan']['inputStage']['indexName'] # Retrieve the document using the index - doc = db.retrieve_mongo_document_by_index(collection, my_index, tag='DAS') - assert doc[0]['_id'] == ExpressionHasher.expression_hash( + _, doc = db.get_atoms_by_index(my_index, tag='DAS') + assert doc[0]['handle'] == ExpressionHasher.expression_hash( ExpressionHasher.named_type_hash("Similarity"), [human, monkey] ) - assert doc[0]['key_0'] == human - assert doc[0]['key_1'] == monkey + assert doc[0]['targets'] == [human, monkey] _db_down() diff --git a/tests/unit/adapters/test_redis_mongo_db.py b/tests/unit/adapters/test_redis_mongo_db.py index 6eeca989..7f0e4aca 100644 --- a/tests/unit/adapters/test_redis_mongo_db.py +++ b/tests/unit/adapters/test_redis_mongo_db.py @@ -1897,24 +1897,24 @@ def test_create_field_index_node_collection(self, database): assert result == 'name_index_asc' database.mongo_atoms_collection.create_index.assert_called_once_with( [('name', 1)], - name='name_index_asc', - partialFilterExpression={'named_type': {'$eq': 'Type'}}, + name='node_name_index_asc', + partialFilterExpression={MongoFieldNames.TYPE_NAME: {'$eq': 'Type'}}, ) def test_create_field_index_link_collection(self, database): database.mongo_atoms_collection = mock.Mock() - database.mongo_atoms_collection.create_index.return_value = 'link_index_asc' + database.mongo_atoms_collection.create_index.return_value = 'field_index_asc' with mock.patch( 'hyperon_das_atomdb.index.Index.generate_index_id', - return_value='link_index_asc', + return_value='field_index_asc', ): result = database.create_field_index('link', 'field', 'Type') - assert result == 'link_index_asc' + assert result == 'field_index_asc' database.mongo_atoms_collection.create_index.assert_called_once_with( [('field', 1)], - name='link_index_asc', - partialFilterExpression={'named_type': {'$eq': 'Type'}}, + name='link_field_index_asc', + partialFilterExpression={MongoFieldNames.TYPE_NAME: {'$eq': 'Type'}}, ) @pytest.mark.skip(reason="Maybe change the way to handle this test")