diff --git a/tests/mongo/test_mongo.py b/tests/mongo/test_mongo.py index 84fe0fe57..c30940ff0 100644 --- a/tests/mongo/test_mongo.py +++ b/tests/mongo/test_mongo.py @@ -156,6 +156,7 @@ def list_database_names(self): snock = mocker.patch("pymongo.MongoClient") snock.return_value = MongoMock("toucan", "test_col") aggregate = mocker.patch("pymongo.collection.Collection.aggregate") + aggregate.return_value = [{"foo": 42}] mongo_connector = MongoConnector(name="mycon", host="localhost", port=22, username="ubuntu", password="ilovetoucan") diff --git a/toucan_connectors/mongo/mongo_connector.py b/toucan_connectors/mongo/mongo_connector.py index bc5ce4ad5..519d59d90 100644 --- a/toucan_connectors/mongo/mongo_connector.py +++ b/toucan_connectors/mongo/mongo_connector.py @@ -1,3 +1,4 @@ +import itertools from collections.abc import Generator from contextlib import contextmanager from functools import _lru_cache_wrapper, cached_property, lru_cache @@ -260,15 +261,22 @@ def _execute_query(self, data_source: MongoDataSource): col = client[data_source.database][data_source.collection] return col.aggregate(data_source.query) # type: ignore[arg-type] - def _retrieve_data(self, data_source): + def _retrieve_data(self, data_source, chunk_size: int | None = None): data_source.query = normalize_query(data_source.query, data_source.parameters) data = self._execute_query(data_source) - return pd.DataFrame(list(data)) + + if chunk_size: + chunks = [] + while (chunk := list(itertools.islice(data, chunk_size))): + chunks.append(pd.DataFrame.from_records(chunk)) + return pd.concat(chunks) if chunks else pd.DataFrame() + else: + return pd.DataFrame.from_records(data) @decorate_func_with_retry - def get_df(self, data_source, permissions=None): + def get_df(self, data_source, permissions=None, chunk_size: int | None = None): data_source.query = apply_condition_filter(data_source.query, permissions) - return self._retrieve_data(data_source) + return self._retrieve_data(data_source, chunk_size=chunk_size) @decorate_func_with_retry def get_slice( @@ -278,6 +286,7 @@ def get_slice( offset: int = 0, limit: int | None = None, get_row_count: bool | None = False, + chunk_size: int | None = None, ) -> DataSlice: # Create a copy in order to keep the original (deepcopy-like) data_source = data_source.model_copy(deep=True) @@ -310,7 +319,7 @@ def get_slice( total_count = res["count"][0]["value"] if len(res["count"]) > 0 else 0 df = pd.DataFrame(res["df"]) else: - df = self.get_df(data_source, permissions) + df = self.get_df(data_source, permissions, chunk_size=chunk_size) total_count = len(df) # We try to remove the _id from this DataFrame if there is one # ugly for now but we need to handle that in this else case