-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved duck array wrapping #9798
Changes from 5 commits
fd6b339
893408c
f7866ce
90037fe
5ba1a2f
6225ae3
e2911c2
2ac37f9
1cc344b
69080a5
372439c
0eef2cb
6739504
9e6d6f8
e721011
1fe4131
205c199
7752088
c8d4e5e
e67a819
f306768
18ebdcd
f51e3fb
121af9e
472ae7e
5aa4a39
390df6f
f6074d2
561f21b
bfd6aeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
import numpy as np | ||
|
||
from xarray.namedarray.pycompat import array_type | ||
|
||
|
||
def is_weak_scalar_type(t): | ||
return isinstance(t, bool | int | float | complex | str | bytes) | ||
|
@@ -42,3 +44,29 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: | |
return xp.result_type(*arrays_and_dtypes) | ||
else: | ||
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) | ||
|
||
|
||
def get_array_namespace(*values): | ||
def _get_single_namespace(x): | ||
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
elif isinstance(x, array_type("cupy")): | ||
# special case cupy for now | ||
import cupy as cp | ||
|
||
return cp | ||
else: | ||
return np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you probably want to wrap array-api-compat's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to go that route, I did actually try that but array-api-compat doesn't handle a bunch of things we end up passing through this (scalars, index wrappers, etc) so it would require some careful prefiltering. The only things this package effectively wraps that don't have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Support for scalars was just merged half an hour ago! data-apis/array-api-compat#147 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this would handle the array-like check well, but would require adding this as a core xarray dependency to use it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In SciPy we vendor array-api-compat via a Git submodule. There's a little bit of build system bookkeeping needed, but otherwise it works well without introducing a dependency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I played around with this some more. Getting an object that is compliant from the perspective of I was able to get things to the point that xarray can wrap It seems torch has had little movement on supporting the standard since 2021. From xarray's perspective, So while this compat module is nice in theory, I don't think it's very useful for xarray. cupy, jax, and sparse are good to go without it, and we only need a single special case for cupy to fetch its |
||
|
||
namespaces = {_get_single_namespace(t) for t in values} | ||
non_numpy = namespaces - {np} | ||
|
||
if len(non_numpy) > 1: | ||
names = [module.__name__ for module in non_numpy] | ||
raise TypeError(f"Mixed array types {names} are not supported.") | ||
elif non_numpy: | ||
[xp] = non_numpy | ||
else: | ||
xp = np | ||
|
||
return xp |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,22 +18,17 @@ | |
import pandas as pd | ||
from numpy import all as array_all # noqa: F401 | ||
from numpy import any as array_any # noqa: F401 | ||
from numpy import concatenate as _concatenate | ||
from numpy import ( # noqa: F401 | ||
full_like, | ||
gradient, | ||
isclose, | ||
isin, | ||
isnat, | ||
take, | ||
tensordot, | ||
transpose, | ||
unravel_index, | ||
) | ||
from packaging.version import Version | ||
from pandas.api.types import is_extension_array_dtype | ||
|
||
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils | ||
from xarray.core.array_api_compat import get_array_namespace | ||
from xarray.core.options import OPTIONS | ||
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available | ||
from xarray.namedarray import pycompat | ||
|
@@ -54,28 +49,6 @@ | |
dask_available = module_available("dask") | ||
|
||
|
||
def get_array_namespace(*values): | ||
def _get_array_namespace(x): | ||
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
else: | ||
return np | ||
|
||
namespaces = {_get_array_namespace(t) for t in values} | ||
non_numpy = namespaces - {np} | ||
|
||
if len(non_numpy) > 1: | ||
raise TypeError( | ||
"cannot deal with more than one type supporting the array API at the same time" | ||
) | ||
elif non_numpy: | ||
[xp] = non_numpy | ||
else: | ||
xp = np | ||
|
||
return xp | ||
|
||
|
||
def einsum(*args, **kwargs): | ||
from xarray.core.options import OPTIONS | ||
|
||
|
@@ -84,7 +57,23 @@ def einsum(*args, **kwargs): | |
|
||
return opt_einsum.contract(*args, **kwargs) | ||
else: | ||
return np.einsum(*args, **kwargs) | ||
xp = get_array_namespace(*args) | ||
return xp.einsum(*args, **kwargs) | ||
|
||
|
||
def tensordot(*args, **kwargs): | ||
xp = get_array_namespace(*args) | ||
return xp.tensordot(*args, **kwargs) | ||
|
||
|
||
def cross(*args, **kwargs): | ||
xp = get_array_namespace(*args) | ||
return xp.cross(*args, **kwargs) | ||
|
||
|
||
def gradient(f, *varargs, axis=None, edge_order=1): | ||
xp = get_array_namespace(f) | ||
return xp.gradient(f, *varargs, axis=axis, edge_order=edge_order) | ||
|
||
|
||
def _dask_or_eager_func( | ||
|
@@ -133,15 +122,20 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): | |
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" | ||
) | ||
|
||
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), | ||
# so we need to hand-code this. | ||
sliding_window_view = _dask_or_eager_func( | ||
"sliding_window_view", | ||
eager_module=np.lib.stride_tricks, | ||
dask_module=dask_array_compat, | ||
dask_only_kwargs=("automatic_rechunk",), | ||
numpy_only_kwargs=("subok", "writeable"), | ||
) | ||
|
||
def sliding_window_view(array, window_shape, axis=None, **kwargs): | ||
# TODO: some libraries (e.g. jax) don't have this, implement an alternative? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is one of the biggest outstanding bummers of wrapping jax arrays. There is apparently openness to adding this as an API (even though it would not offer any performance benefit in XLA). But given this is way outside the API standard, whether it makes sense to implement a general version within xarray that doesn't rely on stride tricks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could implement a version using "summed area tables" (basically run a single accumulator and then compute differences between the window edges); or convolutions I guess. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have something that works pretty well with this style of gather operation. But only in a |
||
xp = get_array_namespace(array) | ||
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), | ||
# so we need to hand-code this. | ||
func = _dask_or_eager_func( | ||
"sliding_window_view", | ||
eager_module=xp.lib.stride_tricks, | ||
dask_module=dask_array_compat, | ||
dask_only_kwargs=("automatic_rechunk",), | ||
numpy_only_kwargs=("subok", "writeable"), | ||
) | ||
return func(array, window_shape, axis=axis, **kwargs) | ||
|
||
|
||
def round(array): | ||
|
@@ -174,7 +168,7 @@ def isnull(data): | |
) | ||
): | ||
# these types cannot represent missing values | ||
return full_like(data, dtype=bool, fill_value=False) | ||
return full_like(data, dtype=xp.bool, fill_value=False) | ||
else: | ||
# at this point, array should have dtype=object | ||
if isinstance(data, np.ndarray) or is_extension_array_dtype(data): | ||
|
@@ -215,11 +209,23 @@ def cumulative_trapezoid(y, x, axis): | |
|
||
# Pad so that 'axis' has same length in result as it did in y | ||
pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] | ||
integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) | ||
|
||
xp = get_array_namespace(y, x) | ||
integrand = xp.pad(integrand, pads, mode="constant", constant_values=0.0) | ||
|
||
return cumsum(integrand, axis=axis, skipna=False) | ||
|
||
|
||
def full_like(a, fill_value, **kwargs): | ||
xp = get_array_namespace(a) | ||
return xp.full_like(a, fill_value, **kwargs) | ||
|
||
|
||
def empty_like(a, **kwargs): | ||
xp = get_array_namespace(a) | ||
return xp.empty_like(a, **kwargs) | ||
|
||
|
||
def astype(data, dtype, **kwargs): | ||
if hasattr(data, "__array_namespace__"): | ||
xp = get_array_namespace(data) | ||
|
@@ -350,7 +356,8 @@ def array_notnull_equiv(arr1, arr2): | |
|
||
def count(data, axis=None): | ||
"""Count the number of non-NA in this array along the given axis or axes""" | ||
return np.sum(np.logical_not(isnull(data)), axis=axis) | ||
xp = get_array_namespace(data) | ||
return xp.sum(xp.logical_not(isnull(data)), axis=axis) | ||
|
||
|
||
def sum_where(data, axis=None, dtype=None, where=None): | ||
|
@@ -365,7 +372,7 @@ def sum_where(data, axis=None, dtype=None, where=None): | |
|
||
def where(condition, x, y): | ||
"""Three argument where() with better dtype promotion rules.""" | ||
xp = get_array_namespace(condition) | ||
xp = get_array_namespace(condition, x, y) | ||
return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) | ||
|
||
|
||
|
@@ -382,15 +389,25 @@ def fillna(data, other): | |
return where(notnull(data), data, other) | ||
|
||
|
||
def logical_not(data): | ||
xp = get_array_namespace(data) | ||
return xp.logical_not(data) | ||
|
||
|
||
def clip(data, min=None, max=None): | ||
xp = get_array_namespace(data) | ||
return xp.clip(data, min, max) | ||
|
||
|
||
def concatenate(arrays, axis=0): | ||
"""concatenate() with better dtype promotion rules.""" | ||
# TODO: remove the additional check once `numpy` adds `concat` to its array namespace | ||
if hasattr(arrays[0], "__array_namespace__") and not isinstance( | ||
arrays[0], np.ndarray | ||
): | ||
xp = get_array_namespace(arrays[0]) | ||
# TODO: `concat` is the xp compliant name, but fallback to concatenate for | ||
# older numpy and for cupy | ||
xp = get_array_namespace(*arrays) | ||
if hasattr(xp, "concat"): | ||
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) | ||
return _concatenate(as_shared_dtype(arrays), axis=axis) | ||
else: | ||
return xp.concatenate(as_shared_dtype(arrays, xp=xp), axis=axis) | ||
|
||
|
||
def stack(arrays, axis=0): | ||
|
@@ -408,6 +425,26 @@ def ravel(array): | |
return reshape(array, (-1,)) | ||
|
||
|
||
def transpose(array, axes=None): | ||
xp = get_array_namespace(array) | ||
return xp.transpose(array, axes) | ||
|
||
|
||
def moveaxis(array, source, destination): | ||
xp = get_array_namespace(array) | ||
return xp.moveaxis(array, source, destination) | ||
|
||
|
||
def pad(array, pad_width, **kwargs): | ||
xp = get_array_namespace(array) | ||
return xp.pad(array, pad_width, **kwargs) | ||
|
||
|
||
def quantile(array, q, axis=None, **kwargs): | ||
xp = get_array_namespace(array) | ||
return xp.quantile(array, q, axis=axis, **kwargs) | ||
|
||
|
||
@contextlib.contextmanager | ||
def _ignore_warnings_if(condition): | ||
if condition: | ||
|
@@ -749,6 +786,11 @@ def last(values, axis, skipna=None): | |
return take(values, -1, axis=axis) | ||
|
||
|
||
def isin(element, test_elements, **kwargs): | ||
xp = get_array_namespace(element, test_elements) | ||
return xp.isin(element, test_elements, **kwargs) | ||
|
||
|
||
def least_squares(lhs, rhs, rcond=None, skipna=False): | ||
"""Return the coefficients and residuals of a least-squares fit.""" | ||
if is_duck_dask_array(rhs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cupy seems to have full compliance with the standard, but doesn't yet actually have
__array_namespace__
on the core API. Others may be the same?