Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-wang-dl committed May 30, 2024
1 parent 7a1b03a commit b5ff9cc
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 52 deletions.
35 changes: 24 additions & 11 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
import warnings
from unittest.mock import MagicMock

from pyspark.sql import SparkSession

from joblibspark.backend import SparkDistributedBackend


def test_effective_n_jobs():
class TestLocalSparkCluster:
@classmethod
def setup_class(cls):
cls.spark = (
SparkSession.builder.master("local[*]").getOrCreate()
)

@classmethod
def teardown_class(cls):
cls.spark.stop()

def test_effective_n_jobs(self):

backend = SparkDistributedBackend()
max_num_concurrent_tasks = 8
backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks)
backend = SparkDistributedBackend()
max_num_concurrent_tasks = 8
backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks)

assert backend.effective_n_jobs(n_jobs=None) == 1
assert backend.effective_n_jobs(n_jobs=-1) == 8
assert backend.effective_n_jobs(n_jobs=4) == 4
assert backend.effective_n_jobs(n_jobs=None) == 1
assert backend.effective_n_jobs(n_jobs=-1) == 8
assert backend.effective_n_jobs(n_jobs=4) == 4

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert backend.effective_n_jobs(n_jobs=16) == 16
assert len(w) == 1
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert backend.effective_n_jobs(n_jobs=16) == 16
assert len(w) == 1
102 changes: 61 additions & 41 deletions test/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from sklearn import datasets
from sklearn import svm

from pyspark.sql import SparkSession

register_spark()


Expand All @@ -44,57 +46,75 @@ def slow_raise_value_error(condition, duration=0.05):
raise ValueError("condition evaluated to True")


def test_simple():
with parallel_backend('spark') as (ba, _):
seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]
class TestSparkCluster:
spark = None
@classmethod
def setup_class(cls):
cls.num_cpus_per_spark_task = 1
cls.num_gpus_per_spark_task = 1

cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.maxFailures", "1")
.getOrCreate()
)

@classmethod
def teardown_class(cls):
cls.spark.stop()

def test_simple(self):
with parallel_backend('spark') as (ba, _):
seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]

with pytest.raises(BaseException):
Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3)
for i in range(10))
with pytest.raises(BaseException):
Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3)
for i in range(10))


def test_sklearn_cv():
iris = datasets.load_iris()
clf = svm.SVC(kernel='linear', C=1)
with parallel_backend('spark', n_jobs=3):
scores = cross_val_score(clf, iris.data, iris.target, cv=5)
def test_sklearn_cv(self):
iris = datasets.load_iris()
clf = svm.SVC(kernel='linear', C=1)
with parallel_backend('spark', n_jobs=3):
scores = cross_val_score(clf, iris.data, iris.target, cv=5)

expected = [0.97, 1.0, 0.97, 0.97, 1.0]
expected = [0.97, 1.0, 0.97, 0.97, 1.0]

for i in range(5):
assert(pytest.approx(scores[i], 0.01) == expected[i])
for i in range(5):
assert(pytest.approx(scores[i], 0.01) == expected[i])

# test with default n_jobs=-1
with parallel_backend('spark'):
scores = cross_val_score(clf, iris.data, iris.target, cv=5)
# test with default n_jobs=-1
with parallel_backend('spark'):
scores = cross_val_score(clf, iris.data, iris.target, cv=5)

for i in range(5):
assert(pytest.approx(scores[i], 0.01) == expected[i])
for i in range(5):
assert(pytest.approx(scores[i], 0.01) == expected[i])


def test_job_cancelling():
from joblib import Parallel, delayed
import time
import tempfile
import os
def test_job_cancelling():
from joblib import Parallel, delayed
import time
import tempfile
import os

tmp_dir = tempfile.mkdtemp()
tmp_dir = tempfile.mkdtemp()

def test_fn(x):
if x == 0:
# make the task-0 fail, then it will cause task 1/2/3 to be canceled.
raise RuntimeError()
else:
time.sleep(15)
# if the task finished successfully, it will write a flag file to tmp dir.
with open(os.path.join(tmp_dir, str(x)), 'w'):
pass
def test_fn(x):
if x == 0:
# make the task-0 fail, then it will cause task 1/2/3 to be canceled.
raise RuntimeError()
else:
time.sleep(15)
# if the task finished successfully, it will write a flag file to tmp dir.
with open(os.path.join(tmp_dir, str(x)), 'w'):
pass

with pytest.raises(Exception):
with parallel_backend('spark', n_jobs=2):
Parallel()(delayed(test_fn)(i) for i in range(2))
with pytest.raises(Exception):
with parallel_backend('spark', n_jobs=2):
Parallel()(delayed(test_fn)(i) for i in range(2))

time.sleep(30) # wait until we can ensure all task finish or cancelled.
# assert all jobs was cancelled, no flag file will be written to tmp dir.
assert len(os.listdir(tmp_dir)) == 0
time.sleep(30) # wait until we can ensure all task finish or cancelled.
# assert all jobs was cancelled, no flag file will be written to tmp dir.
assert len(os.listdir(tmp_dir)) == 0

0 comments on commit b5ff9cc

Please sign in to comment.