Skip to content

Commit

Permalink
splits pandas and arrow imports
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Mar 18, 2024
1 parent 7e30318 commit 1d0fd25
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 9 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/test_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ jobs:
name: Run smoke tests with minimum deps Windows
shell: cmd
- name: Install pyarrow
run: poetry install --no-interaction -E duckdb -E cli -E parquet --with sentry-sdk

- run: |
poetry run pytest tests/pipeline/test_pipeline_extra.py -k arrow
if: runner.os != 'Windows'
name: Run pipeline tests with pyarrow but no pandas installed
- run: |
poetry run pytest tests/pipeline/test_pipeline_extra.py -k arrow
if: runner.os == 'Windows'
name: Run pipeline tests with pyarrow but no pandas installed Windows
shell: cmd
- name: Install pipeline dependencies
run: poetry install --no-interaction -E duckdb -E cli -E parquet --with sentry-sdk --with pipeline

Expand Down
9 changes: 8 additions & 1 deletion dlt/common/libs/pandas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Any
from dlt.common.exceptions import MissingDependencyException

try:
import pandas
from pandas.io.sql import _wrap_result
except ModuleNotFoundError:
raise MissingDependencyException("DLT Pandas Helpers", ["pandas"])


def pandas_to_arrow(df: pandas.DataFrame) -> Any:
"""Converts pandas to arrow or raises an exception if pyarrow is not installed"""
from dlt.common.libs.pyarrow import pyarrow as pa

return pa.Table.from_pandas(df)
7 changes: 7 additions & 0 deletions dlt/common/libs/pandas_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dlt.common.exceptions import MissingDependencyException


try:
from pandas.io.sql import _wrap_result
except ModuleNotFoundError:
raise MissingDependencyException("dlt pandas helper for sql", ["pandas"])
4 changes: 3 additions & 1 deletion dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import pyarrow.compute
except ModuleNotFoundError:
raise MissingDependencyException(
"dlt parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "dlt Helpers for for parquet."
"dlt pyarrow helpers",
[f"{version.DLT_PKG_NAME}[parquet]"],
"Install pyarrow to be allow to load arrow tables, panda frames and to use parquet files.",
)


Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _get_columns(self) -> List[str]:
return [c[0] for c in self.native_cursor.description]

def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]:
from dlt.common.libs.pandas import _wrap_result
from dlt.common.libs.pandas_sql import _wrap_result

columns = self._get_columns()
if chunk_size is None:
Expand Down
5 changes: 3 additions & 2 deletions dlt/extract/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from dlt.common.libs.pyarrow import pyarrow as pa, TAnyArrowItem
except MissingDependencyException:
pyarrow = None
pa = None

try:
from dlt.common.libs.pandas import pandas
from dlt.common.libs.pandas import pandas, pandas_to_arrow
except MissingDependencyException:
pandas = None

Expand Down Expand Up @@ -224,7 +225,7 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No
for tbl in (
(
# 1. Convert pandas frame(s) to arrow Table
pa.Table.from_pandas(item)
pandas_to_arrow(item)
if (pandas and isinstance(item, pandas.DataFrame))
else item
)
Expand Down
8 changes: 6 additions & 2 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@

try:
from dlt.common.libs import pyarrow
from dlt.common.libs.pandas import pandas
from dlt.common.libs.numpy import numpy
from dlt.common.libs.pyarrow import pyarrow as pa, TAnyArrowItem
from dlt.common.libs.pyarrow import from_arrow_scalar, to_arrow_scalar
except MissingDependencyException:
pa = None
pyarrow = None
numpy = None

# NOTE: always import pandas independently from pyarrow
try:
from dlt.common.libs.pandas import pandas, pandas_to_arrow
except MissingDependencyException:
pandas = None


Expand Down Expand Up @@ -220,7 +224,7 @@ def __call__(
) -> Tuple[TDataItem, bool, bool]:
is_pandas = pandas is not None and isinstance(tbl, pandas.DataFrame)
if is_pandas:
tbl = pa.Table.from_pandas(tbl)
tbl = pandas_to_arrow(tbl)

primary_key = self.primary_key(tbl) if callable(self.primary_key) else self.primary_key
if primary_key:
Expand Down
10 changes: 8 additions & 2 deletions dlt/extract/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@

try:
from dlt.common.libs.pandas import pandas

PandaFrame = pandas.DataFrame
except MissingDependencyException:
PandaFrame = NoneType

try:
from dlt.common.libs.pyarrow import pyarrow

PandaFrame, ArrowTable, ArrowRecords = pandas.DataFrame, pyarrow.Table, pyarrow.RecordBatch
ArrowTable, ArrowRecords = pyarrow.Table, pyarrow.RecordBatch
except MissingDependencyException:
PandaFrame, ArrowTable, ArrowRecords = NoneType, NoneType, NoneType
ArrowTable, ArrowRecords = NoneType, NoneType


def wrap_additional_type(data: Any) -> Any:
Expand Down
24 changes: 24 additions & 0 deletions tests/pipeline/test_pipeline_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,27 @@ class Parent(BaseModel):
}

assert loaded_values == {"data_dictionary__child_attribute": "any string"}


def test_arrow_no_pandas() -> None:
import pyarrow as pa

data = {
"Numbers": [1, 2, 3, 4, 5],
"Strings": ["apple", "banana", "cherry", "date", "elderberry"],
}

df = pa.table(data)

@dlt.resource
def pandas_incremental(numbers=dlt.sources.incremental("Numbers")):
yield df

info = dlt.run(
pandas_incremental(), write_disposition="append", table_name="data", destination="duckdb"
)

with info.pipeline.sql_client() as client: # type: ignore
with client.execute_query("SELECT * FROM data") as c:
with pytest.raises(ImportError):
df = c.df()

0 comments on commit 1d0fd25

Please sign in to comment.