Skip to content

Commit

Permalink
address the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-wang-dl committed Jun 11, 2024
1 parent bceb7b5 commit eb7e9da
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions test/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def test_fn(x):
class TestGPUSparkCluster(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.num_cpus_per_spark_task = 2
cls.num_gpus_per_spark_task = 1
gpu_discovery_script_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "discover_2_gpu.sh"
)
Expand Down Expand Up @@ -162,10 +160,10 @@ def get_spark_context(x):
from pyspark import TaskContext
taskcontext = TaskContext.get()
assert taskcontext.cpus() == 2
assert len(taskcontext.resources().get("gpu").addresses) == 1
return TaskContext.get()
assert len(taskcontext.resources().get("gpu").addresses) == 2
return taskcontext.cpus()

with parallel_backend('spark',
num_cpus_per_spark_task=self.num_cpus_per_spark_task,
num_gpus_per_spark_task=self.num_gpus_per_spark_task) as (ba, _):
num_cpus_per_spark_task=2,
num_gpus_per_spark_task=2) as (ba, _):
Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10))

0 comments on commit eb7e9da

Please sign in to comment.