From 104a2dec5e655fe2264d4b3eb635cb05cba01fdd Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 20 Sep 2024 14:49:23 +0200 Subject: [PATCH 1/2] Add test for support for loading MongoDB data without pymongoarrow. Signed-off-by: Marcel Coetzee --- tests/mongodb/test_mongodb_source.py | 37 ++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/tests/mongodb/test_mongodb_source.py b/tests/mongodb/test_mongodb_source.py index 1a84c3d07..dfabcc198 100644 --- a/tests/mongodb/test_mongodb_source.py +++ b/tests/mongodb/test_mongodb_source.py @@ -1,11 +1,11 @@ -import bson import json -import pyarrow -import pytest -from pendulum import DateTime, timezone from unittest import mock +import bson import dlt +import pyarrow +import pytest +from pendulum import DateTime, timezone from sources.mongodb import mongodb, mongodb_collection from sources.mongodb_pipeline import ( @@ -151,12 +151,8 @@ def test_incremental( @pytest.mark.parametrize("data_item_format", ["object", "arrow"]) def test_parallel_loading(data_item_format): - st_records = load_select_collection_db_items_parallel( - data_item_format, parallel=False - ) - parallel_records = load_select_collection_db_items_parallel( - data_item_format, parallel=True - ) + st_records = load_select_collection_db_items_parallel(data_item_format, parallel=False) + parallel_records = load_select_collection_db_items_parallel(data_item_format, parallel=True) assert len(st_records) == len(parallel_records) @@ -356,3 +352,24 @@ def test_arrow_types(destination_name): info = pipeline.run(res, table_name="types_test") assert info.loads_ids != [] + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("data_item_format", ["object", "arrow"]) +def test_mongodb_without_pymongoarrow(destination_name: str, data_item_format: str) -> None: + with mock.patch.dict("sys.modules", {"pymongoarrow": None}): + pipeline = dlt.pipeline( + pipeline_name="test_mongodb_without_pymongoarrow", + destination=destination_name, + dataset_name="test_mongodb_without_pymongoarrow_data", + full_refresh=True, + ) + + comments = mongodb_collection( + collection="comments", limit=10, data_item_format=data_item_format + ) + load_info = pipeline.run(comments) + + assert load_info.loads_ids != [] + table_counts = load_table_counts(pipeline, "comments") + assert table_counts["comments"] == 10 From b40fbe3fae82a25b80a80a8961bb38f2c666a77d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 20 Sep 2024 19:15:14 +0200 Subject: [PATCH 2/2] Make pymongoarrow optional in requirements file Signed-off-by: Marcel Coetzee --- sources/mongodb/helpers.py | 17 +++++++++-- sources/mongodb/requirements.txt | 3 +- sources/sql_database_pipeline.py | 6 ++-- tests/mongodb/test_mongodb_source.py | 12 ++++++-- .../test_rest_api_source_processed.py | 5 ---- .../sql_database/test_sql_database_source.py | 30 +++++++++---------- 6 files changed, 43 insertions(+), 30 deletions(-) diff --git a/sources/mongodb/helpers.py b/sources/mongodb/helpers.py index 9769e4fbd..86ec34e7e 100644 --- a/sources/mongodb/helpers.py +++ b/sources/mongodb/helpers.py @@ -29,6 +29,13 @@ TCollection = Any TCursor = Any +try: + import pymongoarrow # type: ignore + + PYMONGOARROW_AVAILABLE = True +except ImportError: + PYMONGOARROW_AVAILABLE = False + class CollectionLoader: def __init__( @@ -301,15 +308,21 @@ def collection_documents( Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. """ + if data_item_format == "arrow" and not PYMONGOARROW_AVAILABLE: + dlt.common.logger.warn( + "'pymongoarrow' is not installed; falling back to standard MongoDB CollectionLoader." + ) + data_item_format = "object" + if parallel: if data_item_format == "arrow": LoaderClass = CollectionArrowLoaderParallel - elif data_item_format == "object": + else: LoaderClass = CollectionLoaderParallel # type: ignore else: if data_item_format == "arrow": LoaderClass = CollectionArrowLoader # type: ignore - elif data_item_format == "object": + else: LoaderClass = CollectionLoader # type: ignore loader = LoaderClass( diff --git a/sources/mongodb/requirements.txt b/sources/mongodb/requirements.txt index 45ac0bc3d..5240a44e2 100644 --- a/sources/mongodb/requirements.txt +++ b/sources/mongodb/requirements.txt @@ -1,3 +1,2 @@ -pymongo>=4.3.3 -pymongoarrow>=1.3.0 +pymongo>=3 dlt>=0.5.1 diff --git a/sources/sql_database_pipeline.py b/sources/sql_database_pipeline.py index 86e21845d..605895a48 100644 --- a/sources/sql_database_pipeline.py +++ b/sources/sql_database_pipeline.py @@ -338,9 +338,9 @@ def specify_columns_to_load() -> None: ) # Columns can be specified per table in env var (json array) or in `.dlt/config.toml` - os.environ["SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"] = ( - '["rfam_acc", "description"]' - ) + os.environ[ + "SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS" + ] = '["rfam_acc", "description"]' sql_alchemy_source = sql_database( "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", diff --git a/tests/mongodb/test_mongodb_source.py b/tests/mongodb/test_mongodb_source.py index dfabcc198..0917462b5 100644 --- a/tests/mongodb/test_mongodb_source.py +++ b/tests/mongodb/test_mongodb_source.py @@ -151,8 +151,12 @@ def test_incremental( @pytest.mark.parametrize("data_item_format", ["object", "arrow"]) def test_parallel_loading(data_item_format): - st_records = load_select_collection_db_items_parallel(data_item_format, parallel=False) - parallel_records = load_select_collection_db_items_parallel(data_item_format, parallel=True) + st_records = load_select_collection_db_items_parallel( + data_item_format, parallel=False + ) + parallel_records = load_select_collection_db_items_parallel( + data_item_format, parallel=True + ) assert len(st_records) == len(parallel_records) @@ -356,7 +360,9 @@ def test_arrow_types(destination_name): @pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) @pytest.mark.parametrize("data_item_format", ["object", "arrow"]) -def test_mongodb_without_pymongoarrow(destination_name: str, data_item_format: str) -> None: +def test_mongodb_without_pymongoarrow( + destination_name: str, data_item_format: str +) -> None: with mock.patch.dict("sys.modules", {"pymongoarrow": None}): pipeline = dlt.pipeline( pipeline_name="test_mongodb_without_pymongoarrow", diff --git a/tests/rest_api/test_rest_api_source_processed.py b/tests/rest_api/test_rest_api_source_processed.py index cc04c27a6..85e07f9e6 100644 --- a/tests/rest_api/test_rest_api_source_processed.py +++ b/tests/rest_api/test_rest_api_source_processed.py @@ -40,7 +40,6 @@ def test_rest_api_source_filtered(mock_api_server) -> None: def test_rest_api_source_exclude_columns(mock_api_server) -> None: - def exclude_columns(columns: List[str]) -> Callable: def pop_columns(resource: DltResource) -> DltResource: for col in columns: @@ -73,7 +72,6 @@ def pop_columns(resource: DltResource) -> DltResource: def test_rest_api_source_anonymize_columns(mock_api_server) -> None: - def anonymize_columns(columns: List[str]) -> Callable: def empty_columns(resource: DltResource) -> DltResource: for col in columns: @@ -106,7 +104,6 @@ def empty_columns(resource: DltResource) -> DltResource: def test_rest_api_source_map(mock_api_server) -> None: - def lower_title(row): row["title"] = row["title"].lower() return row @@ -133,7 +130,6 @@ def lower_title(row): def test_rest_api_source_filter_and_map(mock_api_server) -> None: - def id_by_10(row): row["id"] = row["id"] * 10 return row @@ -211,7 +207,6 @@ def test_rest_api_source_filtered_child(mock_api_server) -> None: def test_rest_api_source_filtered_and_map_child(mock_api_server) -> None: - def extend_body(row): row["body"] = f"{row['_posts_title']} - {row['body']}" return row diff --git a/tests/sql_database/test_sql_database_source.py b/tests/sql_database/test_sql_database_source.py index abecde9a8..9d578ff7e 100644 --- a/tests/sql_database/test_sql_database_source.py +++ b/tests/sql_database/test_sql_database_source.py @@ -97,9 +97,9 @@ def test_pass_engine_credentials(sql_source_db: SQLAlchemySourceDB) -> None: def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: # set the credentials per table name - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"] = ( - sql_source_db.engine.url.render_as_string(False) - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS" + ] = sql_source_db.engine.url.render_as_string(False) table = sql_table(table="chat_message", schema=sql_source_db.schema) assert table.name == "chat_message" assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] @@ -119,9 +119,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: assert len(list(table)) == 10 # make it fail on cursor - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at_x" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at_x" table = sql_table(table="chat_message", schema=sql_source_db.schema) with pytest.raises(ResourceExtractionError) as ext_ex: len(list(table)) @@ -130,9 +130,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: # set the credentials per table name - os.environ["SOURCES__SQL_DATABASE__CREDENTIALS"] = ( - sql_source_db.engine.url.render_as_string(False) - ) + os.environ[ + "SOURCES__SQL_DATABASE__CREDENTIALS" + ] = sql_source_db.engine.url.render_as_string(False) # applies to both sql table and sql database table = sql_table(table="chat_message", schema=sql_source_db.schema) assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] @@ -155,9 +155,9 @@ def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: assert len(list(database)) == 10 # make it fail on cursor - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at_x" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at_x" table = sql_table(table="chat_message", schema=sql_source_db.schema) with pytest.raises(ResourceExtractionError) as ext_ex: len(list(table)) @@ -275,9 +275,9 @@ def test_load_sql_table_incremental( """Run pipeline twice. Insert more rows after first run and ensure only those rows are stored after the second run. """ - os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = ( - "updated_at" - ) + os.environ[ + "SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH" + ] = "updated_at" pipeline = make_pipeline(destination_name) tables = ["chat_message"]