Skip to content

Commit

Permalink
feat: support pyspark 3 (via a databricks.koalas stub)
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Dec 7, 2023
1 parent 31ba5f2 commit 81621f6
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 37 deletions.
29 changes: 22 additions & 7 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ jobs:
needs: check_skip
if: ${{ needs.check_skip.outputs.skip == 'false' }}
runs-on: "ubuntu-latest"
strategy:
fail-fast: true
matrix:
include:
- python-version: "3.7"
spark: "spark2"
- python-version: "3.7"
spark: "spark3"
- python-version: "3.8"
spark: "spark3"
- python-version: "3.9"
spark: "spark3"
name: 'Testing on ubuntu'
defaults:
run:
Expand All @@ -61,13 +73,16 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.7'
- name: Install eds-scikit
shell: bash {0}
run: ./build_tools/github/install.sh
- name: Run tests
shell: bash {0}
run: ./build_tools/github/test.sh
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
pip install -U "pip<23"
echo Installing eds-scikit with spark version ${{ matrix.spark }}
pip install --progress-bar off ".[${{ matrix.spark }}, dev, doc]"
- name: Run pytest
run: |
python -m pytest --pyargs tests -m "" --cov=eds_scikit
- name: Upload coverage to CodeCov
uses: codecov/codecov-action@v3
if: success()
Expand Down
4 changes: 0 additions & 4 deletions build_tools/github/install.sh

This file was deleted.

4 changes: 0 additions & 4 deletions build_tools/github/test.sh

This file was deleted.

1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Changed

- Support for pyarrow > 0.17.0
- Support for pyspark 3 (to force pyspark 2, use `pip install eds-scikit[spark2]`)

### Fixed
- Caching in spark instead of koalas to improve speed
Expand Down
1 change: 1 addition & 0 deletions docs/project_description.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ The goal of **Koalas** is precisely to avoid this issue. It aims at allowing cod

```python
from databricks import koalas as ks
# or from pyspark import pandas as ks, if you have spark 3

# Converting the Spark DataFrame into a Koalas DataFrame
visit_occurrence_koalas = visit_occurrence_spark.to_koalas()
Expand Down
4 changes: 2 additions & 2 deletions eds_scikit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
from pyspark import SparkContext
from pyspark.sql import SparkSession

import eds_scikit.biology # noqa: F401 --> To register functions

pyarrow.open_stream = pyarrow.ipc.open_stream

sys.path.insert(
0, (pathlib.Path(__file__).parent / "package-override").absolute().as_posix()
)
os.environ["PYTHONPATH"] = ":".join(sys.path)

import eds_scikit.biology # noqa: F401 --> To register functions

# Remove SettingWithCopyWarning
pd.options.mode.chained_assignment = None

Expand Down
2 changes: 2 additions & 0 deletions eds_scikit/biology/viz/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def aggregate_concepts_set(

# Extract concept-set
measurement_std_filtered = get_measurement_std(measurement_valid, src_to_std)
if is_koalas(measurement_std_filtered):
measurement_std_filtered.spark.cache()
measurement_std_filtered = measurement_std_filtered.drop(
columns="source_concept_id"
)
Expand Down
Empty file.
17 changes: 17 additions & 0 deletions eds_scikit/package-override/databricks/koalas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# This file is used to override the databricks.koalas package with the pyspark.pandas
# package, if the databricks.koalas package is not available (python >= 3.8)
import sys
import pyarrow # noqa: E402, F401

old_sys_path = sys.path.copy()
sys.path.remove(next((p for p in sys.path if "package-override" in p), None))
databricks = sys.modules.pop("databricks")
sys.modules.pop("databricks.koalas")
try:
from databricks.koalas import * # noqa: E402, F401, F403
except ImportError:
from pyspark.pandas import * # noqa: E402, F401, F403

Check warning on line 13 in eds_scikit/package-override/databricks/koalas/__init__.py

View check run for this annotation

Codecov / codecov/patch

eds_scikit/package-override/databricks/koalas/__init__.py#L12-L13

Added lines #L12 - L13 were not covered by tests

sys.modules["databricks"] = databricks
sys.modules["databricks.koalas"] = sys.modules["pyspark.pandas"]

Check warning on line 16 in eds_scikit/package-override/databricks/koalas/__init__.py

View check run for this annotation

Codecov / codecov/patch

eds_scikit/package-override/databricks/koalas/__init__.py#L15-L16

Added lines #L15 - L16 were not covered by tests
sys.path[:] = old_sys_path
13 changes: 6 additions & 7 deletions eds_scikit/package-override/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@
is the only one that resolves to this very module, still gets what it asked for:
the pyarrow module's content.
"""

import sys

old_sys_path = sys.path.copy()
sys.path.remove(next((p for p in sys.path if "package-override" in p), None))
del sys.modules["pyarrow"]
import pyarrow # noqa: E402, F401

try:
import pyarrow.ipc
import pyarrow # noqa: E402, F401
from pyarrow.ipc import open_stream # noqa: E402, F401

pyarrow.open_stream = pyarrow.ipc.open_stream
except ImportError:
pass
pyarrow.open_stream = open_stream

from pyarrow import * # noqa: F401, F403, E402

sys.path[:] = old_sys_path
15 changes: 9 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,18 @@ dependencies = [
"pgpasslib>=1.1.0, <2.0.0",
"psycopg2-binary>=2.9.0, <3.0.0",
"pandas>=1.3.0, <2.0.0",
"numpy>=1.0.0, <1.20",
"koalas>=1.8.1, <2.0.0",
"numpy>=1.0.0",
"altair>=5.0.0, <6.0.0",
"loguru==0.7.0",
"pypandoc==1.7.5",
"pyspark==2.4.3",
"pyspark",
"pyarrow>=0.10.0",
"pretty-html-table>=0.9.15, <0.10.0",
"catalogue",
"schemdraw>=0.15.0, <1.0.0",
"ipython>=7.32.0, <8.0.0",
"packaging==21.3",
"tomli==2.0.1",
"ipython>=7.32.0",
"packaging>=21.3",
"tomli>=2.0.1",
]
dynamic = ['version']

Expand All @@ -66,6 +65,10 @@ Documentation = "https://aphp.github.io/eds-scikit"
"Bug Tracker" = "https://github.com/aphp/eds-scikit/issues"

[project.optional-dependencies]
spark2 = [
"pyspark==2.4.3",
"koalas>=1.8.1,<2.0.0",
]
dev = [
"black>=22.3.0, <23.0.0",
"flake8==3.9.2",
Expand Down
17 changes: 12 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# isort: skip_file
import logging
import os

import eds_scikit
import pandas as pd
import pytest
from _pytest.logging import caplog as _caplog # noqa F401
Expand Down Expand Up @@ -78,17 +80,13 @@ def spark_session(pytestconfig, tmpdir_factory):
print("!! Creating spark session !!")

from pyspark import SparkConf
from pyspark import __version__ as pyspark_version

temp_warehouse_dir = tmpdir_factory.mktemp("spark")
conf = (
SparkConf()
.setMaster("local")
.setAppName("testing")
# used to overwrite hive tables
.set(
"spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation",
"true",
)
# Path to data and metastore
# Note: the option "hive.metastore.warehouse.dir" is deprecated
# But javax.jdo.option.ConnectionURL can be used for the path of 'metastrore_db'
Expand All @@ -101,8 +99,17 @@ def spark_session(pytestconfig, tmpdir_factory):
"javax.jdo.option.ConnectionURL",
f"jdbc:derby:;databaseName={temp_warehouse_dir}/metastore_db;create=true",
)
.set("spark.executor.cores", 1)
)

if pyspark_version < "3":

# used to overwrite hive tables
conf = conf.set(
"spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation",
"true",
)

session, _, _ = improve_performances(to_add_conf=list(conf.getAll()))

# session is ready
Expand Down
4 changes: 2 additions & 2 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_framework_koalas(example_objects):
def test_unconvertible_objects():
objects = [1, "coucou", {"a": [1, 2]}, [1, 2, 3], 2.5, ks, pd]
for obj in objects:
with pytest.raises(ValueError):
with pytest.raises((ValueError, TypeError)):
framework.pandas(obj)

for obj in objects:
with pytest.raises(ValueError):
with pytest.raises((ValueError, TypeError)):
framework.koalas(obj)

0 comments on commit 81621f6

Please sign in to comment.