diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0a261fb..99457e8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,10 +13,10 @@ jobs: include: - PYSPARK_VERSION: "3.4.0" PYTHON_VERSION: "3.11" - JOBLIB_VERSION: "1.3.0" + JOBLIB_VERSION: ["1.3.0", "1.4.2"] - PYSPARK_VERSION: "3.4.0" PYTHON_VERSION: "3.12" - JOBLIB_VERSION: "1.3.0" + JOBLIB_VERSION: ["1.3.0", "1.4.2"] exclude: - PYSPARK_VERSION: "3.0.3" PIN_MODE: true diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 3d39655..9eb1801 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -23,17 +23,19 @@ import uuid from packaging.version import Version, parse -import joblib from joblib.parallel \ import AutoBatchingMixin, ParallelBackendBase, register_parallel_backend, SequentialBackend -if parse(joblib.__version__) >= Version('1.4.0'): +try: + # joblib >=1.4.0 from joblib._utils import _TracebackCapturingWrapper as SafeFunction -elif parse(joblib.__version__) < Version('1.3.0'): - from joblib._parallel_backends import SafeFunction -else: - from joblib._parallel_backends import PoolManagerMixin - SafeFunction = None +except ImportError: + try: + from joblib._parallel_backends import SafeFunction + except ImportError: + # joblib >= 1.3.0 + from joblib._parallel_backends import PoolManagerMixin + SafeFunction = None from py4j.clientserver import ClientServer