Skip to content

Commit

Permalink
Merge pull request #166 from asmeurer/more-2023
Browse files Browse the repository at this point in the history
More fixes for 2023.12 support
  • Loading branch information
asmeurer authored Sep 30, 2024
2 parents 6f9edc7 + 9d2e283 commit b30a59e
Show file tree
Hide file tree
Showing 20 changed files with 1,591 additions and 94 deletions.
1 change: 1 addition & 0 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jobs:
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
env:
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
ARRAY_API_TESTS_VERSION: 2023.12
# This enables the NEP 50 type promotion behavior (without it a lot of
# tests fail on bad scalar type promotion behavior)
NPY_PROMOTION_STATE: weak
Expand Down
102 changes: 56 additions & 46 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import NamedTuple
import inspect

from ._helpers import array_namespace, _check_device
from ._helpers import array_namespace, _check_device, device, is_torch_array

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -264,6 +264,38 @@ def var(
) -> ndarray:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)

# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
# argument

def cumulative_sum(
x: ndarray,
/,
xp,
*,
axis: Optional[int] = None,
dtype: Optional[Dtype] = None,
include_initial: bool = False,
**kwargs
) -> ndarray:
wrapped_xp = array_namespace(x)

# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
axis = 0

res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)

# np.cumsum does not support include_initial
if include_initial:
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
axis=axis,
)
return res

# The min and max argument names in clip are different and not optional in numpy, and type
# promotion behavior is different.
Expand All @@ -281,10 +313,11 @@ def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)

wrapped_xp = array_namespace(x)

result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)

# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
Expand All @@ -305,20 +338,26 @@ def _isscalar(a):

# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
if type(min) is int and min <= xp.iinfo(x.dtype).min:
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
min = None
if type(max) is int and max >= xp.iinfo(x.dtype).max:
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
max = None

if out is None:
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
copy=True, device=device(x))
if min is not None:
a = xp.broadcast_to(xp.asarray(min), result_shape)
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
# Avoid loss of precision due to torch defaulting to float32
min = wrapped_xp.asarray(min, dtype=xp.float64)
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
ia = (out < a) | xp.isnan(a)
# torch requires an explicit cast here
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
if max is not None:
b = xp.broadcast_to(xp.asarray(max), result_shape)
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
max = wrapped_xp.asarray(max, dtype=xp.float64)
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
ib = (out > b) | xp.isnan(b)
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
# Return a scalar for 0-D
Expand Down Expand Up @@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)

# sum() and prod() should always upcast when dtype=None
def sum(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
# `xp.sum` already upcasts integers, but not floats or complexes
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)

def prod(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)

# ceil, floor, and trunc return integers for integer inputs

def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
Expand Down Expand Up @@ -521,10 +524,17 @@ def isdtype(
# array_api_strict implementation will be very strict.
return dtype == kind

# unstack is a new function in the 2023.12 array API standard
def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
if x.ndim == 0:
raise ValueError("Input array must be at least 1-d.")
return tuple(xp.moveaxis(x, axis, 0))

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
'unstack']
5 changes: 0 additions & 5 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)

def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
Expand Down
20 changes: 14 additions & 6 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ..common import _aliases
from .._internal import get_xp

from ._info import __array_namespace_info__

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
Expand Down Expand Up @@ -47,14 +49,13 @@
astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
clip = get_xp(cp)(_aliases.clip)
permute_dims = get_xp(cp)(_aliases.permute_dims)
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
sum = get_xp(cp)(_aliases.sum)
prod = get_xp(cp)(_aliases.prod)
ceil = get_xp(cp)(_aliases.ceil)
floor = get_xp(cp)(_aliases.floor)
trunc = get_xp(cp)(_aliases.trunc)
Expand Down Expand Up @@ -121,14 +122,21 @@ def sign(x: ndarray, /) -> ndarray:
vecdot = cp.vecdot
else:
vecdot = get_xp(cp)(_aliases.vecdot)

if hasattr(cp, 'isdtype'):
isdtype = cp.isdtype
else:
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'sign']
if hasattr(cp, 'unstack'):
unstack = cp.unstack
else:
unstack = get_xp(cp)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow', 'sign']

_all_ignore = ['cp', 'get_xp']
Loading

0 comments on commit b30a59e

Please sign in to comment.