Skip to content

Commit

Permalink
Merge pull request #82 from asmeurer/2024-draft
Browse files Browse the repository at this point in the history
Add preliminary support for the draft 2024.12 version of the standard
  • Loading branch information
asmeurer authored Nov 11, 2024
2 parents 6afcfe1 + 61b3c90 commit 838f7f4
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 134 deletions.
12 changes: 8 additions & 4 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@
minimum,
multiply,
negative,
nextafter,
not_equal,
positive,
pow,
real,
reciprocal,
remainder,
round,
sign,
Expand Down Expand Up @@ -240,10 +242,12 @@
"minimum",
"multiply",
"negative",
"nextafter",
"not_equal",
"positive",
"pow",
"real",
"reciprocal",
"remainder",
"round",
"sign",
Expand All @@ -258,9 +262,9 @@
"trunc",
]

from ._indexing_functions import take
from ._indexing_functions import take, take_along_axis

__all__ += ["take"]
__all__ += ["take", "take_along_axis"]

from ._info import __array_namespace_info__

Expand Down Expand Up @@ -305,9 +309,9 @@

__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]

from ._utility_functions import all, any
from ._utility_functions import all, any, diff

__all__ += ["all", "any"]
__all__ += ["all", "any", "diff"]

from ._array_object import Device
__all__ += ["Device"]
Expand Down
25 changes: 25 additions & 0 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,20 @@ def negative(x: Array, /) -> Array:
return Array._new(np.negative(x._array), device=x.device)


@requires_api_version('2024.12')
def nextafter(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.nextafter <numpy.nextafter>`.
See its docstring for more information.
"""
if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in nextafter")
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.nextafter(x1._array, x2._array), device=x1.device)

def not_equal(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
Expand Down Expand Up @@ -858,6 +872,17 @@ def real(x: Array, /) -> Array:
return Array._new(np.real(x._array), device=x.device)


@requires_api_version('2024.12')
def reciprocal(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reciprocal <numpy.reciprocal>`.
See its docstring for more information.
"""
if x.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in reciprocal")
return Array._new(np.reciprocal(x._array), device=x.device)

def remainder(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
Expand Down
11 changes: 8 additions & 3 deletions array_api_strict/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"2023.12",
)

draft_version = "2024.12"

API_VERSION = default_version = "2023.12"

BOOLEAN_INDEXING = True
Expand Down Expand Up @@ -70,8 +72,8 @@ def set_array_api_strict_flags(
----------
api_version : str, optional
The version of the standard to use. Supported versions are:
``{supported_versions}``. The default version number is
``{default_version!r}``.
``{supported_versions}``, plus the draft version ``{draft_version}``.
The default version number is ``{default_version!r}``.
Note that 2021.12 is supported, but currently gives the same thing as
2022.12 (except that the fft extension will be disabled).
Expand Down Expand Up @@ -134,10 +136,12 @@ def set_array_api_strict_flags(
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS

if api_version is not None:
if api_version not in supported_versions:
if api_version not in [*supported_versions, draft_version]:
raise ValueError(f"Unsupported standard version {api_version!r}")
if api_version == "2021.12":
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2)
if api_version == draft_version:
warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.")
API_VERSION = api_version
array_api_strict.__array_api_version__ = API_VERSION

Expand Down Expand Up @@ -169,6 +173,7 @@ def set_array_api_strict_flags(
supported_versions=supported_versions,
default_version=default_version,
default_extensions=default_extensions,
draft_version=draft_version,
)

def get_array_api_strict_flags():
Expand Down
12 changes: 12 additions & 0 deletions array_api_strict/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ._array_object import Array
from ._dtypes import _integer_dtypes
from ._flags import requires_api_version

from typing import TYPE_CHECKING

Expand All @@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
if x.device != indices.device:
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device)

@requires_api_version('2024.12')
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
"""
Array API compatible wrapper for :py:func:`np.take_along_axis <numpy.take_along_axis>`.
See its docstring for more information.
"""
if x.device != indices.device:
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
return Array._new(np.take_along_axis(x._array, indices._array, axis), device=x.device)
Loading

0 comments on commit 838f7f4

Please sign in to comment.