Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-23.12' into plugin-inherit
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 4, 2023
2 parents e54d522 + b6212ea commit 316e8ba
Showing 1 changed file with 64 additions and 30 deletions.
94 changes: 64 additions & 30 deletions dask_cuda/tests/test_spill.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
from time import sleep

Expand Down Expand Up @@ -58,15 +59,21 @@ def assert_device_host_file_size(


def worker_assert(
dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
total_size,
device_chunk_overhead,
serialized_chunk_overhead,
dask_worker=None,
):
assert_device_host_file_size(
dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
)


def delayed_worker_assert(
dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
total_size,
device_chunk_overhead,
serialized_chunk_overhead,
dask_worker=None,
):
start = time()
while not device_host_file_size_matches(
Expand All @@ -82,6 +89,18 @@ def delayed_worker_assert(
)


def assert_host_chunks(spills_to_disk, dask_worker=None):
if spills_to_disk is False:
assert len(dask_worker.data.host)


def assert_disk_chunks(spills_to_disk, dask_worker=None):
if spills_to_disk is True:
assert len(dask_worker.data.disk or list()) > 0
else:
assert len(dask_worker.data.disk or list()) == 0


@pytest.mark.parametrize(
"params",
[
Expand Down Expand Up @@ -122,7 +141,7 @@ def delayed_worker_assert(
},
],
)
@gen_test(timeout=120)
@gen_test(timeout=30)
async def test_cupy_cluster_device_spill(params):
cupy = pytest.importorskip("cupy")
with dask.config.set(
Expand All @@ -144,6 +163,8 @@ async def test_cupy_cluster_device_spill(params):
) as cluster:
async with Client(cluster, asynchronous=True) as client:

await client.wait_for_workers(1)

rs = da.random.RandomState(RandomState=cupy.random.RandomState)
x = rs.random(int(50e6), chunks=2e6)
await wait(x)
Expand All @@ -153,7 +174,10 @@ async def test_cupy_cluster_device_spill(params):

# Allow up to 1024 bytes overhead per chunk serialized
await client.run(
lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024)
worker_assert,
x.nbytes,
1024,
1024,
)

y = client.compute(x.sum())
Expand All @@ -162,20 +186,19 @@ async def test_cupy_cluster_device_spill(params):
assert (abs(res / x.size) - 0.5) < 1e-3

await client.run(
lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024)
worker_assert,
x.nbytes,
1024,
1024,
)
host_chunks = await client.run(
lambda dask_worker: len(dask_worker.data.host)
await client.run(
assert_host_chunks,
params["spills_to_disk"],
)
disk_chunks = await client.run(
lambda dask_worker: len(dask_worker.data.disk or list())
await client.run(
assert_disk_chunks,
params["spills_to_disk"],
)
for hc, dc in zip(host_chunks.values(), disk_chunks.values()):
if params["spills_to_disk"]:
assert dc > 0
else:
assert hc > 0
assert dc == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -218,7 +241,7 @@ async def test_cupy_cluster_device_spill(params):
},
],
)
@gen_test(timeout=120)
@gen_test(timeout=30)
async def test_cudf_cluster_device_spill(params):
cudf = pytest.importorskip("cudf")

Expand All @@ -243,6 +266,8 @@ async def test_cudf_cluster_device_spill(params):
) as cluster:
async with Client(cluster, asynchronous=True) as client:

await client.wait_for_workers(1)

# There's a known issue with datetime64:
# https://github.com/numpy/numpy/issues/4983#issuecomment-441332940
# The same error above happens when spilling datetime64 to disk
Expand All @@ -264,26 +289,35 @@ async def test_cudf_cluster_device_spill(params):
await wait(cdf2)

del cdf
gc.collect()

host_chunks = await client.run(
lambda dask_worker: len(dask_worker.data.host)
await client.run(
assert_host_chunks,
params["spills_to_disk"],
)
disk_chunks = await client.run(
lambda dask_worker: len(dask_worker.data.disk or list())
await client.run(
assert_disk_chunks,
params["spills_to_disk"],
)
for hc, dc in zip(host_chunks.values(), disk_chunks.values()):
if params["spills_to_disk"]:
assert dc > 0
else:
assert hc > 0
assert dc == 0

await client.run(
lambda dask_worker: worker_assert(dask_worker, nbytes, 32, 2048)
worker_assert,
nbytes,
32,
2048,
)

del cdf2

await client.run(
lambda dask_worker: delayed_worker_assert(dask_worker, 0, 0, 0)
)
while True:
try:
await client.run(
delayed_worker_assert,
0,
0,
0,
)
except AssertionError:
gc.collect()
else:
break

0 comments on commit 316e8ba

Please sign in to comment.