Skip to content

Commit

Permalink
Make pymongoarrow optional (#568)
Browse files Browse the repository at this point in the history
* Add test for support for loading MongoDB data without pymongoarrow.

Signed-off-by: Marcel Coetzee <[email protected]>

* Make pymongoarrow optional in requirements file

Signed-off-by: Marcel Coetzee <[email protected]>

---------

Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy authored Oct 1, 2024
1 parent f413db6 commit 5312382
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 28 deletions.
17 changes: 15 additions & 2 deletions sources/mongodb/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
TCollection = Any
TCursor = Any

try:
import pymongoarrow # type: ignore

PYMONGOARROW_AVAILABLE = True
except ImportError:
PYMONGOARROW_AVAILABLE = False


class CollectionLoader:
def __init__(
Expand Down Expand Up @@ -345,15 +352,21 @@ def collection_documents(
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
"""
if data_item_format == "arrow" and not PYMONGOARROW_AVAILABLE:
dlt.common.logger.warn(
"'pymongoarrow' is not installed; falling back to standard MongoDB CollectionLoader."
)
data_item_format = "object"

if parallel:
if data_item_format == "arrow":
LoaderClass = CollectionArrowLoaderParallel
elif data_item_format == "object":
else:
LoaderClass = CollectionLoaderParallel # type: ignore
else:
if data_item_format == "arrow":
LoaderClass = CollectionArrowLoader # type: ignore
elif data_item_format == "object":
else:
LoaderClass = CollectionLoader # type: ignore

loader = LoaderClass(
Expand Down
3 changes: 1 addition & 2 deletions sources/mongodb/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
pymongo>=4.3.3
pymongoarrow>=1.3.0
pymongo>=3
dlt>=0.5.1
6 changes: 3 additions & 3 deletions sources/sql_database_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ def specify_columns_to_load() -> None:
)

# Columns can be specified per table in env var (json array) or in `.dlt/config.toml`
os.environ["SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"] = (
'["rfam_acc", "description"]'
)
os.environ[
"SOURCES__SQL_DATABASE__FAMILY__INCLUDED_COLUMNS"
] = '["rfam_acc", "description"]'

sql_alchemy_source = sql_database(
"mysql+pymysql://[email protected]:4497/Rfam?&binary_prefix=true",
Expand Down
28 changes: 27 additions & 1 deletion tests/mongodb/test_mongodb_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import bson
import json
from unittest import mock

import bson
import dlt
import pyarrow
import pytest
from pendulum import DateTime, timezone
Expand Down Expand Up @@ -404,3 +407,26 @@ def test_filter_intersect(destination_name):

with pytest.raises(PipelineStepFailed):
pipeline.run(movies)


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
@pytest.mark.parametrize("data_item_format", ["object", "arrow"])
def test_mongodb_without_pymongoarrow(
destination_name: str, data_item_format: str
) -> None:
with mock.patch.dict("sys.modules", {"pymongoarrow": None}):
pipeline = dlt.pipeline(
pipeline_name="test_mongodb_without_pymongoarrow",
destination=destination_name,
dataset_name="test_mongodb_without_pymongoarrow_data",
full_refresh=True,
)

comments = mongodb_collection(
collection="comments", limit=10, data_item_format=data_item_format
)
load_info = pipeline.run(comments)

assert load_info.loads_ids != []
table_counts = load_table_counts(pipeline, "comments")
assert table_counts["comments"] == 10
5 changes: 0 additions & 5 deletions tests/rest_api/test_rest_api_source_processed.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_rest_api_source_filtered(mock_api_server) -> None:


def test_rest_api_source_exclude_columns(mock_api_server) -> None:

def exclude_columns(columns: List[str]) -> Callable:
def pop_columns(resource: DltResource) -> DltResource:
for col in columns:
Expand Down Expand Up @@ -73,7 +72,6 @@ def pop_columns(resource: DltResource) -> DltResource:


def test_rest_api_source_anonymize_columns(mock_api_server) -> None:

def anonymize_columns(columns: List[str]) -> Callable:
def empty_columns(resource: DltResource) -> DltResource:
for col in columns:
Expand Down Expand Up @@ -106,7 +104,6 @@ def empty_columns(resource: DltResource) -> DltResource:


def test_rest_api_source_map(mock_api_server) -> None:

def lower_title(row):
row["title"] = row["title"].lower()
return row
Expand All @@ -133,7 +130,6 @@ def lower_title(row):


def test_rest_api_source_filter_and_map(mock_api_server) -> None:

def id_by_10(row):
row["id"] = row["id"] * 10
return row
Expand Down Expand Up @@ -211,7 +207,6 @@ def test_rest_api_source_filtered_child(mock_api_server) -> None:


def test_rest_api_source_filtered_and_map_child(mock_api_server) -> None:

def extend_body(row):
row["body"] = f"{row['_posts_title']} - {row['body']}"
return row
Expand Down
30 changes: 15 additions & 15 deletions tests/sql_database/test_sql_database_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def test_pass_engine_credentials(sql_source_db: SQLAlchemySourceDB) -> None:

def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None:
# set the credentials per table name
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"] = (
sql_source_db.engine.url.render_as_string(False)
)
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"
] = sql_source_db.engine.url.render_as_string(False)
table = sql_table(table="chat_message", schema=sql_source_db.schema)
assert table.name == "chat_message"
assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"]
Expand All @@ -119,9 +119,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None:
assert len(list(table)) == 10

# make it fail on cursor
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = (
"updated_at_x"
)
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"
] = "updated_at_x"
table = sql_table(table="chat_message", schema=sql_source_db.schema)
with pytest.raises(ResourceExtractionError) as ext_ex:
len(list(table))
Expand All @@ -130,9 +130,9 @@ def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None:

def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None:
# set the credentials per table name
os.environ["SOURCES__SQL_DATABASE__CREDENTIALS"] = (
sql_source_db.engine.url.render_as_string(False)
)
os.environ[
"SOURCES__SQL_DATABASE__CREDENTIALS"
] = sql_source_db.engine.url.render_as_string(False)
# applies to both sql table and sql database
table = sql_table(table="chat_message", schema=sql_source_db.schema)
assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"]
Expand All @@ -155,9 +155,9 @@ def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None:
assert len(list(database)) == 10

# make it fail on cursor
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = (
"updated_at_x"
)
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"
] = "updated_at_x"
table = sql_table(table="chat_message", schema=sql_source_db.schema)
with pytest.raises(ResourceExtractionError) as ext_ex:
len(list(table))
Expand Down Expand Up @@ -275,9 +275,9 @@ def test_load_sql_table_incremental(
"""Run pipeline twice. Insert more rows after first run
and ensure only those rows are stored after the second run.
"""
os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = (
"updated_at"
)
os.environ[
"SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"
] = "updated_at"

pipeline = make_pipeline(destination_name)
tables = ["chat_message"]
Expand Down

0 comments on commit 5312382

Please sign in to comment.