diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0a261fb..3d7dd99 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -11,12 +11,18 @@ jobs: PIN_MODE: [false, true] PYSPARK_VERSION: ["3.0.3", "3.1.3", "3.2.3", "3.3.2", "3.4.0"] include: - - PYSPARK_VERSION: "3.4.0" + - PYSPARK_VERSION: "3.5.1" PYTHON_VERSION: "3.11" JOBLIB_VERSION: "1.3.0" - - PYSPARK_VERSION: "3.4.0" + - PYSPARK_VERSION: "3.5.1" + PYTHON_VERSION: "3.11" + JOBLIB_VERSION: "1.4.2" + - PYSPARK_VERSION: "3.5.1" PYTHON_VERSION: "3.12" JOBLIB_VERSION: "1.3.0" + - PYSPARK_VERSION: "3.5.1" + PYTHON_VERSION: "3.12" + JOBLIB_VERSION: "1.4.2" exclude: - PYSPARK_VERSION: "3.0.3" PIN_MODE: true diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 87dbfb8..9eb1801 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -27,11 +27,15 @@ import AutoBatchingMixin, ParallelBackendBase, register_parallel_backend, SequentialBackend try: - from joblib._parallel_backends import SafeFunction + # joblib >=1.4.0 + from joblib._utils import _TracebackCapturingWrapper as SafeFunction except ImportError: - # joblib >= 1.3.0 - from joblib._parallel_backends import PoolManagerMixin - SafeFunction = None + try: + from joblib._parallel_backends import SafeFunction + except ImportError: + # joblib >= 1.3.0 + from joblib._parallel_backends import PoolManagerMixin + SafeFunction = None from py4j.clientserver import ClientServer @@ -216,7 +220,8 @@ def mapper_fn(_): ) return self._get_pool().apply_async( - SafeFunction(run_on_worker_and_fetch_result), callback=callback + SafeFunction(run_on_worker_and_fetch_result), + callback=callback, error_callback=callback )