Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fixes for 2023.12 support #166

Merged
merged 29 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0734064
Remove floating-point promotion from sum, prod, and trace
asmeurer Jul 29, 2024
25da0e4
Wrap trace for dask
asmeurer Jul 29, 2024
e383441
Wrap hypot in torch
asmeurer Jul 29, 2024
5eb7abf
Add unstack to all wrapped libraries
asmeurer Jul 29, 2024
e70bcc8
Run the test suite against the 2023 version of the standard on CI
asmeurer Jul 29, 2024
d751db6
Update xfails for failing 2023 tests
asmeurer Jul 30, 2024
b96e84b
Merge branch 'main' into more-2023
asmeurer Aug 7, 2024
5f8b5d6
Improvements to the clip wrapper
asmeurer Aug 12, 2024
2995a0f
Remove clip from dask xfails
asmeurer Aug 12, 2024
284bd99
Update numpy 1.21 xfails
asmeurer Aug 13, 2024
11cb6ef
Add NumPy inspection namespace
asmeurer Aug 19, 2024
4c9dd0e
Add CuPy inspection APIs
asmeurer Aug 19, 2024
8e3f0b6
Merge branch 'main' into more-2023
asmeurer Sep 11, 2024
d3c4b3c
Add inspection namespace for torch
asmeurer Sep 11, 2024
be4fa68
Handle torch versions that do not have uint dtypes
asmeurer Sep 11, 2024
774d175
Remove uint16, uint32, and uint64 from the pytorch dtypes() output
asmeurer Sep 12, 2024
231ef95
Add inspection namespace for dask
asmeurer Sep 12, 2024
92ebdff
Add cumulative_sum wrapper for the numpy-likes
asmeurer Sep 18, 2024
176a66a
Remove a bunch of 2023.12 xfails
asmeurer Sep 18, 2024
46a2227
Hard-code the default torch integral type to int64
asmeurer Sep 24, 2024
ab822f5
Ignore bare except in ruff checks
asmeurer Sep 24, 2024
cb9acd4
Add a comment
asmeurer Sep 24, 2024
c0dd5b0
Add cumulative_sum to torch
asmeurer Sep 24, 2024
1b47d96
Fix the tests
asmeurer Sep 30, 2024
e472dcb
Add repeat to the torch xfails
asmeurer Sep 30, 2024
470e41a
Update torch xfails
asmeurer Sep 30, 2024
198d6d7
Update numpy dev xfails
asmeurer Sep 30, 2024
ef7ad7a
Add maximum and minimum torch wrappers (for fixed two-arg type promot…
asmeurer Sep 30, 2024
9d2e283
Add dlpack xfails to numpy-dev
asmeurer Sep 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jobs:
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
env:
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
ARRAY_API_TESTS_VERSION: 2023.12
# This enables the NEP 50 type promotion behavior (without it a lot of
# tests fail on bad scalar type promotion behavior)
NPY_PROMOTION_STATE: weak
Expand Down
102 changes: 56 additions & 46 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import NamedTuple
import inspect

from ._helpers import array_namespace, _check_device
from ._helpers import array_namespace, _check_device, device, is_torch_array

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -264,6 +264,38 @@ def var(
) -> ndarray:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)

# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
# argument

def cumulative_sum(
x: ndarray,
/,
xp,
*,
axis: Optional[int] = None,
dtype: Optional[Dtype] = None,
include_initial: bool = False,
**kwargs
) -> ndarray:
wrapped_xp = array_namespace(x)

# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
axis = 0

res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)

# np.cumsum does not support include_initial
if include_initial:
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
axis=axis,
)
return res

# The min and max argument names in clip are different and not optional in numpy, and type
# promotion behavior is different.
Expand All @@ -281,10 +313,11 @@ def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)

wrapped_xp = array_namespace(x)

result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)

# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
Expand All @@ -305,20 +338,26 @@ def _isscalar(a):

# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
if type(min) is int and min <= xp.iinfo(x.dtype).min:
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
min = None
if type(max) is int and max >= xp.iinfo(x.dtype).max:
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
max = None

if out is None:
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
copy=True, device=device(x))
if min is not None:
a = xp.broadcast_to(xp.asarray(min), result_shape)
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
# Avoid loss of precision due to torch defaulting to float32
min = wrapped_xp.asarray(min, dtype=xp.float64)
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
ia = (out < a) | xp.isnan(a)
# torch requires an explicit cast here
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
if max is not None:
b = xp.broadcast_to(xp.asarray(max), result_shape)
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
max = wrapped_xp.asarray(max, dtype=xp.float64)
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
ib = (out > b) | xp.isnan(b)
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
# Return a scalar for 0-D
Expand Down Expand Up @@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)

# sum() and prod() should always upcast when dtype=None
def sum(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
# `xp.sum` already upcasts integers, but not floats or complexes
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)

def prod(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)

# ceil, floor, and trunc return integers for integer inputs

def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
Expand Down Expand Up @@ -521,10 +524,17 @@ def isdtype(
# array_api_strict implementation will be very strict.
return dtype == kind

# unstack is a new function in the 2023.12 array API standard
def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
if x.ndim == 0:
raise ValueError("Input array must be at least 1-d.")
return tuple(xp.moveaxis(x, axis, 0))

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack']
5 changes: 0 additions & 5 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)

def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
Expand Down
20 changes: 14 additions & 6 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ..common import _aliases
from .._internal import get_xp

from ._info import __array_namespace_info__

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
Expand Down Expand Up @@ -47,14 +49,13 @@
astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
clip = get_xp(cp)(_aliases.clip)
permute_dims = get_xp(cp)(_aliases.permute_dims)
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
sum = get_xp(cp)(_aliases.sum)
prod = get_xp(cp)(_aliases.prod)
ceil = get_xp(cp)(_aliases.ceil)
floor = get_xp(cp)(_aliases.floor)
trunc = get_xp(cp)(_aliases.trunc)
Expand Down Expand Up @@ -121,14 +122,21 @@ def sign(x: ndarray, /) -> ndarray:
vecdot = cp.vecdot
else:
vecdot = get_xp(cp)(_aliases.vecdot)

if hasattr(cp, 'isdtype'):
isdtype = cp.isdtype
else:
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'sign']
if hasattr(cp, 'unstack'):
unstack = cp.unstack
else:
unstack = get_xp(cp)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow', 'sign']

_all_ignore = ['cp', 'get_xp']
Loading
Loading