Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-wang-dl committed Jun 11, 2024
1 parent 13d7a53 commit 166099d
Showing 1 changed file with 56 additions and 4 deletions.
60 changes: 56 additions & 4 deletions test/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
#
from time import sleep
import pytest
import os
from packaging.version import Version, parse
import sklearn
import unittest

if parse(sklearn.__version__) < Version('0.21'):
raise RuntimeError("Test requires sklearn version >=0.21")
Expand All @@ -36,13 +38,11 @@
register_spark()


class TestSparkCluster:
class TestSparkCluster(unittest.TestCase):
spark = None

@classmethod
def setup_class(cls):
cls.num_cpus_per_spark_task = 1
cls.num_gpus_per_spark_task = 1

cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
Expand Down Expand Up @@ -114,3 +114,55 @@ def test_fn(x):
time.sleep(30) # wait until we can ensure all task finish or cancelled.
# assert all jobs was cancelled, no flag file will be written to tmp dir.
assert len(os.listdir(tmp_dir)) == 0


class TestGPUSparkCluster(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.num_cpus_per_spark_task = 2
cls.num_gpus_per_spark_task = 1
gpu_discovery_script_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "discover_2_gpu.sh"
)

cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.resource.gpu.amount", "1")
.config("spark.executor.cores", "2")
.config("spark.worker.resource.gpu.amount", "2")
.config("spark.executor.resource.gpu.amount", "2")
.config("spark.task.maxFailures", "1")
.config(
"spark.worker.resource.gpu.discoveryScript", gpu_discovery_script_path
)
.getOrCreate()
)

@classmethod
def teardown_class(cls):
cls.spark.stop()

def test_resource_group(self):
def get_spark_context(x):
from pyspark import TaskContext
taskcontext = TaskContext.get()
assert taskcontext.cpus() == 1
assert len(taskcontext.resources().get("gpu").addresses) == 1
return TaskContext.get()

with parallel_backend('spark') as (ba, _):
Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10))

def test_customized_resource_group(self):
def get_spark_context(x):
from pyspark import TaskContext
taskcontext = TaskContext.get()
assert taskcontext.cpus() == 2
assert len(taskcontext.resources().get("gpu").addresses) == 1
return TaskContext.get()

with parallel_backend('spark',
num_cpus_per_spark_task=self.num_cpus_per_spark_task,
num_gpus_per_spark_task=self.num_gpus_per_spark_task) as (ba, _):
Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10))

0 comments on commit 166099d

Please sign in to comment.