From 09dc648621974f9b84f0eed3b9f32f9fdbec045d Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 9 Dec 2024 13:40:59 +0000 Subject: [PATCH] MAINT: depend on array-api-compat (#47) * Add dependency to array-api-compat * Code review * revert README --- .github/workflows/test-vendor.yml | 55 +++++++ .gitignore | 4 + docs/index.md | 47 +++++- pixi.lock | 35 ++++- pyproject.toml | 4 +- src/array_api_extra/_funcs.py | 69 ++++++--- src/array_api_extra/_lib/_compat.py | 183 +++-------------------- src/array_api_extra/_lib/_compat.pyi | 13 ++ src/array_api_extra/_lib/_utils.py | 9 +- tests/test_funcs.py | 167 +++++++++++++-------- tests/test_utils.py | 11 +- vendor_tests/__init__.py | 1 + vendor_tests/_array_api_compat_vendor.py | 11 ++ vendor_tests/test_vendor.py | 26 ++++ 14 files changed, 374 insertions(+), 261 deletions(-) create mode 100644 .github/workflows/test-vendor.yml create mode 100644 src/array_api_extra/_lib/_compat.pyi create mode 100644 vendor_tests/__init__.py create mode 100644 vendor_tests/_array_api_compat_vendor.py create mode 100644 vendor_tests/test_vendor.py diff --git a/.github/workflows/test-vendor.yml b/.github/workflows/test-vendor.yml new file mode 100644 index 0000000..20be389 --- /dev/null +++ b/.github/workflows/test-vendor.yml @@ -0,0 +1,55 @@ +name: Test vendoring support + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + # Many color libraries just need this to be set to any value, but at least + # one distinguishes color depth, where "3" -> "256-bit color". + FORCE_COLOR: 3 + +jobs: + pre-commit-and-lint: + name: Format + runs-on: ubuntu-latest + steps: + - name: Checkout array-api-extra + uses: actions/checkout@v4 + with: + path: array-api-extra + + - name: Checkout array-api-compat + uses: actions/checkout@v4 + with: + repository: data-apis/array-api-compat + path: array-api-compat + + - name: Vendor array-api-extra into test package + run: | + cp -a array-api-compat/array_api_compat array-api-extra/vendor_tests/ + cp -a array-api-extra/src/array_api_extra array-api-extra/vendor_tests/ + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Install Pixi + uses: prefix-dev/setup-pixi@v0.8.1 + with: + pixi-version: v0.39.0 + manifest-path: array-api-extra/pyproject.toml + cache: true + + - name: Test package + run: | + cd array-api-extra/ + pixi run --environment tests tests-vendor diff --git a/.gitignore b/.gitignore index 8937748..123ce99 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,7 @@ Thumbs.db # pixi environments .pixi *.egg-info + +# Vendor tests +vendor_tests/array_api_compat/ +vendor_tests/array_api_extra/ diff --git a/docs/index.md b/docs/index.md index 11c2916..fece21b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -51,16 +51,57 @@ If you require stability, it is recommended to pin `array-api-extra` to a specific version, or vendor the library inside your own. ``` +```{note} +This library depends on array-api-compat. We aim for compatibility with +the latest released version of array-api-compat, and your mileage may vary +with older or dev versions. +``` + (vendoring)= ## Vendoring To vendor the library, clone -[the repository](https://github.com/data-apis/array-api-extra) and copy it into -the appropriate place in your library, like: +[the array-api-extra repository](https://github.com/data-apis/array-api-extra) +and copy it into the appropriate place in your library, like: ``` -cp -R array-api-extra/ mylib/vendored/array_api_extra +cp -a array-api-extra/src/array_api_extra mylib/vendored/ +``` + +`array-api-extra` depends on `array-api-compat`. You may either add a dependency +in your own project to `array-api-compat` or vendor it too: + +1. Clone + [the array-api-compat repository](https://github.com/data-apis/array-api-compat) + and copy it next to your vendored array-api-extra: + + ``` + cp -a array-api-compat/array_api_compat mylib/vendored/ + ``` + +2. Create a new hook file which array-api-extra will use instead of the + top-level `array-api-compat` if present: + + ``` + echo 'from mylib.vendored.array_api_compat import *' > mylib/vendored/_array_api_compat_vendor.py + ``` + +This also allows overriding `array-api-compat` functions if you so wish. E.g. +your `mylib/vendored/_array_api_compat_vendor.py` could look like this: + +```python +from mylib.vendored.array_api_compat import * +from mylib.vendored.array_api_compat import array_namespace as _array_namespace_orig + + +def array_namespace(*xs, **kwargs): + import mylib + + if any(isinstance(x, mylib.MyArray) for x in xs): + return mylib + else: + return _array_namespace_orig(*xs, **kwargs) ``` (usage)= diff --git a/pixi.lock b/pixi.lock index 277959b..bac0a75 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9,6 +9,7 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda @@ -50,6 +51,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . osx-arm64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda @@ -85,6 +87,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . win-64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda @@ -132,6 +135,7 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda @@ -173,6 +177,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . osx-arm64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda @@ -210,6 +215,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . win-64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda @@ -259,6 +265,7 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda @@ -281,6 +288,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . osx-arm64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda @@ -298,6 +306,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . win-64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda - conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda @@ -326,6 +335,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.5-py313h78bf25f_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -452,6 +462,7 @@ environments: - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.5-py313h8f79df9_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -573,6 +584,7 @@ environments: - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.5-py313hfa70ccb_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -702,6 +714,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/linux-64/brotli-python-1.1.0-py313h46c70d0_2.conda @@ -771,6 +784,7 @@ environments: - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/brotli-python-1.1.0-py313h3579c5c_2.conda @@ -834,6 +848,7 @@ environments: - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/win-64/brotli-python-1.1.0-py313h5813708_2.conda @@ -906,6 +921,7 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.5-py313h78bf25f_0.conda - conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda @@ -976,6 +992,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - pypi: . osx-arm64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.5-py313h8f79df9_0.conda - conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda @@ -1041,6 +1058,7 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - pypi: . win-64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.5-py313hfa70ccb_0.conda - conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda @@ -1115,6 +1133,7 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda @@ -1156,6 +1175,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . osx-arm64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda @@ -1193,6 +1213,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda - pypi: . win-64: + - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda @@ -1266,11 +1287,23 @@ packages: - pkg:pypi/alabaster?source=hash-mapping size: 18522 timestamp: 1722035895436 +- conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda + sha256: 32689f25dd97965043a5ca8a07ae3a9c27278258a16e574b0705bdca7656feff + md5: f2328337441baa8f669d2a830cfd0097 + depends: + - python >=3.8 + license: MIT + license_family: MIT + purls: + - pkg:pypi/array-api-compat?source=hash-mapping + size: 38213 + timestamp: 1730293860305 - pypi: . name: array-api-extra version: 0.3.3 - sha256: 8f949b727c03da7c3dff8d6ffab9361f273ea2a81a30296f0474707aaad1b227 + sha256: da9c3302c24283e43a78bee49b3a08bb2840f07e9b04159f4d3db3601798df5c requires_dist: + - array-api-compat>=1.1.1 - furo>=2023.8.17 ; extra == 'docs' - myst-parser>=0.13 ; extra == 'docs' - sphinx-autodoc-typehints ; extra == 'docs' diff --git a/pyproject.toml b/pyproject.toml index 286eae9..da67405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = [] +dependencies = ["array-api-compat>=1.1.1"] [project.optional-dependencies] tests = [ @@ -64,6 +64,7 @@ platforms = ["linux-64", "osx-arm64", "win-64"] [tool.pixi.dependencies] python = ">=3.10.15,<3.14" +array-api-compat = ">=1.1.1" [tool.pixi.pypi-dependencies] array-api-extra = { path = ".", editable = true } @@ -96,6 +97,7 @@ numpy = "*" [tool.pixi.feature.tests.tasks] tests = { cmd = "pytest" } tests-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --durations=20" } +tests-vendor = { cmd = "pytest vendor_tests" } [tool.pixi.feature.docs.dependencies] sphinx = ">=7.0" diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index c19ecb7..7a9ba40 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -7,6 +7,7 @@ from ._lib._typing import Array, ModuleType from ._lib import _utils +from ._lib._compat import array_namespace __all__ = [ "atleast_nd", @@ -19,7 +20,7 @@ ] -def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: """ Recursively expand the dimension of an array to at least `ndim`. @@ -28,8 +29,8 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: x : array ndim : int The minimum number of dimensions for the result. - xp : array_namespace - The standard-compatible namespace for `x`. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer Returns ------- @@ -53,13 +54,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: True """ + if xp is None: + xp = array_namespace(x) + if x.ndim < ndim: x = xp.expand_dims(x, axis=0) x = atleast_nd(x, ndim=ndim, xp=xp) return x -def cov(m: Array, /, *, xp: ModuleType) -> Array: +def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: """ Estimate a covariance matrix. @@ -77,8 +81,8 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, and each column a single observation of all those variables. - xp : array_namespace - The standard-compatible namespace for `m`. + xp : array_namespace, optional + The standard-compatible namespace for `m`. Default: infer Returns ------- @@ -125,6 +129,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: Array(2.14413333, dtype=array_api_strict.float64) """ + if xp is None: + xp = array_namespace(m) + m = xp.asarray(m, copy=True) dtype = ( xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64) @@ -150,7 +157,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: return xp.squeeze(c, axis=axes) -def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: +def create_diagonal( + x: Array, /, *, offset: int = 0, xp: ModuleType | None = None +) -> Array: """ Construct a diagonal array. @@ -162,8 +171,8 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: Offset from the leading diagonal (default is ``0``). Use positive ints for diagonals above the leading diagonal, and negative ints for diagonals below the leading diagonal. - xp : array_namespace - The standard-compatible namespace for `x`. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer Returns ------- @@ -189,6 +198,9 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: [0, 0, 8, 0, 0]], dtype=array_api_strict.int64) """ + if xp is None: + xp = array_namespace(x) + if x.ndim != 1: err_msg = "`x` must be 1-dimensional." raise ValueError(err_msg) @@ -200,7 +212,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: def expand_dims( - a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType + a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None ) -> Array: """ Expand the shape of an array. @@ -220,8 +232,8 @@ def expand_dims( given by a positive index could also be referred to by a negative index - that will also result in an error). Default: ``(0,)``. - xp : array_namespace - The standard-compatible namespace for `a`. + xp : array_namespace, optional + The standard-compatible namespace for `a`. Default: infer Returns ------- @@ -265,6 +277,9 @@ def expand_dims( [2]]], dtype=array_api_strict.int64) """ + if xp is None: + xp = array_namespace(a) + if not isinstance(axis, tuple): axis = (axis,) ndim = a.ndim + len(axis) @@ -282,7 +297,7 @@ def expand_dims( return a -def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: +def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: """ Kronecker product of two arrays. @@ -294,8 +309,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: Parameters ---------- a, b : array - xp : array_namespace - The standard-compatible namespace for `a` and `b`. + xp : array_namespace, optional + The standard-compatible namespace for `a` and `b`. Default: infer Returns ------- @@ -357,6 +372,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: Array(True, dtype=array_api_strict.bool) """ + if xp is None: + xp = array_namespace(a, b) b = xp.asarray(b) singletons = (1,) * (b.ndim - a.ndim) @@ -390,7 +407,12 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array: def setdiff1d( - x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType + x1: Array, + x2: Array, + /, + *, + assume_unique: bool = False, + xp: ModuleType | None = None, ) -> Array: """ Find the set difference of two arrays. @@ -406,8 +428,8 @@ def setdiff1d( assume_unique : bool If ``True``, the input arrays are both assumed to be unique, which can speed up the calculation. Default is ``False``. - xp : array_namespace - The standard-compatible namespace for `x1` and `x2`. + xp : array_namespace, optional + The standard-compatible namespace for `x1` and `x2`. Default: infer Returns ------- @@ -427,6 +449,8 @@ def setdiff1d( Array([1, 2], dtype=array_api_strict.int64) """ + if xp is None: + xp = array_namespace(x1, x2) if assume_unique: x1 = xp.reshape(x1, (-1,)) @@ -436,7 +460,7 @@ def setdiff1d( return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] -def sinc(x: Array, /, *, xp: ModuleType) -> Array: +def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: r""" Return the normalized sinc function. @@ -456,8 +480,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: x : array Array (possibly multi-dimensional) of values for which to calculate ``sinc(x)``. Must have a real floating point dtype. - xp : array_namespace - The standard-compatible namespace for `x`. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer Returns ------- @@ -511,6 +535,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: -3.89817183e-17], dtype=array_api_strict.float64) """ + if xp is None: + xp = array_namespace(x) + if not xp.isdtype(x.dtype, "real floating"): err_msg = "`x` must have a real floating data type." raise ValueError(err_msg) diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index 0c9e1d4..03e47d1 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -1,168 +1,19 @@ -### Helpers borrowed from array-api-compat - -from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 - -import inspect -import sys -import typing - -from ._typing import override - -if typing.TYPE_CHECKING: - from ._typing import Array, Device - -__all__ = ["device"] - - -# Placeholder object to represent the dask device -# when the array backend is not the CPU. -# (since it is not easy to tell which device a dask array is on) -class _dask_device: # pylint: disable=invalid-name - @override - def __repr__(self) -> str: - return "DASK_DEVICE" - - -_DASK_DEVICE = _dask_device() - - -# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray -# or cupy.ndarray. They are not included in array objects of this library -# because this library just reuses the respective ndarray classes without -# wrapping or subclassing them. These helper functions can be used instead of -# the wrapper functions for libraries that need to support both NumPy/CuPy and -# other libraries that use devices. -def device(x: Array, /) -> Device: - """ - Hardware device the array data resides on. - - This is equivalent to `x.device` according to the `standard - `__. - This helper is included because some array libraries either do not have - the `device` attribute or include it with an incompatible API. - - Parameters - ---------- - x: array - array instance from an array API compatible library. - - Returns - ------- - out: device - a ``device`` object (see the `Device Support `__ - section of the array API specification). - - Notes - ----- - - For NumPy the device is always `"cpu"`. For Dask, the device is always a - special `DASK_DEVICE` object. - - See Also - -------- - - to_device : Move array data to a different device. - - """ - if _is_numpy_array(x): - return "cpu" - if _is_dask_array(x): - # Peek at the metadata of the jax array to determine type - try: - import numpy as np # pylint: disable=import-outside-toplevel - - if isinstance(x._meta, np.ndarray): # pylint: disable=protected-access - # Must be on CPU since backed by numpy - return "cpu" - except ImportError: - pass - return _DASK_DEVICE - if _is_jax_array(x): - # JAX has .device() as a method, but it is being deprecated so that it - # can become a property, in accordance with the standard. In order for - # this function to not break when JAX makes the flip, we check for - # both here. - if inspect.ismethod(x.device): - return x.device() - return x.device - if _is_pydata_sparse_array(x): - # `sparse` will gain `.device`, so check for this first. - x_device = getattr(x, "device", None) - if x_device is not None: - return x_device - # Everything but DOK has this attr. - try: - inner = x.data - except AttributeError: - return "cpu" - # Return the device of the constituent array - return device(inner) - return x.device - - -def _is_numpy_array(x: Array) -> bool: - """Return True if `x` is a NumPy array.""" - # Avoid importing NumPy if it isn't already - if "numpy" not in sys.modules: - return False - - import numpy as np # pylint: disable=import-outside-toplevel - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array( - x +# Allow packages that vendor both `array-api-extra` and +# `array-api-compat` to override the import location +from __future__ import annotations + +try: + from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] + array_namespace, # pyright: ignore[reportUnknownVariableType] + device, # pyright: ignore[reportUnknownVariableType] + ) +except ImportError: + from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] + array_namespace, # pyright: ignore[reportUnknownVariableType] + device, ) - -def _is_dask_array(x: Array) -> bool: - """Return True if `x` is a dask.array Array.""" - # Avoid importing dask if it isn't already - if "dask.array" not in sys.modules: - return False - - # pylint: disable=import-error, import-outside-toplevel - import dask.array # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] - - return isinstance(x, dask.array.Array) - - -def _is_jax_zero_gradient_array(x: Array) -> bool: - """Return True if `x` is a zero-gradient array. - - These arrays are a design quirk of Jax that may one day be removed. - See https://github.com/google/jax/issues/20620. - """ - if "numpy" not in sys.modules or "jax" not in sys.modules: - return False - - # pylint: disable=import-error, import-outside-toplevel - import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] - import numpy as np # pylint: disable=import-outside-toplevel - - return isinstance(x, np.ndarray) and x.dtype == jax.float0 # pyright: ignore[reportUnknownVariableType] - - -def _is_jax_array(x: Array) -> bool: - """Return True if `x` is a JAX array.""" - # Avoid importing jax if it isn't already - if "jax" not in sys.modules: - return False - - # pylint: disable=import-error, import-outside-toplevel - import jax # pyright: ignore[reportMissingImports] - - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) - - -def _is_pydata_sparse_array(x: Array) -> bool: - """Return True if `x` is an array from the `sparse` package.""" - - # Avoid importing jax if it isn't already - if "sparse" not in sys.modules: - return False - - # pylint: disable=import-error, import-outside-toplevel - import sparse # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] - - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) +__all__ = [ + "array_namespace", + "device", +] diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_compat.pyi new file mode 100644 index 0000000..3b4eb43 --- /dev/null +++ b/src/array_api_extra/_lib/_compat.pyi @@ -0,0 +1,13 @@ +from types import ModuleType + +from ._typing import Array, Device + +class ArrayModule(ModuleType): + def device(self, x: Array, /) -> Device: ... + +def array_namespace( + *xs: Array, + api_version: str | None = None, + use_compat: bool | None = None, +) -> ArrayModule: ... +def device(x: Array, /) -> Device: ... diff --git a/src/array_api_extra/_lib/_utils.py b/src/array_api_extra/_lib/_utils.py index 33f800b..15e33d6 100644 --- a/src/array_api_extra/_lib/_utils.py +++ b/src/array_api_extra/_lib/_utils.py @@ -17,7 +17,7 @@ def in1d( *, assume_unique: bool = False, invert: bool = False, - xp: ModuleType, + xp: ModuleType | None = None, ) -> Array: """Checks whether each element of an array is also present in a second array. @@ -29,6 +29,8 @@ def in1d( present in numpy: https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758 """ + if xp is None: + xp = _compat.array_namespace(x1, x2) # This code is run to make the code significantly faster if x2.shape[0] < 10 * x1.shape[0] ** 0.145: @@ -71,11 +73,14 @@ def mean( *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - xp: ModuleType, + xp: ModuleType | None = None, ) -> Array: """ Complex mean, https://github.com/data-apis/array-api/issues/846. """ + if xp is None: + xp = _compat.array_namespace(x) + if xp.isdtype(x.dtype, "complex floating"): x_real = xp.real(x) x_imag = xp.imag(x) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index c5303db..488636e 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -28,91 +28,96 @@ class TestAtLeastND: def test_0D(self): x = xp.asarray(1) - y = atleast_nd(x, ndim=0, xp=xp) + y = atleast_nd(x, ndim=0) assert_array_equal(y, x) - y = atleast_nd(x, ndim=1, xp=xp) + y = atleast_nd(x, ndim=1) assert_array_equal(y, xp.ones((1,))) - y = atleast_nd(x, ndim=5, xp=xp) + y = atleast_nd(x, ndim=5) assert_array_equal(y, xp.ones((1, 1, 1, 1, 1))) def test_1D(self): x = xp.asarray([0, 1]) - y = atleast_nd(x, ndim=0, xp=xp) + y = atleast_nd(x, ndim=0) assert_array_equal(y, x) - y = atleast_nd(x, ndim=1, xp=xp) + y = atleast_nd(x, ndim=1) assert_array_equal(y, x) - y = atleast_nd(x, ndim=2, xp=xp) + y = atleast_nd(x, ndim=2) assert_array_equal(y, xp.asarray([[0, 1]])) - y = atleast_nd(x, ndim=5, xp=xp) + y = atleast_nd(x, ndim=5) assert_array_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2))) def test_2D(self): x = xp.asarray([[3]]) - y = atleast_nd(x, ndim=0, xp=xp) + y = atleast_nd(x, ndim=0) assert_array_equal(y, x) - y = atleast_nd(x, ndim=2, xp=xp) + y = atleast_nd(x, ndim=2) assert_array_equal(y, x) - y = atleast_nd(x, ndim=3, xp=xp) + y = atleast_nd(x, ndim=3) assert_array_equal(y, 3 * xp.ones((1, 1, 1))) - y = atleast_nd(x, ndim=5, xp=xp) + y = atleast_nd(x, ndim=5) assert_array_equal(y, 3 * xp.ones((1, 1, 1, 1, 1))) def test_5D(self): x = xp.ones((1, 1, 1, 1, 1)) - y = atleast_nd(x, ndim=0, xp=xp) + y = atleast_nd(x, ndim=0) assert_array_equal(y, x) - y = atleast_nd(x, ndim=4, xp=xp) + y = atleast_nd(x, ndim=4) assert_array_equal(y, x) - y = atleast_nd(x, ndim=5, xp=xp) + y = atleast_nd(x, ndim=5) assert_array_equal(y, x) - y = atleast_nd(x, ndim=6, xp=xp) + y = atleast_nd(x, ndim=6) assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1))) - y = atleast_nd(x, ndim=9, xp=xp) + y = atleast_nd(x, ndim=9) assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) def test_device(self): device = xp.Device("device1") x = xp.asarray([1, 2, 3], device=device) - assert atleast_nd(x, ndim=2, xp=xp).device == device + assert atleast_nd(x, ndim=2).device == device + + def test_xp(self): + x = xp.asarray(1) + y = atleast_nd(x, ndim=0, xp=xp) + assert_array_equal(y, x) class TestCov: def test_basic(self): assert_allclose( - cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T, xp=xp), + cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T), xp.asarray([[1.0, -1.0], [-1.0, 1.0]]), ) def test_complex(self): x = xp.asarray([[1, 2, 3], [1j, 2j, 3j]]) res = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]]) - assert_allclose(cov(x, xp=xp), res) + assert_allclose(cov(x), res) def test_empty(self): with warnings.catch_warnings(record=True): warnings.simplefilter("always", RuntimeWarning) - assert_array_equal(cov(xp.asarray([]), xp=xp), xp.nan) + assert_array_equal(cov(xp.asarray([])), xp.nan) assert_array_equal( - cov(xp.reshape(xp.asarray([]), (0, 2)), xp=xp), + cov(xp.reshape(xp.asarray([]), (0, 2))), xp.reshape(xp.asarray([]), (0, 0)), ) assert_array_equal( - cov(xp.reshape(xp.asarray([]), (2, 0)), xp=xp), + cov(xp.reshape(xp.asarray([]), (2, 0))), xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]]), ) @@ -121,14 +126,20 @@ def test_combination(self): y = xp.asarray([3, 1.1, 0.12]) X = xp.stack((x, y), axis=0) desired = xp.asarray([[11.71, -4.286], [-4.286, 2.144133]]) - assert_allclose(cov(X, xp=xp), desired, rtol=1e-6) - assert_allclose(cov(x, xp=xp), xp.asarray(11.71)) - assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6) + assert_allclose(cov(X), desired, rtol=1e-6) + assert_allclose(cov(x), xp.asarray(11.71)) + assert_allclose(cov(y), xp.asarray(2.144133), rtol=1e-6) def test_device(self): device = xp.Device("device1") x = xp.asarray([1, 2, 3], device=device) - assert cov(x, xp=xp).device == device + assert cov(x).device == device + + def test_xp(self): + assert_allclose( + cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T, xp=xp), + xp.asarray([[1.0, -1.0], [-1.0, 1.0]]), + ) class TestCreateDiagonal: @@ -138,14 +149,14 @@ def test_1d(self): b = xp.zeros((5, 5)) for k in range(5): b[k, k] = vals[k] - assert_array_equal(create_diagonal(vals, xp=xp), b) + assert_array_equal(create_diagonal(vals), b) b = xp.zeros((7, 7)) c = xp.asarray(b, copy=True) for k in range(5): b[k, k + 2] = vals[k] c[k + 2, k] = vals[k] - assert_array_equal(create_diagonal(vals, offset=2, xp=xp), b) - assert_array_equal(create_diagonal(vals, offset=-2, xp=xp), c) + assert_array_equal(create_diagonal(vals, offset=2), b) + assert_array_equal(create_diagonal(vals, offset=-2), c) @pytest.mark.parametrize("n", range(1, 10)) @pytest.mark.parametrize("offset", range(1, 10)) @@ -154,22 +165,27 @@ def test_create_diagonal(self, n: int, offset: int): rng = np.random.default_rng(2347823) one = xp.asarray(1.0) x = rng.random(n) - A = create_diagonal(xp.asarray(x, dtype=one.dtype), offset=offset, xp=xp) + A = create_diagonal(xp.asarray(x, dtype=one.dtype), offset=offset) B = xp.asarray(np.diag(x, offset), dtype=one.dtype) assert_array_equal(A, B) def test_0d(self): with pytest.raises(ValueError, match="1-dimensional"): - create_diagonal(xp.asarray(1), xp=xp) + create_diagonal(xp.asarray(1)) def test_2d(self): with pytest.raises(ValueError, match="1-dimensional"): - create_diagonal(xp.asarray([[1]]), xp=xp) + create_diagonal(xp.asarray([[1]])) def test_device(self): device = xp.Device("device1") x = xp.asarray([1, 2, 3], device=device) - assert create_diagonal(x, xp=xp).device == device + assert create_diagonal(x).device == device + + def test_xp(self): + x = xp.asarray([1, 2]) + y = create_diagonal(x, xp=xp) + assert_array_equal(y, xp.asarray([[1, 0], [0, 2]])) class TestExpandDims: @@ -184,46 +200,51 @@ def _squeeze_all(b: Array) -> Array: s = (2, 3, 4, 5) a = xp.empty(s) for axis in range(-5, 4): - b = expand_dims(a, axis=axis, xp=xp) + b = expand_dims(a, axis=axis) assert b.shape[axis] == 1 assert _squeeze_all(b).shape == s def test_axis_tuple(self): a = xp.empty((3, 3, 3)) - assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3) - assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1) - assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1) - assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3) + assert expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3) + assert expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1) + assert expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1) + assert expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3) def test_axis_out_of_range(self): s = (2, 3, 4, 5) a = xp.empty(s) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=-6, xp=xp) + expand_dims(a, axis=-6) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=5, xp=xp) + expand_dims(a, axis=5) a = xp.empty((3, 3, 3)) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=(0, -6), xp=xp) + expand_dims(a, axis=(0, -6)) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=(0, 5), xp=xp) + expand_dims(a, axis=(0, 5)) def test_repeated_axis(self): a = xp.empty((3, 3, 3)) with pytest.raises(ValueError, match="Duplicate dimensions"): - expand_dims(a, axis=(1, 1), xp=xp) + expand_dims(a, axis=(1, 1)) def test_positive_negative_repeated(self): # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817 a = xp.empty((2, 3, 4, 5)) with pytest.raises(ValueError, match="Duplicate dimensions"): - expand_dims(a, axis=(3, -3), xp=xp) + expand_dims(a, axis=(3, -3)) def test_device(self): device = xp.Device("device1") x = xp.asarray([1, 2, 3], device=device) - assert expand_dims(x, axis=0, xp=xp).device == device + assert expand_dims(x, axis=0).device == device + + def test_xp(self): + x = xp.asarray([1, 2, 3]) + y = expand_dims(x, axis=(0, 1, 2), xp=xp) + assert y.shape == (1, 1, 1, 3) class TestKron: @@ -232,36 +253,36 @@ def test_basic(self): a = xp.asarray(1) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[1, 2], [3, 4]]) - assert_array_equal(kron(a, b, xp=xp), k) + assert_array_equal(kron(a, b), k) a = xp.asarray([[1, 2], [3, 4]]) b = xp.asarray(1) - assert_array_equal(kron(a, b, xp=xp), k) + assert_array_equal(kron(a, b), k) # Using 1-dimensional array a = xp.asarray([3]) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[3, 6], [9, 12]]) - assert_array_equal(kron(a, b, xp=xp), k) + assert_array_equal(kron(a, b), k) a = xp.asarray([[1, 2], [3, 4]]) b = xp.asarray([3]) - assert_array_equal(kron(a, b, xp=xp), k) + assert_array_equal(kron(a, b), k) # Using 3-dimensional array a = xp.asarray([[[1]], [[2]]]) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) - assert_array_equal(kron(a, b, xp=xp), k) + assert_array_equal(kron(a, b), k) a = xp.asarray([[1, 2], [3, 4]]) b = xp.asarray([[[1]], [[2]]]) k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) - assert_array_equal(kron(a, b, xp=xp), k) + assert_array_equal(kron(a, b), k) def test_kron_smoke(self): a = xp.ones((3, 3)) b = xp.ones((3, 3)) k = xp.ones((9, 9)) - assert_array_equal(kron(a, b, xp=xp), k) + assert_array_equal(kron(a, b), k) @pytest.mark.parametrize( ("shape_a", "shape_b"), @@ -287,14 +308,20 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]): int(dim) for dim in xp.multiply(normalised_shape_a, normalised_shape_b) ) - k = kron(a, b, xp=xp) + k = kron(a, b) assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron") def test_device(self): device = xp.Device("device1") x1 = xp.asarray([1, 2, 3], device=device) x2 = xp.asarray([4, 5], device=device) - assert kron(x1, x2, xp=xp).device == device + assert kron(x1, x2).device == device + + def test_xp(self): + a = xp.ones((3, 3)) + b = xp.ones((3, 3)) + k = xp.ones((9, 9)) + assert_array_equal(kron(a, b, xp=xp), k) class TestSetDiff1D: @@ -303,53 +330,63 @@ def test_setdiff1d(self): x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5]) expected = xp.asarray([6, 7]) - actual = setdiff1d(x1, x2, xp=xp) + actual = setdiff1d(x1, x2) assert_array_equal(actual, expected) x1 = xp.arange(21) x2 = xp.arange(19) expected = xp.asarray([19, 20]) - actual = setdiff1d(x1, x2, xp=xp) + actual = setdiff1d(x1, x2) assert_array_equal(actual, expected) - assert_array_equal(setdiff1d(xp.empty(0), xp.empty(0), xp=xp), xp.empty(0)) + assert_array_equal(setdiff1d(xp.empty(0), xp.empty(0)), xp.empty(0)) x1 = xp.empty(0, dtype=xp.uint32) x2 = x1 - assert_equal(setdiff1d(x1, x2, xp=xp).dtype, xp.uint32) + assert_equal(setdiff1d(x1, x2).dtype, xp.uint32) def test_assume_unique(self): x1 = xp.asarray([3, 2, 1]) x2 = xp.asarray([7, 5, 2]) expected = xp.asarray([3, 1]) - actual = setdiff1d(x1, x2, assume_unique=True, xp=xp) + actual = setdiff1d(x1, x2, assume_unique=True) assert_array_equal(actual, expected) def test_device(self): device = xp.Device("device1") x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device) - assert setdiff1d(x1, x2, xp=xp).device == device + assert setdiff1d(x1, x2).device == device + + def test_xp(self): + x1 = xp.asarray([3, 8, 20]) + x2 = xp.asarray([2, 3, 4]) + expected = xp.asarray([8, 20]) + actual = setdiff1d(x1, x2, xp=xp) + assert_array_equal(actual, expected) class TestSinc: def test_simple(self): - assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) - w = sinc(xp.linspace(-1, 1, 100), xp=xp) + assert_array_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0)) + w = sinc(xp.linspace(-1, 1, 100)) # check symmetry assert_allclose(w, xp.flip(w, axis=0)) @pytest.mark.parametrize("x", [0, 1 + 3j]) def test_dtype(self, x: int | complex): with pytest.raises(ValueError, match="real floating data type"): - sinc(xp.asarray(x), xp=xp) + sinc(xp.asarray(x)) def test_3d(self): x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2)) expected = xp.zeros((3, 3, 2)) expected[0, 0, 0] = 1.0 - assert_allclose(sinc(x, xp=xp), expected, atol=1e-15) + assert_allclose(sinc(x), expected, atol=1e-15) def test_device(self): device = xp.Device("device1") x = xp.asarray(0.0, device=device) - assert sinc(x, xp=xp).device == device + assert sinc(x).device == device + + def test_xp(self): + assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 797b9a6..fbab9c4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,11 +20,18 @@ class TestIn1D: def test_no_invert_assume_unique(self, x2: Array): x1 = xp.asarray([3, 8, 20]) expected = xp.asarray([True, True, False]) - actual = in1d(x1, x2, xp=xp) + actual = in1d(x1, x2) assert_array_equal(actual, expected) def test_device(self): device = xp.Device("device1") x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device) - assert in1d(x1, x2, xp=xp).device == device + assert in1d(x1, x2).device == device + + def test_xp(self): + x1 = xp.asarray([1, 6]) + x2 = xp.arange(5) + expected = xp.asarray([True, False]) + actual = in1d(x1, x2, xp=xp) + assert_array_equal(actual, expected) diff --git a/vendor_tests/__init__.py b/vendor_tests/__init__.py new file mode 100644 index 0000000..da33f30 --- /dev/null +++ b/vendor_tests/__init__.py @@ -0,0 +1 @@ +# Allow for relative imports in test_vendor.py diff --git a/vendor_tests/_array_api_compat_vendor.py b/vendor_tests/_array_api_compat_vendor.py new file mode 100644 index 0000000..3156691 --- /dev/null +++ b/vendor_tests/_array_api_compat_vendor.py @@ -0,0 +1,11 @@ +# This file is a hook imported by src/array_api_extra/_lib/_compat.py +from __future__ import annotations + +from .array_api_compat import * # noqa: F403 +from .array_api_compat import array_namespace as array_namespace_compat + + +# Let unit tests check with `is` that we are picking up the function from this module +# and not from the original array_api_compat module. +def array_namespace(*xs, **kwargs): + return array_namespace_compat(*xs, **kwargs) diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py new file mode 100644 index 0000000..8b00a37 --- /dev/null +++ b/vendor_tests/test_vendor.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import array_api_strict as xp +from numpy.testing import assert_array_equal + + +def test_vendor_compat(): + from ._array_api_compat_vendor import array_namespace + + x = xp.asarray([1, 2, 3]) + assert array_namespace(x) is xp + + +def test_vendor_extra(): + from .array_api_extra import atleast_nd + + x = xp.asarray(1) + y = atleast_nd(x, ndim=0) + assert_array_equal(y, x) + + +def test_vendor_extra_uses_vendor_compat(): + from ._array_api_compat_vendor import array_namespace as n1 + from .array_api_extra._lib._compat import array_namespace as n2 + + assert n1 is n2