Skip to content

Commit

Permalink
cleanup imports that wont work with dask>2024.12.1
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jan 7, 2025
1 parent c19a1a7 commit 0557211
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 38 deletions.
18 changes: 10 additions & 8 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import dask
import dask.utils
import dask.dataframe.core
import dask.dataframe as dd
import dask.dataframe.shuffle
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
from dask.dataframe import DASK_EXPR_ENABLED
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import dask_deserialize, dask_serialize

Expand All @@ -19,12 +18,15 @@
from .proxify_device_objects import proxify_decorator, unproxify_decorator


if not DASK_EXPR_ENABLED:
raise ValueError(
"Dask-CUDA no longer supports the legacy Dask DataFrame API. "
"Please set the 'dataframe.query-planning' config to `True` "
"or None, or downgrade RAPIDS to <=24.12."
)
try:
if not dd._dask_expr_enabled():
raise ValueError(
"Dask-CUDA no longer supports the legacy Dask DataFrame API. "
"Please set the 'dataframe.query-planning' config to `True` "
"or None, or downgrade RAPIDS to <=24.12."
)
except AttributeError:
pass


# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
Expand Down
4 changes: 2 additions & 2 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dask.base import tokenize
from dask.dataframe import DataFrame, Series
from dask.dataframe.core import _concat as dd_concat
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
from dask.dataframe.dispatch import group_split_dispatch, hash_object_dispatch
from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize
from distributed.worker import Worker
Expand Down Expand Up @@ -585,7 +585,7 @@ def _layer(self):
# Execute an explicit-comms shuffle
if not hasattr(self, "_ec_shuffled"):
on = self.partitioning_index
df = dask_expr._collection.new_collection(self.frame)
df = dask_expr.new_collection(self.frame)
self._ec_shuffled = shuffle(
df,
[on] if isinstance(on, str) else on,
Expand Down
33 changes: 13 additions & 20 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import dask
import dask.array.core
import dask.dataframe.methods
import dask.dataframe.backends
import dask.dataframe.dispatch
import dask.dataframe.utils
import dask.utils
import distributed.protocol
Expand All @@ -22,16 +23,6 @@

from dask_cuda.disk_io import disk_read

try:
from dask.dataframe.backends import concat_pandas
except ImportError:
from dask.dataframe.methods import concat_pandas

try:
from dask.dataframe.dispatch import make_meta_dispatch as make_meta_dispatch
except ImportError:
from dask.dataframe.utils import make_meta as make_meta_dispatch

from .disk_io import SpillToDiskFile
from .is_device_object import is_device_object

Expand Down Expand Up @@ -893,10 +884,12 @@ def obj_pxy_dask_deserialize(header, frames):
return subclass(pxy)


@dask.dataframe.core.get_parallel_type.register(ProxyObject)
@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
def get_parallel_type_proxy_object(obj: ProxyObject):
# Notice, `get_parallel_type()` needs a instance not a type object
return dask.dataframe.core.get_parallel_type(obj.__class__.__new__(obj.__class__))
return dask.dataframe.dispatch.get_parallel_type(
obj.__class__.__new__(obj.__class__)
)


def unproxify_input_wrapper(func):
Expand All @@ -913,24 +906,24 @@ def wrapper(*args, **kwargs):

# Register dispatch of ProxyObject on all known dispatch objects
for dispatch in (
dask.dataframe.core.hash_object_dispatch,
make_meta_dispatch,
dask.dataframe.dispatch.hash_object_dispatch,
dask.dataframe.dispatch.make_meta_dispatch,
dask.dataframe.utils.make_scalar,
dask.dataframe.core.group_split_dispatch,
dask.dataframe.dispatch.group_split_dispatch,
dask.array.core.tensordot_lookup,
dask.array.core.einsum_lookup,
dask.array.core.concatenate_lookup,
):
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))

dask.dataframe.methods.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.methods.concat)
dask.dataframe.dispatch.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
)


# We overwrite the Dask dispatch of Pandas objects in order to
# deserialize all ProxyObjects before concatenating
dask.dataframe.methods.concat_dispatch.register(
dask.dataframe.dispatch.concat_dispatch.register(
(pandas.DataFrame, pandas.Series, pandas.Index),
unproxify_input_wrapper(concat_pandas),
unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
)
16 changes: 8 additions & 8 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,27 +504,27 @@ def test_pandas():
df1 = pandas.DataFrame({"a": range(10)})
df2 = pandas.DataFrame({"a": range(10)})

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert_frame_equal(res, got)

df1 = pandas.Series(range(10))
df2 = pandas.Series(range(10))

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert all(res == got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert all(res == got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert all(res == got)


Expand Down

0 comments on commit 0557211

Please sign in to comment.