diff --git a/test/test_spark.py b/test/test_spark.py index 66beb31..9c78257 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -122,8 +122,6 @@ def test_fn(x): class TestGPUSparkCluster(unittest.TestCase): @classmethod def setup_class(cls): - cls.num_cpus_per_spark_task = 2 - cls.num_gpus_per_spark_task = 1 gpu_discovery_script_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "discover_2_gpu.sh" ) @@ -162,10 +160,10 @@ def get_spark_context(x): from pyspark import TaskContext taskcontext = TaskContext.get() assert taskcontext.cpus() == 2 - assert len(taskcontext.resources().get("gpu").addresses) == 1 - return TaskContext.get() + assert len(taskcontext.resources().get("gpu").addresses) == 2 + return taskcontext.cpus() with parallel_backend('spark', - num_cpus_per_spark_task=self.num_cpus_per_spark_task, - num_gpus_per_spark_task=self.num_gpus_per_spark_task) as (ba, _): + num_cpus_per_spark_task=2, + num_gpus_per_spark_task=2) as (ba, _): Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10))