Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mongo): when creating a df from a cursor, allow to do it by chunks (saves memory) [TCTC-9496] #1813

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/mongo/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def list_database_names(self):
snock = mocker.patch("pymongo.MongoClient")
snock.return_value = MongoMock("toucan", "test_col")
aggregate = mocker.patch("pymongo.collection.Collection.aggregate")
aggregate.return_value = [{"foo": 42}]

mongo_connector = MongoConnector(name="mycon", host="localhost", port=22, username="ubuntu", password="ilovetoucan")

Expand Down
19 changes: 14 additions & 5 deletions toucan_connectors/mongo/mongo_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from collections.abc import Generator
from contextlib import contextmanager
from functools import _lru_cache_wrapper, cached_property, lru_cache
Expand Down Expand Up @@ -260,15 +261,22 @@ def _execute_query(self, data_source: MongoDataSource):
col = client[data_source.database][data_source.collection]
return col.aggregate(data_source.query) # type: ignore[arg-type]

def _retrieve_data(self, data_source):
def _retrieve_data(self, data_source, chunk_size: int | None = None):
data_source.query = normalize_query(data_source.query, data_source.parameters)
data = self._execute_query(data_source)
return pd.DataFrame(list(data))

if chunk_size:
chunks = []
while (chunk := list(itertools.islice(data, chunk_size))):
chunks.append(pd.DataFrame.from_records(chunk))
return pd.concat(chunks) if chunks else pd.DataFrame()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Est-ce que ca fait pas de différence de créer n petit df et des concaténer après, plutot que de faire un dataframe qu'on mute en ajoutant les chunks à chaque fois ?
(genre df = pd.DataFrame() au début puis df = df.concat(pd.DataFrame.from_records(chunk)) à chaque tour)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

je pense que ça revient au même oui, pourquoi ?

else:
return pd.DataFrame.from_records(data)

@decorate_func_with_retry
def get_df(self, data_source, permissions=None):
def get_df(self, data_source, permissions=None, chunk_size: int | None = None):
fspot marked this conversation as resolved.
Show resolved Hide resolved
data_source.query = apply_condition_filter(data_source.query, permissions)
return self._retrieve_data(data_source)
return self._retrieve_data(data_source, chunk_size=chunk_size)

@decorate_func_with_retry
def get_slice(
Expand All @@ -278,6 +286,7 @@ def get_slice(
offset: int = 0,
limit: int | None = None,
get_row_count: bool | None = False,
chunk_size: int | None = None,
) -> DataSlice:
# Create a copy in order to keep the original (deepcopy-like)
data_source = data_source.model_copy(deep=True)
Expand Down Expand Up @@ -310,7 +319,7 @@ def get_slice(
total_count = res["count"][0]["value"] if len(res["count"]) > 0 else 0
df = pd.DataFrame(res["df"])
else:
df = self.get_df(data_source, permissions)
df = self.get_df(data_source, permissions, chunk_size=chunk_size)
total_count = len(df)
# We try to remove the _id from this DataFrame if there is one
# ugly for now but we need to handle that in this else case
Expand Down