Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dask to array-api-compat #76

Merged
merged 25 commits into from
Feb 6, 2024
Merged

Add dask to array-api-compat #76

merged 25 commits into from
Feb 6, 2024

Conversation

lithomas1
Copy link
Contributor

No description provided.

@rgommers rgommers mentioned this pull request Dec 19, 2023
@lithomas1 lithomas1 marked this pull request as ready for review December 27, 2023 00:46
@lithomas1
Copy link
Contributor Author

lithomas1 commented Dec 27, 2023

Dask tests should be green, but are currently very slow.

disables the dask scheduler
"""
import dask
dask.config.set(scheduler='synchronous')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is needed anymore (it didn't seem to help with the test that OOMed - which is now disabled).

Leaving this in just in case I haven't found the correct options.

@@ -303,6 +303,8 @@ def _asarray(
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, the array module of dask is called dask.array.

I guess we could match that here, but it feels kind of weird to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think (but am not 100% sure) that using namespace == 'dask.array' here is preferred here, since it's meant to identify the actual namespace rather than the package.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest waiting for @asmeurer to weigh in here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This keyword is only used internally so it doesn't really matter, but I like dask.array.

elif _is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dask supports cupy on the GPU.

Is this something we also need to take into consideration?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, it does support CuPy. I"m not sure that a stream input can be supported though. If a Dask array spans multiple machines, I think there'd be multiple streams with numbers that are unrelated to each other. Which therefore can't be supported at the Dask level in this API.

That's probably fine - you'd only move arrays in a single process to another device like this I think, so maybe the whole to_device method doesn't quite work for Dask? @jakirkham any thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if dask can't really support device transfers or anything I would not call the default device "cpu" as that's misleading. We could just create a proxy DaskDevice object that serves as the device for dask arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lithomas1 what's the status of this comment? Does the device() helper above need to be updated for dask?

@lithomas1
Copy link
Contributor Author

@rgommers
This should be ready for a first pass now.

@rgommers
Copy link
Member

rgommers commented Jan 2, 2024

Nice @lithomas1! Question: did you already try this with scikit-learn or SciPy functionality that uses array_api_compat under the hood, and if so did that work?

Array API Tests (Dask) ... Failing after 130m

Let's try and fix that first I'd say, this isn't really workable. It uses max_examples=200 says the log - how about starting with 1? The first goal here is checking what exists yes/no, for which you need only a single example. That will make it much easier to iterate on this. I assume that'll fix this practical issue - maybe @honno or @asmeurer have better ideas.

@lithomas1
Copy link
Contributor Author

Thanks for the review.

Nice @lithomas1! Question: did you already try this with scikit-learn or SciPy functionality that uses array_api_compat under the hood, and if so did that work?

I'm currently testing with scikit-learn, and the biggest issue I'm facing is the array having NaN shapes (since the dask computation hasn't been compute()'ed yet). From a quick glance, it looks like scikit-learn uses the shape in allocating the result array, so there doesn't seem to be easy way to avoid the shape call.

Currently, I'm working around that by adding a lot of np.asarray calls to force computation - but that probably lowers performance for the other eager backends, right?

Any ideas on how to fix this?

Array API Tests (Dask) ... Failing after 130m

Let's try and fix that first I'd say, this isn't really workable. It uses max_examples=200 says the log - how about starting with 1? The first goal here is checking what exists yes/no, for which you need only a single example. That will make it much easier to iterate on this. I assume that'll fix this practical issue - maybe @honno or @asmeurer have better ideas.

Hm, I think I got this to green last time, even with the full amount of max-examples.

I'm going to try lowering it to 5.
(I remember that when testing this last time, there were tests that failed when you increased the number of max-examples)

@lithomas1
Copy link
Contributor Author

Update on the testing with scikit-learn/scipy:

I did a really quick test with scikit-learns train_test_split function, and it looks like that worked.

@rgommers
Copy link
Member

rgommers commented Jan 4, 2024

I did a really quick test with scikit-learns train_test_split function, and it looks like that worked.

Nice! That's probably the most heavily used function in all of scikit-learn, so a good start:)

and the biggest issue I'm facing is the array having NaN shapes (since the dask computation hasn't been compute()'ed yet). From a quick glance, it looks like scikit-learn uses the shape in allocating the result array, so there doesn't seem to be easy way to avoid the shape call.

Currently, I'm working around that by adding a lot of np.asarray calls to force computation - but that probably lowers performance for the other eager backends, right?

Any ideas on how to fix this?

Yes, not by much but there is overhead for each asarray call. In general it's tricky. nan shapes should come from data-dependent operations mostly, so I guess you're running into something like this?

>>> x = da.ones((3))
>>> y = da.unique(da.zeros((3)))
>>> x
dask.array<ones_like, shape=(3,), dtype=float64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y
dask.array<getitem, shape=(nan,), dtype=float64, chunksize=(nan,), chunktype=numpy.ndarray>
>>> da.shape(x + y)
...
ValueError: operands could not be broadcast together with shapes (3,) (nan,)

Allocating new arrays with a shape that depends on the result shape of some other operation with a nan in the shape attribute may be tricky. In some cases it does work, e.g.:

>>> z = da.empty_like(y)
>>> z
dask.array<empty_like, shape=(nan,), dtype=float64, chunksize=(nan,), chunktype=numpy.ndarray>

and that can stay lazy just fine. But if it's not the same shape and you need a new array with an extra dimension, things become tricky I guess, since the nan is actually a nan and not some custom object that can be carried forward when constructing a new array.

Can you point to the code in question?

@lithomas1
Copy link
Contributor Author

This one is the one that I've hit so far.

https://github.com/scikit-learn/scikit-learn/blob/e2b3785b6a1f96989c3992bffa2a05ef5c048a7e/sklearn/discriminant_analysis.py#L116

Not too familiar with the sklearn code, but would it be fair to assume that a good chunk of places where we have np.zeros would be converted to xp.zeros (as the code is made array API compatible)?

Then we would have a lot more
https://github.com/search?q=repo%3Ascikit-learn%2Fscikit-learn+np.zeros+language%3APython+path%3A%2F%5Esklearn%5C%2F%2F&type=code

@rgommers
Copy link
Member

rgommers commented Jan 4, 2024

Not too familiar with the sklearn code, but would it be fair to assume that a good chunk of places where we have np.zeros would be converted to xp.zeros (as the code is made array API compatible)?

Summary of what we just discussed:

  • Dask is able to do shape computations in many cases, so we don't expect too many cases where this nan issue will be problematic.
  • For those cases, it's probably fine to write a small utility function for use in scikit-learn that checks if an array is a dask array and if so returns .compute() and otherwise is a no-op.

@betatim
Copy link
Member

betatim commented Jan 5, 2024

Joining in for the chat about scikit-learn. In the particular case you linked to the code is computing something from the input data (the number of classes) and then creating a new array with a shape based on the number of classes found in the input data.

I think this means there isn't much we can do, other than getting an actual value for the shape. Ideally the underlying lib (dask in this case) can compute that without doing all of the outstanding work. It sounds like that is what dask is (sometimes) doing?

I'm not sure what I think about adding code to scikit-learn to deal with "if dask then call .compute() else no-op". The whole point of the Array API is that as a array consumer you don't have to know the producer of the array. It is annoying to have that kind of code, but that isn't the reason I dislike it. The reason I dislike it is that it means we can't write "scikit-learn supports all Array API compliant arrays as input". Instead we need to explicitly list the set of Array API compliant inputs that scikit-learn handles.

I'm not sure what the right solution is. It is probably bigger than array-api-compat though. In the discussions about lazy vs not-lazy array producers my impression was that we kinda didn't really conclude anything. People like me would like accessing things like .shape implicitly to trigger a computation (if the underlying array producer is lazy), the lazy library implementors prefer explicit .compute() calls or the concept of a LazyInteger (I think).

So I don't know how/where/if we should restart this discussion :-/

@rgommers
Copy link
Member

rgommers commented Jan 5, 2024

People like me would like accessing things like .shape implicitly to trigger a computation (if the underlying array producer is lazy), the lazy library implementors prefer explicit .compute() calls or the concept of a LazyInteger (I think).

To me this feels more like missing functionality in Dask. Accessing .shape[0] to use as input to constructing a new array can stay lazy just fine - if only dask wouldn't use plain nan but a smarter object to represent "not yet computed". It could have a repr of nan while being an object that carries a reference to its parent array's shape. This already works if you use the whole shape (see z = da.empty_like(y) above), it's just not implemented for a partial use of that shape. And having that functionality should be useful for many such cases.

@jakirkham do you think this is implementable in Dask with reasonable effort?

@@ -322,10 +324,13 @@ def _asarray(
if copy in COPY_FALSE:
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, xp.ndarray):
# TODO: This feels wrong (__array__ is not in the standard)
# Dask doesn't support DLPack, though, so, this'll do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only used for numpy and cupy (and now dask). The internals of the wrappers can use specific functionality that isn't in the standard because they can assume that they are operating on the specific array library. So this isn't a problem.

if dtype is not None and obj.dtype != dtype:
copy = True
if copy in COPY_TRUE:
# Dask arrays are immutable, so copy doesn't do anything
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably fine but maybe this sort of thing should be clarified in the standard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think what I was mostly worried about was whether I need to return a new dask array object (I guess this would kinda be like a view on the old object) or not.

@asmeurer
Copy link
Member

asmeurer commented Jan 8, 2024

This is looking good. It looks like the basis for most of this is that the dask functions are basically the same as numpy, so the existing wrappers that are used for numpy and cupy can just be reused for dask (i.e., all of these. I'll assume that you've actually verified that this is the case for these functions either manually or using the tests.

Note that if a some dask function does end up deviating significantly from numpy, don't feel like you need to always use the helpers in common. Those are there just to avoid code duplicate between numpy and cupy, but for instance, you'll notice that the torch wrappers barely use common at all.

@lithomas1
Copy link
Contributor Author

This is looking good. It looks like the basis for most of this is that the dask functions are basically the same as numpy, so the existing wrappers that are used for numpy and cupy can just be reused for dask (i.e., all of these. I'll assume that you've actually verified that this is the case for these functions either manually or using the tests.

I just finished wrapping the linalg functions, but I haven't done the wrapping on the FFT functions at all, though.
Should I wait for #78 to be merged first, then?

asmeurer added a commit to asmeurer/array-api-compat that referenced this pull request Jan 17, 2024
@asmeurer asmeurer mentioned this pull request Jan 18, 2024
@lithomas1
Copy link
Contributor Author

@asmeurer @rgommers

This should be ready for a final pass now.

dask-xfails.txt Outdated
#| Draw 1 (key): (slice(None, None, None), slice(None, None, None))
#| Draw 2 (value): dask.array<zeros_like, shape=(0, 2), dtype=bool, chunksize=(0, 2), chunktype=numpy.ndarray>

# TODO: this also skips test_setitem_masking unnecessarily
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change the xfails file matching to require an exact match. No reason not to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I make a PR into array-api-tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes go ahead. It should be a trivial change. Just change some ins to == in conftest.py, and make sure some existing xfail files don't need to be updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, it turns out that after fixing some other tests, this test(test_setitem_masking) started failing again (the test was passing because of a bug I had in my asarray version), and the comment isn't relevant anymore.

@asmeurer
Copy link
Member

The code itself here looks fine. My main comment is the same as before, that we should be sure that the numpy/cupy generic wrappers are all actually correct for dask too. But the test suite will also be a good check on this. Have you ran the tests with more than the 5 max examples that are used on the CI?

I guess the best thing to do would be to test this out with some consumer library like scipy or scikit-learn. If it's easier to do that after making a release we can do that.

Should we also see if any dask devs want to take a look here?

@@ -77,7 +77,8 @@ def matrix_rank(x: ndarray,
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
S = xp.linalg.svdvals(x, **kwargs)
#S = xp.linalg.svd(x, compute_uv=False, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks like it is causing the numpy array api tests to fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current numpy dev CI failures can be ignored (they should go away once numpy/numpy#25668 is merged), but the other ones are important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. I think this should be equal to svdvals (if svdvals exists), so I just used a hasattr check. Lemme know if that's too hacky

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use get_xp(xp).linalg.svdvals so that it uses the svdvals alias. This pattern is already used in a few places in this file.

As for not wrapping svdvals when it is present, we could do that too (I did this for vector_norm

if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
but I forgot about svdvals).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

concatenate as concat,
)

del da
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also manage __all__ in these files the same as in the other submodules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated for everything except linalg.

I'm not sure if we can do linalg since there's no all for dask.array.linalg.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See how it's done for cupy.linalg, which also doesn't define __all__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be updated as well.

SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
#svd = get_xp(da)(_linalg.svd)
Copy link
Member

@asmeurer asmeurer Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you get rid of all the little commented out bits of code like this. Or if one of them actually is useful to keep, add an explanatory comment above it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, had some code left over from debugging.

Everything should be cleaned up now. 🤞

@asmeurer
Copy link
Member

Should we also see if any dask devs want to take a look here?

We should also see if the dask devs are interested in having upstream issues for the existing holes here.

@lithomas1 if any upstream issues do exist, or are opened, please cross-reference them in the xfails file (and the code if relevant).

@lithomas1
Copy link
Contributor Author

Should we also see if any dask devs want to take a look here?

We should also see if the dask devs are interested in having upstream issues for the existing holes here.

Definitely. I'll try to get followups opened up on the dask side after this is merged (so I don't have to go back and forth if new tests fail/pass).

@asmeurer
Copy link
Member

asmeurer commented Feb 6, 2024

Alright, I'm going to merge this. I'll make a release with this and if any further issues come up we can fix them up later. I don't think anyone will test this unless it is released.

@asmeurer asmeurer merged commit 9cb5a13 into data-apis:main Feb 6, 2024
26 checks passed
@lithomas1 lithomas1 deleted the add-dask branch February 6, 2024 21:37
@asmeurer asmeurer mentioned this pull request Feb 6, 2024
@asmeurer
Copy link
Member

asmeurer commented Feb 6, 2024

@lithomas1 I realized there was one important thing missing here, and that's documentation. Specifically, the README needs to be updated. I already added "dask" to a couple of sentences in #84, but it would be really helpful if you could make a PR to add a Dask section to the "known differences" section of the README.

(and yes I know we need real docs. That's on my TODO list)

@asmeurer
Copy link
Member

asmeurer commented Feb 6, 2024

It also looks like the device helper doesn't work for dask. I added a test at #84 but XFAILed it. Can you make a PR fixing this?

@lithomas1
Copy link
Contributor Author

Will do. Feel free to put in the changes from #84 first (It'll probably be easier for me to rebase on top of that).

@asmeurer
Copy link
Member

asmeurer commented Feb 8, 2024

@lithomas1 I just merged #82 which updates some of your dask code. Can you make a PR fixing these dask issues? There are also several test failures on CI related to dask. #84 shouldn't be an issue. Just add support for device but don't worry about testing it.

@lithomas1
Copy link
Contributor Author

Will do, was slightly busy yesterday, but should have both PRs up tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants