Skip to content

Commit

Permalink
pyarrow fix
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz committed Apr 26, 2024
1 parent af4fb1e commit fbde576
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
17 changes: 5 additions & 12 deletions eds_scikit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Top-level package for eds_scikit."""

__author__ = """eds_scikit"""
__version__ = "0.1.7"
__version__ = "0.1.6"

import warnings

Expand All @@ -12,36 +12,29 @@
import importlib
import os
import pathlib
import sys
import time
from packaging import version
from typing import List, Tuple
from pathlib import Path

import pandas as pd
import pyarrow
import pyarrow.ipc
import pyspark
from loguru import logger
from pyspark import SparkContext
from pyspark.sql import SparkSession

import eds_scikit.biology # noqa: F401 --> To register functions
from eds_scikit.io import improve_performances

pyarrow.open_stream = pyarrow.ipc.open_stream

sys.path.insert(
0, (pathlib.Path(__file__).parent / "package-override").absolute().as_posix()
)
os.environ["PYTHONPATH"] = ":".join(sys.path)
from eds_scikit.io import improve_performances, pyarrow_fix

# Remove SettingWithCopyWarning
pd.options.mode.chained_assignment = None

pyarrow_fix()

logger.warning(
"""To improve performances when using Spark and Koalas, please call `eds_scikit.improve_performances()`
This function optimally configures Spark. Use it as:
`spark, sc, sql = eds_scikit.improve_performances()`
The functions respectively returns a SparkSession, a SparkContext and an sql method"""
)

27 changes: 26 additions & 1 deletion eds_scikit/io/improve_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Tuple

import pyarrow
import pyarrow.ipc
import pyspark
from packaging import version
from pyspark import SparkContext
Expand Down Expand Up @@ -48,6 +49,29 @@ def set_env_variables() -> None:

if version.parse(pyarrow.__version__) >= version.parse("2.0.0"):
os.environ["PYARROW_IGNORE_TIMEZONE"] = "0"

def pyarrow_fix():

pyarrow.open_stream = pyarrow.ipc.open_stream

sys.path.insert(
0, (Path(__file__).parent / "package-override").absolute().as_posix()
)
os.environ["PYTHONPATH"] = ":".join(sys.path)

global spark, sc, sql

spark = SparkSession.builder.getOrCreate()

conf = spark.sparkContext.getConf()
conf.set("spark.executorEnv.PYTHONPATH", f"{Path(__file__).parent}/package-override")

spark = SparkSession.builder.enableHiveSupport().config(conf=conf).getOrCreate()

sc = spark.sparkContext

sql = spark.sql



def improve_performances(
Expand Down Expand Up @@ -88,13 +112,14 @@ def improve_performances(
tz = os.environ.get("TZ", "UTC")
os.environ["TZ"] = tz
time.tzset()

to_add_conf.extend(
[
("spark.app.name", f"{os.environ.get('USER')}_{app_name}_scikit"),
("spark.sql.session.timeZone", tz),
("spark.sql.execution.arrow.enabled", "true"),
("spark.sql.execution.arrow.pyspark.enabled", "true"),
("spark.executorEnv.PYTHONPATH", f"{BASE_DIR}/package-override")
]
)

Expand Down

0 comments on commit fbde576

Please sign in to comment.