From 1519768abf3e83f3e5464a836e3d6e31f3b903f6 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Fri, 25 Aug 2023 14:25:32 -0400 Subject: [PATCH] MemoryStore: mongomock -> pymongo-inmemory --- setup.cfg | 5 ++ setup.py | 1 + src/maggma/stores/mongolike.py | 120 ++++++++++++--------------------- tests/stores/test_mongolike.py | 4 +- 4 files changed, 50 insertions(+), 80 deletions(-) create mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..0ddc78462 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +# TODO - this entire file can be removed once pymongo-inmemory supports pyproject.toml +# see https://github.com/kaizendorks/pymongo_inmemory/issues/81 +[pymongo_inmemory] +use_local_mongod = False +mongod_port = 27019 diff --git a/setup.py b/setup.py index 8f1617453..639ab5f97 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "pydantic<2.0", "pydantic>=0.32.2", "pymongo>=4.2.0", + "pymongo-inmemory", "monty>=1.0.2", "mongomock>=3.10.0", "pydash>=4.1.0", diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index 6cd734b7b..563f3ad8f 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -5,7 +5,7 @@ """ import warnings -from itertools import chain, groupby +from itertools import chain from pathlib import Path from socket import socket @@ -18,15 +18,15 @@ from typing_extensions import Literal -import mongomock import orjson from monty.dev import requires from monty.io import zopen from monty.json import MSONable, jsanitize from monty.serialization import loadfn -from pydash import get, has, set_ +from pydash import has, set_ from pymongo import MongoClient, ReplaceOne, uri_parser from pymongo.errors import ConfigurationError, DocumentTooLarge, OperationFailure +from pymongo_inmemory import MongoClient as MemoryClient from sshtunnel import SSHTunnelForwarder from maggma.core import Sort, Store, StoreError @@ -139,10 +139,12 @@ def __init__( port: TCP port to connect to username: Username for the collection password: Password to connect with + ssh_tunnel: SSHTunnel instance to use for connection. safe_update: fail gracefully on DocumentTooLarge errors on update auth_source: The database to authenticate on. Defaults to the database name. default_sort: Default sort field and direction to use when querying. Can be used to ensure determinacy in query results. + mongoclient_kwargs: Dict of extra kwargs to pass to MongoClient() """ self.database = database self.collection_name = collection_name @@ -578,95 +580,57 @@ class MemoryStore(MongoStore): to a MongoStore """ - def __init__(self, collection_name: str = "memory_db", **kwargs): + def __init__( + self, + database: str = "mem", + collection_name: str = "memory_store", + host: str = "localhost", + port: int = 27019, # to avoid conflicts with localhost + safe_update: bool = False, + mongoclient_kwargs: Optional[Dict] = None, + default_sort: Optional[Dict[str, Union[Sort, int]]] = None, + **kwargs, + ): """ - Initializes the Memory Store Args: - collection_name: name for the collection in memory - """ - self.collection_name = collection_name - self.default_sort = None - self._coll = None - self.kwargs = kwargs - super(MongoStore, self).__init__(**kwargs) + database: The database name + collection_name: The collection name + host: Hostname for the database + port: TCP port to connect to + safe_update: fail gracefully on DocumentTooLarge errors on update + default_sort: Default sort field and direction to use when querying. + Can be used to ensure determinacy in query results. + mongoclient_kwargs: Dict of extra kwargs to pass to MongoClient() + """ + super().__init__( + database=database, + collection_name=collection_name, + host=host, + port=port, + safe_update=safe_update, + mongoclient_kwargs=mongoclient_kwargs, + default_sort=default_sort, + **kwargs, + ) def connect(self, force_reset: bool = False): """ Connect to the source data """ + conn: MemoryClient = MemoryClient( + host=self.host, + port=self.port, + **self.mongoclient_kwargs, + ) - if self._coll is None or force_reset: - self._coll = mongomock.MongoClient().db[self.name] # type: ignore - - def close(self): - """Close up all collections""" - self._coll.database.client.close() + db = conn[self.database] + self._coll = db[self.collection_name] # type: ignore @property def name(self): """Name for the store""" return f"mem://{self.collection_name}" - def __hash__(self): - """Hash for the store""" - return hash((self.name, self.last_updated_field)) - - def groupby( - self, - keys: Union[List[str], str], - criteria: Optional[Dict] = None, - properties: Union[Dict, List, None] = None, - sort: Optional[Dict[str, Union[Sort, int]]] = None, - skip: int = 0, - limit: int = 0, - ) -> Iterator[Tuple[Dict, List[Dict]]]: - """ - Simple grouping function that will group documents - by keys. - - Args: - keys: fields to group documents - criteria: PyMongo filter for documents to search in - properties: properties to return in grouped documents - sort: Dictionary of sort order for fields. Keys are field names and - values are 1 for ascending or -1 for descending. - skip: number documents to skip - limit: limit on total number of documents returned - - Returns: - generator returning tuples of (key, list of elements) - """ - keys = keys if isinstance(keys, list) else [keys] - - if properties is None: - properties = [] - if isinstance(properties, dict): - properties = list(properties.keys()) - - data = [ - doc for doc in self.query(properties=keys + properties, criteria=criteria) if all(has(doc, k) for k in keys) - ] - - def grouping_keys(doc): - return tuple(get(doc, k) for k in keys) - - for vals, group in groupby(sorted(data, key=grouping_keys), key=grouping_keys): - doc = {} # type: ignore - for k, v in zip(keys, vals): - set_(doc, k, v) - yield doc, list(group) - - def __eq__(self, other: object) -> bool: - """ - Check equality for MemoryStore - other: other MemoryStore to compare with - """ - if not isinstance(other, MemoryStore): - return False - - fields = ["collection_name", "last_updated_field"] - return all(getattr(self, f) == getattr(other, f) for f in fields) - class JSONStore(MemoryStore): """ diff --git a/tests/stores/test_mongolike.py b/tests/stores/test_mongolike.py index 68784a639..901e717ee 100644 --- a/tests/stores/test_mongolike.py +++ b/tests/stores/test_mongolike.py @@ -4,7 +4,6 @@ from pathlib import Path from unittest import mock -import mongomock.collection import orjson import pymongo.collection import pytest @@ -238,8 +237,9 @@ def test_mongostore_newer_in(mongostore): def test_memory_store_connect(): memorystore = MemoryStore() assert memorystore._coll is None + assert "mem:" in memorystore.name memorystore.connect() - assert isinstance(memorystore._collection, mongomock.collection.Collection) + assert isinstance(memorystore._collection, pymongo.collection.Collection) def test_groupby(memorystore):