Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#118] Create a new set in Redis to save custom index filters #119

Merged
merged 4 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[#112] Fix return of the functions get_matched_links(), get_incoming_links(), get_matched_type_template(), get_matched_type() from set to list
[#114] Add create_field_index() to RedisMongoDB adapter
[#120] Refactor Collections in RedisMongoDB adapter
[#120] Refactor Collections in RedisMongoDB adapter
[#118] Create a new set in Redis to save custom index filters
8 changes: 7 additions & 1 deletion hyperon_das_atomdb/adapters/ram_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,5 +498,11 @@ def delete_atom(self, handle: str, **kwargs) -> None:
details=f'handle: {handle}',
)

def create_field_index(self, atom_type: str, field: str, type: Optional[str] = None) -> str:
def create_field_index(
self,
atom_type: str,
field: str,
type: Optional[str] = None,
composite_type: Optional[List[Any]] = None,
) -> str:
pass
127 changes: 105 additions & 22 deletions hyperon_das_atomdb/adapters/redis_mongo_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import pickle
import sys
from copy import deepcopy
from enum import Enum
Expand Down Expand Up @@ -31,7 +33,6 @@ class MongoCollectionNames(str, Enum):
ATOMS = 'atoms'
ATOM_TYPES = 'atom_types'
DAS_CONFIG = 'das_config'
CUSTOM_INDEXES = 'custom_indexes'


class MongoFieldNames(str, Enum):
Expand All @@ -51,6 +52,7 @@ class KeyPrefix(str, Enum):
PATTERNS = 'patterns'
TEMPLATES = 'templates'
NAMED_ENTITIES = 'names'
CUSTOM_INDEXES = 'custom_indexes'


class NodeDocuments:
Expand Down Expand Up @@ -90,14 +92,14 @@ class MongoDBIndex(Index):
def __init__(self, collection: Collection) -> None:
self.collection = collection

def create(self, field: str, **kwargs) -> Tuple[str, Any]:
def create(self, atom_type: str, field: str, **kwargs) -> Tuple[str, Any]:
conditionals = None

for key, value in kwargs.items():
conditionals = {key: {"$eq": value}}
break # only one key-value pair

index_id = self.generate_index_id(field)
index_id = f"{atom_type}_{self.generate_index_id(field, conditionals)}"

index_conditionals = {"name": index_id}

Expand Down Expand Up @@ -128,9 +130,6 @@ def __init__(self, **kwargs: Optional[Dict[str, Any]]) -> None:
self._setup_databases(**kwargs)
self.mongo_atoms_collection = self.mongo_db.get_collection(MongoCollectionNames.ATOMS)
self.mongo_types_collection = self.mongo_db.get_collection(MongoCollectionNames.ATOM_TYPES)
self.mongo_custom_indexes_collection = self.mongo_db.get_collection(
MongoCollectionNames.CUSTOM_INDEXES
)
self.all_mongo_collections = [
(MongoCollectionNames.ATOMS, self.mongo_atoms_collection),
(MongoCollectionNames.ATOM_TYPES, self.mongo_types_collection),
Expand Down Expand Up @@ -661,6 +660,30 @@ def _delete_smember_template(self, handle: str, smember: str) -> None:
def _retrieve_pattern(self, handle: str, **kwargs) -> Tuple[int, List[str]]:
return self._retrieve_hash_targets_value(KeyPrefix.PATTERNS, handle, **kwargs)

def _retrieve_custom_index(self, index_id: str) -> dict:
try:
key = _build_redis_key(KeyPrefix.CUSTOM_INDEXES, index_id)
custom_index_str = self.redis.get(key)

if custom_index_str is None:
logger().info(f"Custom index with ID {index_id} not found in Redis")
return None

custom_index_bytes = base64.b64decode(custom_index_str)
custom_index = pickle.loads(custom_index_bytes)

if not isinstance(custom_index, dict):
logger().error(f"Custom index with ID {index_id} is not a dictionary")
raise ValueError("Custom index is not a dictionary")

return custom_index
except ConnectionError as e:
logger().error(f"Error connecting to Redis: {e}")
raise e
except Exception as e:
logger().error(f"Unexpected error retrieving custom index with ID {index_id}: {e}")
raise e

def _get_redis_members(self, key, **kwargs) -> Tuple[int, list]:
"""
Retrieve members from a Redis set, with optional cursor-based paging.
Expand Down Expand Up @@ -776,6 +799,52 @@ def _process_matched_results(
def _is_document_link(self, document: Dict[str, Any]) -> bool:
return True if MongoFieldNames.COMPOSITE_TYPE in document else False

def _calculate_composite_type_hash(self, composite_type: List[Any]) -> str:
def calculate_composite_type_hashes(composite_type: List[Any]) -> List[str]:
response = []
for type in composite_type:
if isinstance(type, list):
_hash = calculate_composite_type_hashes(type)
response.append(ExpressionHasher.composite_hash(_hash))
else:
response.append(ExpressionHasher.named_type_hash(type))
return response

composite_type_hashes_list = calculate_composite_type_hashes(composite_type)
return ExpressionHasher.composite_hash(composite_type_hashes_list)

def _retrieve_mongo_documents_by_index(
self, collection: Collection, index_id: str, **kwargs
) -> Tuple[int, List[Dict[str, Any]]]:
if MongoDBIndex(collection).index_exists(index_id):
cursor = kwargs.pop('cursor', None)
chunk_size = kwargs.pop('chunk_size', 500)

try:
kwargs.update(self._retrieve_custom_index(index_id))
except Exception as e:
raise e

# Using the hint() method is an additional measure to ensure its use
pymongo_cursor = collection.find(kwargs).hint(index_id)

if cursor is not None:
pymongo_cursor.skip(cursor).limit(chunk_size)

documents = [document for document in pymongo_cursor]

if not documents:
return 0, []

if len(documents) < chunk_size:
return 0, documents
else:
return cursor + chunk_size, documents

return 0, [document for document in pymongo_cursor]
else:
raise ValueError(f"Index '{index_id}' does not exist in collection '{collection}'")

def reindex(self, pattern_index_templates: Optional[Dict[str, Dict[str, Any]]] = None):
if pattern_index_templates is not None:
self.pattern_index_templates = deepcopy(pattern_index_templates)
Expand All @@ -799,18 +868,33 @@ def delete_atom(self, handle: str, **kwargs) -> None:
)
self._update_atom_indexes([document], delete_atom=True)

def create_field_index(self, atom_type: str, field: str, type: Optional[str] = None) -> str:
def create_field_index(
self,
atom_type: str,
field: str,
type: Optional[str] = None,
composite_type: Optional[List[Any]] = None,
) -> str:
if type and composite_type:
raise ValueError("Both type and composite_type cannot be specified")

if type:
kwargs = {MongoFieldNames.TYPE_NAME: type}
elif composite_type:
kwargs = {MongoFieldNames.TYPE: self._calculate_composite_type_hash(composite_type)}

collection = self.mongo_atoms_collection

index_id = ""

try:
exc = ""
index_id, conditionals = MongoDBIndex(collection).create(field, named_type=type)
self.mongo_custom_indexes_collection.update_one(
filter={'_id': index_id},
update={'$set': {'_id': index_id, 'conditionals': conditionals}},
upsert=True,
index_id, conditionals = MongoDBIndex(collection).create(atom_type, field, **kwargs)
serialized_conditionals = pickle.dumps(conditionals)
serialized_conditionals_str = base64.b64encode(serialized_conditionals).decode('utf-8')
self.redis.set(
_build_redis_key(KeyPrefix.CUSTOM_INDEXES, index_id),
serialized_conditionals_str,
)
except pymongo_errors.OperationFailure as e:
exc = e
Expand All @@ -828,14 +912,13 @@ def create_field_index(self, atom_type: str, field: str, type: Optional[str] = N

return index_id

def retrieve_mongo_document_by_index(
self, collection: Collection, index_id: str, **kwargs
) -> List[Dict[str, Any]]:
if MongoDBIndex(collection).index_exists(index_id):
kwargs.update(
self.mongo_custom_indexes_collection.find_one({'_id': index_id})['conditionals']
def get_atoms_by_index(self, index_id: str, **kwargs) -> Union[Tuple[int, list], list]:
try:
documents = self._retrieve_mongo_documents_by_index(
self.mongo_atoms_collection, index_id, **kwargs
)
pymongo_cursor = collection.find(kwargs).hint(
index_id
) # Using the hint() method is an additional measure to ensure its use
return [document for document in pymongo_cursor]
cursor, documents = documents
return cursor, [self.get_atom(document['_id']) for document in documents]
except Exception as e:
logger().error(f"Error retrieving atoms by index: {str(e)}")
raise e
8 changes: 7 additions & 1 deletion hyperon_das_atomdb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,5 +594,11 @@ def delete_atom(self, handle: str, **kwargs) -> None:
... # pragma no cover

@abstractmethod
def create_field_index(self, atom_type: str, field: str, type: Optional[str] = None) -> str:
def create_field_index(
self,
atom_type: str,
field: str,
type: Optional[str] = None,
composite_type: Optional[List[Any]] = None,
) -> str:
... # pragma no cover
8 changes: 4 additions & 4 deletions hyperon_das_atomdb/index.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from abc import ABC, abstractmethod
from typing import Any, Tuple
from typing import Any, Dict, Tuple

from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher


class Index(ABC):
@staticmethod
def generate_index_id(field: str) -> str:
def generate_index_id(field: str, conditionals: Dict[str, Any]) -> str:
"""Generates an index ID based on the field name.

Args:
Expand All @@ -15,10 +15,10 @@ def generate_index_id(field: str) -> str:
Returns:
str: The index ID.
"""
return f"index_{ExpressionHasher._compute_hash(field)}"
return ExpressionHasher._compute_hash(f'{field}{conditionals}')

@abstractmethod
def create(self, field: str, **kwargs) -> Tuple[str, Any]:
def create(self, atom_type: str, field: str, **kwargs) -> Tuple[str, Any]:
"""Creates an index on the given field.

Args:
Expand Down
7 changes: 3 additions & 4 deletions tests/integration/adapters/test_redis_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,11 +969,10 @@ def test_create_field_index(self):
assert my_index == response['queryPlanner']['winningPlan']['inputStage']['indexName']

# Retrieve the document using the index
doc = db.retrieve_mongo_document_by_index(collection, my_index, tag='DAS')
assert doc[0]['_id'] == ExpressionHasher.expression_hash(
_, doc = db.get_atoms_by_index(my_index, tag='DAS')
assert doc[0]['handle'] == ExpressionHasher.expression_hash(
ExpressionHasher.named_type_hash("Similarity"), [human, monkey]
)
assert doc[0]['key_0'] == human
assert doc[0]['key_1'] == monkey
assert doc[0]['targets'] == [human, monkey]

_db_down()
14 changes: 7 additions & 7 deletions tests/unit/adapters/test_redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,24 +1897,24 @@ def test_create_field_index_node_collection(self, database):
assert result == 'name_index_asc'
database.mongo_atoms_collection.create_index.assert_called_once_with(
[('name', 1)],
name='name_index_asc',
partialFilterExpression={'named_type': {'$eq': 'Type'}},
name='node_name_index_asc',
partialFilterExpression={MongoFieldNames.TYPE_NAME: {'$eq': 'Type'}},
)

def test_create_field_index_link_collection(self, database):
database.mongo_atoms_collection = mock.Mock()
database.mongo_atoms_collection.create_index.return_value = 'link_index_asc'
database.mongo_atoms_collection.create_index.return_value = 'field_index_asc'
with mock.patch(
'hyperon_das_atomdb.index.Index.generate_index_id',
return_value='link_index_asc',
return_value='field_index_asc',
):
result = database.create_field_index('link', 'field', 'Type')

assert result == 'link_index_asc'
assert result == 'field_index_asc'
database.mongo_atoms_collection.create_index.assert_called_once_with(
[('field', 1)],
name='link_index_asc',
partialFilterExpression={'named_type': {'$eq': 'Type'}},
name='link_field_index_asc',
partialFilterExpression={MongoFieldNames.TYPE_NAME: {'$eq': 'Type'}},
)

@pytest.mark.skip(reason="Maybe change the way to handle this test")
Expand Down
Loading