Skip to content

Commit

Permalink
Merge pull request #191 from asmeurer/2023.12
Browse files Browse the repository at this point in the history
Update __array_api_version__ to 2023.12
  • Loading branch information
asmeurer authored Oct 29, 2024
2 parents 522a608 + 273d54e commit d7b4111
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 16 deletions.
13 changes: 7 additions & 6 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def is_torch_namespace(xp) -> bool:
is_array_api_strict_namespace
"""
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp):
"""
Expand Down Expand Up @@ -415,10 +415,11 @@ def is_array_api_strict_namespace(xp):
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
elif api_version is not None and api_version != '2022.12':
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
if api_version in ['2021.12', '2022.12']:
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
elif api_version is not None and api_version not in ['2021.12', '2022.12',
'2023.12']:
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")

def array_namespace(*xs, api_version=None, use_compat=None):
"""
Expand All @@ -431,7 +432,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
api_version: str
The newest version of the spec that you need support for (currently
the compat library wrapped APIs support v2022.12).
the compat library wrapped APIs support v2023.12).
use_compat: bool or None
If None (the default), the native namespace will be returned if it is
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

from ..common._helpers import * # noqa: F401,F403

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'

__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
2 changes: 1 addition & 1 deletion array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
except ImportError:
pass

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'
2 changes: 1 addition & 1 deletion array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@

from ..common._helpers import * # noqa: F403

__array_api_version__ = '2022.12'
__array_api_version__ = '2023.12'
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ each array library itself fully compatible with the array API, but this
requires making backwards incompatible changes in many cases, so this will
take some time.

Currently all libraries here are implemented against the [2022.12
version](https://data-apis.org/array-api/2022.12/) of the standard.
Currently all libraries here are implemented against the [2023.12
version](https://data-apis.org/array-api/2023.12/) of the standard.

## Installation

Expand Down
14 changes: 10 additions & 4 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ._helpers import import_, all_libraries, wrapped_libraries

@pytest.mark.parametrize("use_compat", [True, False, None])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
def test_array_namespace(library, api_version, use_compat):
xp = import_(library)
Expand Down Expand Up @@ -94,14 +94,20 @@ def test_array_namespace_errors_torch():
def test_api_version():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2022.12") == torch_
assert array_namespace(x, api_version="2023.12") == torch_
assert array_namespace(x, api_version=None) == torch_
assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
assert len(w) == 1
assert "2021.12" in str(w[0].message)

# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2022.12") == torch_
assert len(w) == 1
assert "2022.12" in str(w[0].message)

pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))

Expand Down

0 comments on commit d7b4111

Please sign in to comment.