Skip to content

Commit

Permalink
Remove Dask test skips and add Dask as a test dependency (#2047)
Browse files Browse the repository at this point in the history
  • Loading branch information
kounelisagis authored Aug 23, 2024
1 parent 0cd0c03 commit b7f8c13
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 37 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ test = [
"psutil",
"pyarrow",
"pandas",
"dask[distributed]",
]

[project.urls]
Expand Down
56 changes: 19 additions & 37 deletions tiledb/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from .common import DiskTestCase

# Skip this test if dask is unavailable
da = pytest.importorskip("dask.array")
da_array = pytest.importorskip("dask.array")
da_distributed = pytest.importorskip("dask.distributed")


@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")
class TestDaskSupport(DiskTestCase):
def test_dask_from_numpy_1d(self):
uri = self.path("np_1attr")
Expand All @@ -22,10 +22,10 @@ def test_dask_from_numpy_1d(self):
T.close()

with tiledb.open(uri) as T:
D = da.from_tiledb(T)
D = da_array.from_tiledb(T)
np.testing.assert_array_equal(D, A)

D2 = da.from_tiledb(uri)
D2 = da_array.from_tiledb(uri)
np.testing.assert_array_equal(D2, A)
self.assertAlmostEqual(
np.mean(A), D2.mean().compute(scheduler="single-threaded")
Expand All @@ -43,18 +43,6 @@ def _make_multiattr_2d(self, uri, shape=(0, 100), tile=10):
tiledb.DenseArray.create(uri, schema)

@pytest.mark.filterwarnings("ignore:There is no current event loop")
@pytest.mark.filterwarnings(
# In Python 3.7 on POSIX systems, Hurricane outputs a warning message
# that "make_current is deprecated." This should be fixed by Dask in
# future releases.
"ignore:make_current is deprecated"
)
@pytest.mark.skipif(
condition=(
sys.version_info >= (3, 11) and (datetime.now() < datetime(2023, 1, 6))
),
reason="https://github.com/dask/distributed/issues/6785",
)
def test_dask_multiattr_2d(self):
uri = self.path("multiattr")

Expand All @@ -66,32 +54,30 @@ def test_dask_multiattr_2d(self):
T[:] = {"attr1": ar1, "attr2": ar2}
with tiledb.DenseArray(uri, mode="r", attr="attr2") as T:
# basic round-trip from dask.array
D = da.from_tiledb(T, attribute="attr2")
D = da_array.from_tiledb(T, attribute="attr2")
np.testing.assert_array_equal(ar2, np.array(D))

# smoke-test computation
# note: re-init from_tiledb each time, or else dask just uses the cached materialization
D = da.from_tiledb(uri, attribute="attr2")
D = da_array.from_tiledb(uri, attribute="attr2")
self.assertAlmostEqual(np.mean(ar2), D.mean().compute(scheduler="threads"))
D = da.from_tiledb(uri, attribute="attr2")
D = da_array.from_tiledb(uri, attribute="attr2")
self.assertAlmostEqual(
np.mean(ar2), D.mean().compute(scheduler="single-threaded")
)
D = da.from_tiledb(uri, attribute="attr2")
D = da_array.from_tiledb(uri, attribute="attr2")
self.assertAlmostEqual(np.mean(ar2), D.mean().compute(scheduler="processes"))

# test dask.distributed
from dask.distributed import Client

D = da.from_tiledb(uri, attribute="attr2")
with Client():
D = da_array.from_tiledb(uri, attribute="attr2")
with da_distributed.Client():
np.testing.assert_approx_equal(D.mean().compute(), np.mean(ar2))

def test_dask_write(self):
uri = self.path("dask_w")
D = da.random.random(10, 10)
D = da_array.random.random(10, 10)
D.to_tiledb(uri)
DT = da.from_tiledb(uri)
DT = da_array.from_tiledb(uri)
np.testing.assert_array_equal(D, DT)

def test_dask_overlap_blocks(self):
Expand All @@ -101,10 +87,10 @@ def test_dask_overlap_blocks(self):
T.close()

with tiledb.open(uri) as T:
D = da.from_tiledb(T)
D = da_array.from_tiledb(T)
np.testing.assert_array_equal(D, A)

D2 = da.from_tiledb(uri)
D2 = da_array.from_tiledb(uri)
np.testing.assert_array_equal(D2, A)

D3 = D2.map_overlap(
Expand Down Expand Up @@ -133,7 +119,7 @@ def test_labeled_dask_overlap_blocks(self):
with tiledb.open(uri, "w", attr="TDB_VALUES") as T:
T[:] = A

D2 = da.from_tiledb(uri, attribute="TDB_VALUES")
D2 = da_array.from_tiledb(uri, attribute="TDB_VALUES")

D3 = D2.map_overlap(
lambda x: x + 1, depth={0: 0, 1: 1, 2: 1}, dtype=D2.dtype, boundary="none"
Expand All @@ -160,19 +146,14 @@ def test_labeled_dask_blocks(self):
with tiledb.open(uri, "w", attr="TDB_VALUES") as D1:
D1[:] = A

D2 = da.from_tiledb(uri, attribute="TDB_VALUES")
D2 = da_array.from_tiledb(uri, attribute="TDB_VALUES")

D3 = D2.map_blocks(lambda x: x + 1, dtype=D2.dtype).compute(
scheduler="processes"
)
np.testing.assert_array_equal(D2 + 1, D3)


@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")
@pytest.mark.skipif(
sys.version_info[:2] == (3, 8),
reason="Fails on Python 3.8 due to dask worker restarts",
)
def test_sc33742_dask_array_object_dtype_conversion():
# This test verifies that an array can be converted to buffer after serialization
# through several dask.distributed compute steps. The original source of the issue
Expand All @@ -182,7 +163,6 @@ def test_sc33742_dask_array_object_dtype_conversion():

import dask
import numpy as np
from dask.distributed import Client, LocalCluster

@dask.delayed
def get_data():
Expand Down Expand Up @@ -213,7 +193,9 @@ def use_data(data):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
global client
client = Client(LocalCluster(scheduler_port=9786, dashboard_address=9787))
client = da_distributed.Client(
da_distributed.LocalCluster(scheduler_port=9786, dashboard_address=9787)
)

w = []

Expand Down

0 comments on commit b7f8c13

Please sign in to comment.