diff --git a/test/test_backend.py b/test/test_backend.py index d000f06..3253062 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,20 +1,33 @@ import warnings from unittest.mock import MagicMock +from pyspark.sql import SparkSession + from joblibspark.backend import SparkDistributedBackend -def test_effective_n_jobs(): +class TestLocalSparkCluster: + @classmethod + def setup_class(cls): + cls.spark = ( + SparkSession.builder.master("local[*]").getOrCreate() + ) + + @classmethod + def teardown_class(cls): + cls.spark.stop() + + def test_effective_n_jobs(self): - backend = SparkDistributedBackend() - max_num_concurrent_tasks = 8 - backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) + backend = SparkDistributedBackend() + max_num_concurrent_tasks = 8 + backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks) - assert backend.effective_n_jobs(n_jobs=None) == 1 - assert backend.effective_n_jobs(n_jobs=-1) == 8 - assert backend.effective_n_jobs(n_jobs=4) == 4 + assert backend.effective_n_jobs(n_jobs=None) == 1 + assert backend.effective_n_jobs(n_jobs=-1) == 8 + assert backend.effective_n_jobs(n_jobs=4) == 4 - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - assert backend.effective_n_jobs(n_jobs=16) == 16 - assert len(w) == 1 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + assert backend.effective_n_jobs(n_jobs=16) == 16 + assert len(w) == 1 diff --git a/test/test_spark.py b/test/test_spark.py index 2e2d24e..12a85d0 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -31,6 +31,8 @@ from sklearn import datasets from sklearn import svm +from pyspark.sql import SparkSession + register_spark() @@ -44,57 +46,75 @@ def slow_raise_value_error(condition, duration=0.05): raise ValueError("condition evaluated to True") -def test_simple(): - with parallel_backend('spark') as (ba, _): - seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10)) - assert seq == [inc(i) for i in range(10)] +class TestSparkCluster: + spark = None + @classmethod + def setup_class(cls): + cls.num_cpus_per_spark_task = 1 + cls.num_gpus_per_spark_task = 1 + + cls.spark = ( + SparkSession.builder.master("local-cluster[1, 2, 1024]") + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + .getOrCreate() + ) + + @classmethod + def teardown_class(cls): + cls.spark.stop() + + def test_simple(self): + with parallel_backend('spark') as (ba, _): + seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10)) + assert seq == [inc(i) for i in range(10)] - with pytest.raises(BaseException): - Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3) - for i in range(10)) + with pytest.raises(BaseException): + Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3) + for i in range(10)) -def test_sklearn_cv(): - iris = datasets.load_iris() - clf = svm.SVC(kernel='linear', C=1) - with parallel_backend('spark', n_jobs=3): - scores = cross_val_score(clf, iris.data, iris.target, cv=5) + def test_sklearn_cv(self): + iris = datasets.load_iris() + clf = svm.SVC(kernel='linear', C=1) + with parallel_backend('spark', n_jobs=3): + scores = cross_val_score(clf, iris.data, iris.target, cv=5) - expected = [0.97, 1.0, 0.97, 0.97, 1.0] + expected = [0.97, 1.0, 0.97, 0.97, 1.0] - for i in range(5): - assert(pytest.approx(scores[i], 0.01) == expected[i]) + for i in range(5): + assert(pytest.approx(scores[i], 0.01) == expected[i]) - # test with default n_jobs=-1 - with parallel_backend('spark'): - scores = cross_val_score(clf, iris.data, iris.target, cv=5) + # test with default n_jobs=-1 + with parallel_backend('spark'): + scores = cross_val_score(clf, iris.data, iris.target, cv=5) - for i in range(5): - assert(pytest.approx(scores[i], 0.01) == expected[i]) + for i in range(5): + assert(pytest.approx(scores[i], 0.01) == expected[i]) -def test_job_cancelling(): - from joblib import Parallel, delayed - import time - import tempfile - import os + def test_job_cancelling(): + from joblib import Parallel, delayed + import time + import tempfile + import os - tmp_dir = tempfile.mkdtemp() + tmp_dir = tempfile.mkdtemp() - def test_fn(x): - if x == 0: - # make the task-0 fail, then it will cause task 1/2/3 to be canceled. - raise RuntimeError() - else: - time.sleep(15) - # if the task finished successfully, it will write a flag file to tmp dir. - with open(os.path.join(tmp_dir, str(x)), 'w'): - pass + def test_fn(x): + if x == 0: + # make the task-0 fail, then it will cause task 1/2/3 to be canceled. + raise RuntimeError() + else: + time.sleep(15) + # if the task finished successfully, it will write a flag file to tmp dir. + with open(os.path.join(tmp_dir, str(x)), 'w'): + pass - with pytest.raises(Exception): - with parallel_backend('spark', n_jobs=2): - Parallel()(delayed(test_fn)(i) for i in range(2)) + with pytest.raises(Exception): + with parallel_backend('spark', n_jobs=2): + Parallel()(delayed(test_fn)(i) for i in range(2)) - time.sleep(30) # wait until we can ensure all task finish or cancelled. - # assert all jobs was cancelled, no flag file will be written to tmp dir. - assert len(os.listdir(tmp_dir)) == 0 + time.sleep(30) # wait until we can ensure all task finish or cancelled. + # assert all jobs was cancelled, no flag file will be written to tmp dir. + assert len(os.listdir(tmp_dir)) == 0