diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 2c97e40..fd92f2d 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -21,6 +21,7 @@ import warnings from multiprocessing.pool import ThreadPool import uuid +from typing import Optional from packaging.version import Version, parse from joblib.parallel \ @@ -36,10 +37,11 @@ from py4j.clientserver import ClientServer import pyspark -from pyspark.sql import SparkSession from pyspark import cloudpickle from pyspark.util import VersionUtils +from .utils import create_resource_profile, get_spark_session + def register(): """ @@ -67,15 +69,15 @@ class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin): by `SequentialBackend` """ - def __init__(self, resource_profile=None, **backend_args): + def __init__(self, + num_cpus_per_spark_task: Optional[int] = None, + num_gpus_per_spark_task: Optional[int] = None, + **backend_args): # pylint: disable=super-with-arguments super(SparkDistributedBackend, self).__init__(**backend_args) self._pool = None self._n_jobs = None - self._spark = SparkSession \ - .builder \ - .appName("JoblibSparkBackend") \ - .getOrCreate() + self._spark = get_spark_session() self._spark_context = self._spark.sparkContext self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4()) self._spark_pinned_threads_enabled = isinstance( @@ -91,18 +93,40 @@ def __init__(self, resource_profile=None, **backend_args): self._ipython = get_ipython() except ImportError: self._ipython = None - self._spark_supports_resource_profile = hasattr( - self._spark_context.parallelize([1]), "withResources" - ) and not self._spark.conf.get("spark.master", "").startswith("local") - if self._spark_supports_resource_profile: - self._resource_profile = resource_profile + + self._support_stage_scheduling = self._is_support_stage_scheduling() + self._resource_profile = self._create_resource_profile(num_cpus_per_spark_task, + num_gpus_per_spark_task) + + def _is_support_stage_scheduling(self): + spark_master = self._spark_context.master + is_spark_local_mode = spark_master == "local" or spark_master.startswith("local[") + if is_spark_local_mode: + support_stage_scheduling = False + warnings.warn("Spark local mode doesn't support stage-level scheduling.") else: - self._resource_profile = None - if resource_profile is not None: - warnings.warn( - "Joblib-spark was constructed with a ResourceProfile, but this Apache " - "Spark version does not support stage-level scheduling." - ) + support_stage_scheduling = hasattr( + self._spark_context.parallelize([1]), "withResources" + ) + warnings.warn("Spark version does not support stage-level scheduling.") + return support_stage_scheduling + + def _create_resource_profile(self, + num_cpus_per_spark_task, + num_gpus_per_spark_task) -> Optional[object]: + resource_profile = None + if self._support_stage_scheduling: + self.using_stage_scheduling = True + + default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1")) + default_gpus_per_task = int(self._spark.conf.get("spark.task.resource.gpu.amount", "0")) + num_cpus_per_spark_task = num_cpus_per_spark_task or default_cpus_per_task + num_gpus_per_spark_task = num_gpus_per_spark_task or default_gpus_per_task + + resource_profile = create_resource_profile(num_cpus_per_spark_task, + num_gpus_per_spark_task) + + return resource_profile def _cancel_all_jobs(self): self._is_running = False diff --git a/joblibspark/utils.py b/joblibspark/utils.py new file mode 100644 index 0000000..81918a5 --- /dev/null +++ b/joblibspark/utils.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +The utils functions for joblib spark backend. +""" +from packaging.version import Version +import pyspark + + +# pylint: disable=import-outside-toplevel +def get_spark_session(): + """ + Get the spark session from the active session or create a new one. + """ + from pyspark.sql import SparkSession + + spark_session = SparkSession.getActiveSession() + if spark_session is None: + spark_session = SparkSession \ + .builder \ + .appName("JoblibSparkBackend") \ + .getOrCreate() + return spark_session + + +def create_resource_profile(num_cpus_per_spark_task, num_gpus_per_spark_task): + """ + Create a resource profile for the task. + :param num_cpus_per_spark_task: Number of cpus for each Spark task of current spark job stage. + :param num_gpus_per_spark_task: Number of gpus for each Spark task of current spark job stage. + :return: Spark ResourceProfile + """ + resource_profile = None + if Version(pyspark.__version__).release > (3, 1, 0): + try: + from pyspark.resource.profile import ResourceProfileBuilder + from pyspark.resource.requests import TaskResourceRequests + except ImportError: + pass + task_res_req = TaskResourceRequests().cpus(num_cpus_per_spark_task) + if num_gpus_per_spark_task > 0: + task_res_req = task_res_req.resource("gpu", num_gpus_per_spark_task) + resource_profile = ResourceProfileBuilder().require(task_res_req).build + return resource_profile diff --git a/requirements.txt b/requirements.txt index e7d569b..8a41183 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ joblib>=0.14 +packaging diff --git a/test/test_backend.py b/test/test_backend.py index b2dcd45..2ee6f38 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,26 +1,79 @@ import warnings +from packaging.version import Version +from unittest import mock from unittest.mock import MagicMock +import pyspark +from pyspark.sql import SparkSession + from joblibspark.backend import SparkDistributedBackend +class TestLocalSparkCluster: + @classmethod + def setup_class(cls): + cls.spark = ( + SparkSession.builder.master("local[*]").getOrCreate() + ) + + @classmethod + def teardown_class(cls): + cls.spark.stop() + + def test_effective_n_jobs(self): + backend = SparkDistributedBackend() + max_num_concurrent_tasks = 8 + backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) + + assert backend.effective_n_jobs(n_jobs=None) == 1 + assert backend.effective_n_jobs(n_jobs=-1) == 8 + assert backend.effective_n_jobs(n_jobs=4) == 4 + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + assert backend.effective_n_jobs(n_jobs=16) == 16 + assert len(w) == 1 + + def test_resource_profile_supported(self): + backend = SparkDistributedBackend() + # The test fixture uses a local (standalone) Spark instance, which doesn't support stage-level scheduling. + assert not backend._support_stage_scheduling + + +class TestBasicSparkCluster: + spark = None + @classmethod + def setup_class(cls): + cls.num_cpus_per_spark_task = 1 + cls.num_gpus_per_spark_task = 1 + + cls.spark = ( + SparkSession.builder.master("local-cluster[1, 2, 1024]") + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + .getOrCreate() + ) + + @classmethod + def teardown_class(cls): + cls.spark.stop() -def test_effective_n_jobs(): + def test_resource_profile(self): + backend = SparkDistributedBackend( + num_cpus_per_spark_task=self.num_cpus_per_spark_task, + num_gpus_per_spark_task=self.num_gpus_per_spark_task) - backend = SparkDistributedBackend() - max_num_concurrent_tasks = 8 - backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) + if Version(pyspark.__version__).release >= (3, 4, 0): + assert backend._support_stage_scheduling - assert backend.effective_n_jobs(n_jobs=None) == 1 - assert backend.effective_n_jobs(n_jobs=-1) == 8 - assert backend.effective_n_jobs(n_jobs=4) == 4 + resource_group = backend._resource_profile + assert resource_group.taskResources['cpus'].amount == 1.0 + assert resource_group.taskResources['gpu'].amount == 1.0 - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - assert backend.effective_n_jobs(n_jobs=16) == 16 - assert len(w) == 1 + def test_resource_with_default(self): + backend = SparkDistributedBackend() + if Version(pyspark.__version__).release >= (3, 4, 0): + assert backend._support_stage_scheduling + resource_group = backend._resource_profile + assert resource_group.taskResources['cpus'].amount == 1.0 -def test_resource_profile_supported(self): - backend = SparkDistributedBackend() - # The test fixture uses a local (standalone) Spark instance, which doesn't support stage-level scheduling. - assert backend._spark_supports_resource_profile == False diff --git a/test/test_spark.py b/test/test_spark.py index 2e2d24e..ecf0fab 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -31,6 +31,8 @@ from sklearn import datasets from sklearn import svm +from pyspark.sql import SparkSession + register_spark() @@ -54,47 +56,64 @@ def test_simple(): for i in range(10)) -def test_sklearn_cv(): - iris = datasets.load_iris() - clf = svm.SVC(kernel='linear', C=1) - with parallel_backend('spark', n_jobs=3): - scores = cross_val_score(clf, iris.data, iris.target, cv=5) - - expected = [0.97, 1.0, 0.97, 0.97, 1.0] - - for i in range(5): - assert(pytest.approx(scores[i], 0.01) == expected[i]) - - # test with default n_jobs=-1 - with parallel_backend('spark'): - scores = cross_val_score(clf, iris.data, iris.target, cv=5) - - for i in range(5): - assert(pytest.approx(scores[i], 0.01) == expected[i]) - - -def test_job_cancelling(): - from joblib import Parallel, delayed - import time - import tempfile - import os - - tmp_dir = tempfile.mkdtemp() - - def test_fn(x): - if x == 0: - # make the task-0 fail, then it will cause task 1/2/3 to be canceled. - raise RuntimeError() - else: - time.sleep(15) - # if the task finished successfully, it will write a flag file to tmp dir. - with open(os.path.join(tmp_dir, str(x)), 'w'): - pass - - with pytest.raises(Exception): - with parallel_backend('spark', n_jobs=2): - Parallel()(delayed(test_fn)(i) for i in range(2)) - - time.sleep(30) # wait until we can ensure all task finish or cancelled. - # assert all jobs was cancelled, no flag file will be written to tmp dir. - assert len(os.listdir(tmp_dir)) == 0 +class TestSparkCluster: + spark = None + @classmethod + def setup_class(cls): + cls.num_cpus_per_spark_task = 1 + cls.num_gpus_per_spark_task = 1 + + cls.spark = ( + SparkSession.builder.master("local-cluster[1, 2, 1024]") + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + .getOrCreate() + ) + + @classmethod + def teardown_class(cls): + cls.spark.stop() + + def test_sklearn_cv(self): + iris = datasets.load_iris() + clf = svm.SVC(kernel='linear', C=1) + with parallel_backend('spark', n_jobs=3): + scores = cross_val_score(clf, iris.data, iris.target, cv=5) + + expected = [0.97, 1.0, 0.97, 0.97, 1.0] + + for i in range(5): + assert(pytest.approx(scores[i], 0.01) == expected[i]) + + # test with default n_jobs=-1 + with parallel_backend('spark'): + scores = cross_val_score(clf, iris.data, iris.target, cv=5) + + for i in range(5): + assert(pytest.approx(scores[i], 0.01) == expected[i]) + + def test_job_cancelling(self): + from joblib import Parallel, delayed + import time + import tempfile + import os + + tmp_dir = tempfile.mkdtemp() + + def test_fn(x): + if x == 0: + # make the task-0 fail, then it will cause task 1/2/3 to be canceled. + raise RuntimeError() + else: + time.sleep(15) + # if the task finished successfully, it will write a flag file to tmp dir. + with open(os.path.join(tmp_dir, str(x)), 'w'): + pass + + with pytest.raises(Exception): + with parallel_backend('spark', n_jobs=2): + Parallel()(delayed(test_fn)(i) for i in range(2)) + + time.sleep(30) # wait until we can ensure all task finish or cancelled. + # assert all jobs was cancelled, no flag file will be written to tmp dir. + assert len(os.listdir(tmp_dir)) == 0