Skip to content

Commit

Permalink
Add resource_profile support (#52)
Browse files Browse the repository at this point in the history
* add resource_profile support

* address the comments

* fix test

* fix test

* add spark cluster test

* update tests

* update tests

* address the comments
  • Loading branch information
lu-wang-dl authored Jun 11, 2024
1 parent e716ac4 commit 66ec069
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 81 deletions.
50 changes: 44 additions & 6 deletions joblibspark/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from multiprocessing.pool import ThreadPool
import uuid
from typing import Optional
from packaging.version import Version, parse

from joblib.parallel \
Expand All @@ -40,10 +41,11 @@
from py4j.clientserver import ClientServer

import pyspark
from pyspark.sql import SparkSession
from pyspark import cloudpickle
from pyspark.util import VersionUtils

from .utils import create_resource_profile, get_spark_session


def register():
"""
Expand Down Expand Up @@ -71,15 +73,15 @@ class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin):
by `SequentialBackend`
"""

def __init__(self, **backend_args):
def __init__(self,
num_cpus_per_spark_task: Optional[int] = None,
num_gpus_per_spark_task: Optional[int] = None,
**backend_args):
# pylint: disable=super-with-arguments
super(SparkDistributedBackend, self).__init__(**backend_args)
self._pool = None
self._n_jobs = None
self._spark = SparkSession \
.builder \
.appName("JoblibSparkBackend") \
.getOrCreate()
self._spark = get_spark_session()
self._spark_context = self._spark.sparkContext
self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4())
self._spark_pinned_threads_enabled = isinstance(
Expand All @@ -96,6 +98,40 @@ def __init__(self, **backend_args):
except ImportError:
self._ipython = None

self._support_stage_scheduling = self._is_support_stage_scheduling()
self._resource_profile = self._create_resource_profile(num_cpus_per_spark_task,
num_gpus_per_spark_task)

def _is_support_stage_scheduling(self):
spark_master = self._spark_context.master
is_spark_local_mode = spark_master == "local" or spark_master.startswith("local[")
if is_spark_local_mode:
support_stage_scheduling = False
warnings.warn("Spark local mode doesn't support stage-level scheduling.")
else:
support_stage_scheduling = hasattr(
self._spark_context.parallelize([1]), "withResources"
)
warnings.warn("Spark version does not support stage-level scheduling.")
return support_stage_scheduling

def _create_resource_profile(self,
num_cpus_per_spark_task,
num_gpus_per_spark_task) -> Optional[object]:
resource_profile = None
if self._support_stage_scheduling:
self.using_stage_scheduling = True

default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1"))
default_gpus_per_task = int(self._spark.conf.get("spark.task.resource.gpu.amount", "0"))
num_cpus_per_spark_task = num_cpus_per_spark_task or default_cpus_per_task
num_gpus_per_spark_task = num_gpus_per_spark_task or default_gpus_per_task

resource_profile = create_resource_profile(num_cpus_per_spark_task,
num_gpus_per_spark_task)

return resource_profile

def _cancel_all_jobs(self):
self._is_running = False
if not self._spark_supports_job_cancelling:
Expand Down Expand Up @@ -178,6 +214,8 @@ def run_on_worker_and_fetch_result():

# TODO: handle possible spark exception here. # pylint: disable=fixme
worker_rdd = self._spark.sparkContext.parallelize([0], 1)
if self._resource_profile:
worker_rdd = worker_rdd.withResources(self._resource_profile)
def mapper_fn(_):
return cloudpickle.dumps(func())
if self._spark_supports_job_cancelling:
Expand Down
58 changes: 58 additions & 0 deletions joblibspark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
The utils functions for joblib spark backend.
"""
from packaging.version import Version
import pyspark


# pylint: disable=import-outside-toplevel
def get_spark_session():
"""
Get the spark session from the active session or create a new one.
"""
from pyspark.sql import SparkSession

spark_session = SparkSession.getActiveSession()
if spark_session is None:
spark_session = SparkSession \
.builder \
.appName("JoblibSparkBackend") \
.getOrCreate()
return spark_session


def create_resource_profile(num_cpus_per_spark_task, num_gpus_per_spark_task):
"""
Create a resource profile for the task.
:param num_cpus_per_spark_task: Number of cpus for each Spark task of current spark job stage.
:param num_gpus_per_spark_task: Number of gpus for each Spark task of current spark job stage.
:return: Spark ResourceProfile
"""
resource_profile = None
if Version(pyspark.__version__).release >= (3, 4, 0):
try:
from pyspark.resource.profile import ResourceProfileBuilder
from pyspark.resource.requests import TaskResourceRequests
except ImportError:
pass
task_res_req = TaskResourceRequests().cpus(num_cpus_per_spark_task)
if num_gpus_per_spark_task > 0:
task_res_req = task_res_req.resource("gpu", num_gpus_per_spark_task)
resource_profile = ResourceProfileBuilder().require(task_res_req).build
return resource_profile
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
joblib>=0.14
packaging
4 changes: 4 additions & 0 deletions test/discover_2_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

# This script is used in spark GPU cluster config for discovering available GPU.
echo "{\"name\":\"gpu\",\"addresses\":[\"0\",\"1\"]}"
94 changes: 83 additions & 11 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,92 @@
import warnings
import os
from packaging.version import Version
import unittest
from unittest.mock import MagicMock

import pyspark
from pyspark.sql import SparkSession

from joblibspark.backend import SparkDistributedBackend


def test_effective_n_jobs():
class TestLocalSparkCluster(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.spark = (
SparkSession.builder.master("local[*]").getOrCreate()
)

@classmethod
def teardown_class(cls):
cls.spark.stop()

def test_effective_n_jobs(self):
backend = SparkDistributedBackend()
max_num_concurrent_tasks = 8
backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks)

assert backend.effective_n_jobs(n_jobs=None) == 1
assert backend.effective_n_jobs(n_jobs=-1) == 8
assert backend.effective_n_jobs(n_jobs=4) == 4

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert backend.effective_n_jobs(n_jobs=16) == 16
assert len(w) == 1

def test_resource_profile_supported(self):
backend = SparkDistributedBackend()
# The test fixture uses a local (standalone) Spark instance, which doesn't support stage-level scheduling.
assert not backend._support_stage_scheduling


class TestBasicSparkCluster(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.num_cpus_per_spark_task = 1
cls.num_gpus_per_spark_task = 1
gpu_discovery_script_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "discover_2_gpu.sh"
)

cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.resource.gpu.amount", "1")
.config("spark.executor.cores", "2")
.config("spark.worker.resource.gpu.amount", "2")
.config("spark.executor.resource.gpu.amount", "2")
.config("spark.task.maxFailures", "1")
.config(
"spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path
)
.getOrCreate()
)

@classmethod
def teardown_class(cls):
cls.spark.stop()

@unittest.skipIf(Version(pyspark.__version__).release < (3, 4, 0),
"Resource group is only supported since spark 3.4.0")
def test_resource_profile(self):
backend = SparkDistributedBackend(
num_cpus_per_spark_task=self.num_cpus_per_spark_task,
num_gpus_per_spark_task=self.num_gpus_per_spark_task)

assert backend._support_stage_scheduling

resource_group = backend._resource_profile
assert resource_group.taskResources['cpus'].amount == 1.0
assert resource_group.taskResources['gpu'].amount == 1.0

backend = SparkDistributedBackend()
max_num_concurrent_tasks = 8
backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks)
@unittest.skipIf(Version(pyspark.__version__).release < (3, 4, 0),
"Resource group is only supported since spark 3.4.0")
def test_resource_with_default(self):
backend = SparkDistributedBackend()

assert backend.effective_n_jobs(n_jobs=None) == 1
assert backend.effective_n_jobs(n_jobs=-1) == 8
assert backend.effective_n_jobs(n_jobs=4) == 4
assert backend._support_stage_scheduling

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert backend.effective_n_jobs(n_jobs=16) == 16
assert len(w) == 1
resource_group = backend._resource_profile
assert resource_group.taskResources['cpus'].amount == 1.0
Loading

0 comments on commit 66ec069

Please sign in to comment.