From fbde5769b1ecf49ea1c68ee903198fa0b37089b8 Mon Sep 17 00:00:00 2001 From: svittoz Date: Fri, 26 Apr 2024 17:05:45 +0000 Subject: [PATCH] pyarrow fix --- eds_scikit/__init__.py | 17 +++++------------ eds_scikit/io/improve_performance.py | 27 ++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/eds_scikit/__init__.py b/eds_scikit/__init__.py index 6da07cb7..ef2f6be8 100644 --- a/eds_scikit/__init__.py +++ b/eds_scikit/__init__.py @@ -1,7 +1,7 @@ """Top-level package for eds_scikit.""" __author__ = """eds_scikit""" -__version__ = "0.1.7" +__version__ = "0.1.6" import warnings @@ -12,36 +12,29 @@ import importlib import os import pathlib -import sys import time from packaging import version from typing import List, Tuple from pathlib import Path import pandas as pd -import pyarrow -import pyarrow.ipc import pyspark from loguru import logger from pyspark import SparkContext from pyspark.sql import SparkSession import eds_scikit.biology # noqa: F401 --> To register functions -from eds_scikit.io import improve_performances - -pyarrow.open_stream = pyarrow.ipc.open_stream - -sys.path.insert( - 0, (pathlib.Path(__file__).parent / "package-override").absolute().as_posix() -) -os.environ["PYTHONPATH"] = ":".join(sys.path) +from eds_scikit.io import improve_performances, pyarrow_fix # Remove SettingWithCopyWarning pd.options.mode.chained_assignment = None +pyarrow_fix() + logger.warning( """To improve performances when using Spark and Koalas, please call `eds_scikit.improve_performances()` This function optimally configures Spark. Use it as: `spark, sc, sql = eds_scikit.improve_performances()` The functions respectively returns a SparkSession, a SparkContext and an sql method""" ) + diff --git a/eds_scikit/io/improve_performance.py b/eds_scikit/io/improve_performance.py index 42b6f19a..7df00085 100644 --- a/eds_scikit/io/improve_performance.py +++ b/eds_scikit/io/improve_performance.py @@ -6,6 +6,7 @@ from typing import List, Tuple import pyarrow +import pyarrow.ipc import pyspark from packaging import version from pyspark import SparkContext @@ -48,6 +49,29 @@ def set_env_variables() -> None: if version.parse(pyarrow.__version__) >= version.parse("2.0.0"): os.environ["PYARROW_IGNORE_TIMEZONE"] = "0" + +def pyarrow_fix(): + + pyarrow.open_stream = pyarrow.ipc.open_stream + + sys.path.insert( + 0, (Path(__file__).parent / "package-override").absolute().as_posix() + ) + os.environ["PYTHONPATH"] = ":".join(sys.path) + + global spark, sc, sql + + spark = SparkSession.builder.getOrCreate() + + conf = spark.sparkContext.getConf() + conf.set("spark.executorEnv.PYTHONPATH", f"{Path(__file__).parent}/package-override") + + spark = SparkSession.builder.enableHiveSupport().config(conf=conf).getOrCreate() + + sc = spark.sparkContext + + sql = spark.sql + def improve_performances( @@ -88,13 +112,14 @@ def improve_performances( tz = os.environ.get("TZ", "UTC") os.environ["TZ"] = tz time.tzset() - + to_add_conf.extend( [ ("spark.app.name", f"{os.environ.get('USER')}_{app_name}_scikit"), ("spark.sql.session.timeZone", tz), ("spark.sql.execution.arrow.enabled", "true"), ("spark.sql.execution.arrow.pyspark.enabled", "true"), + ("spark.executorEnv.PYTHONPATH", f"{BASE_DIR}/package-override") ] )