Skip to content

Commit

Permalink
Add dependency to array-api-compat
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Dec 6, 2024
1 parent 4360355 commit b06cbcc
Show file tree
Hide file tree
Showing 12 changed files with 319 additions and 263 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/test-vendor.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Test vendoring support

on:
workflow_dispatch:
pull_request:
push:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

env:
# Many color libraries just need this to be set to any value, but at least
# one distinguishes color depth, where "3" -> "256-bit color".
FORCE_COLOR: 3

jobs:
pre-commit-and-lint:
name: Format
runs-on: ubuntu-latest
steps:
- name: Checkout array-api-extra
uses: actions/checkout@v4
with:
path: array-api-extra

- name: Checkout array-api-compat
uses: actions/checkout@v4
with:
repository: data-apis/array-api-compat
path: array-api-compat

- name: Vendor array-api-extra into test package
run: |
cp -a array-api-compat/array_api_compat array-api-extra/vendor_tests/
cp -a array-api-extra/src/array_api_extra array-api-extra/vendor_tests/
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.x"

- name: Install Pixi
uses: prefix-dev/[email protected]
with:
pixi-version: v0.37.0
cache: true

- name: Test package
run: pixi run tests-vendor
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ ENV/
env.bak/
venv.bak/

# Spyder project settings
# IDE project settings
.idea/
.spyderproject
.spyproject
.vscode/

# Rope project settings
.ropeproject
Expand Down Expand Up @@ -160,3 +162,7 @@ Thumbs.db
# pixi environments
.pixi
*.egg-info

# Vendor tests
vendor_tests/array_api_compat/
vendor_tests/array_api_extra/
47 changes: 44 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,57 @@ If you require stability, it is recommended to pin `array-api-extra` to
a specific version, or vendor the library inside your own.
```

```{note}
This library depends on array-api-compat. We aim for compatibility with
the latest released version of array-api-compat, and your mileage may vary
with older or dev versions.
```

(vendoring)=

## Vendoring

To vendor the library, clone
[the repository](https://github.com/data-apis/array-api-extra) and copy it into
the appropriate place in your library, like:
[the array-api-extra repository](https://github.com/data-apis/array-api-extra)
and copy it into the appropriate place in your library, like:

```
cp -R array-api-extra/ mylib/vendored/array_api_extra
cp -a array-api-extra/src/array_api_extra mylib/vendored/
```

`array-api-extra` depends on `array-api-compat`. You may either add a dependency
in your own project to `array-api-compat` or vendor it too:

1. Clone
[the array-api-compat repository](https://github.com/data-apis/array-api-compat)
and copy it next to your vendored array-api-extra:

```
cp -a array-api-compat/array_api_compat mylib/vendored/
```

2. Create a new hook file which array-api-extra will use instead of the
top-level `array-api-compat` if present:

```
echo 'from mylib.vendored.array_api_compat import *' > mylib/vendored/_array_api_compat_vendor.py
```

This also allows overriding `array-api-compat` functions if you so wish. E.g.
your `mylib/vendored/_array_api_compat_vendor.py` could look like this:

```python
from mylib.vendored.array_api_compat import *
from mylib.vendored.array_api_compat import array_namespace as _array_namespace_orig


def array_namespace(*xs, **kwargs):
import mylib

if any(isinstance(x, mylib.MyArray) for x in xs):
return mylib
else:
return _array_namespace_orig(*xs, **kwargs)
```

(usage)=
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = []
dependencies = ["array-api-compat"]

[project.optional-dependencies]
tests = [
Expand Down Expand Up @@ -96,6 +96,7 @@ numpy = "*"
[tool.pixi.feature.tests.tasks]
tests = { cmd = "pytest" }
tests-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --durations=20" }
tests-vendor = { cmd = "pytest vendor_tests" }

[tool.pixi.feature.docs.dependencies]
sphinx = ">=7.0"
Expand Down
69 changes: 48 additions & 21 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._lib._typing import Array, ModuleType

from ._lib import _utils
from ._lib._compat import array_namespace

__all__ = [
"atleast_nd",
Expand All @@ -19,7 +20,7 @@
]


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.
Expand All @@ -28,8 +29,8 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
x : array
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace
The standard-compatible namespace for `x`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
Returns
-------
Expand All @@ -53,13 +54,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
True
"""
if xp is None:
xp = array_namespace(x)

if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
return x


def cov(m: Array, /, *, xp: ModuleType) -> Array:
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Estimate a covariance matrix.
Expand All @@ -77,8 +81,8 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
observation of all those variables.
xp : array_namespace
The standard-compatible namespace for `m`.
xp : array_namespace, optional
The standard-compatible namespace for `m`. Default: infer
Returns
-------
Expand Down Expand Up @@ -125,6 +129,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
Array(2.14413333, dtype=array_api_strict.float64)
"""
if xp is None:
xp = array_namespace(m)

m = xp.asarray(m, copy=True)
dtype = (
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
Expand All @@ -150,7 +157,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
return xp.squeeze(c, axis=axes)


def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
def create_diagonal(
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
) -> Array:
"""
Construct a diagonal array.
Expand All @@ -162,8 +171,8 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
Offset from the leading diagonal (default is ``0``).
Use positive ints for diagonals above the leading diagonal,
and negative ints for diagonals below the leading diagonal.
xp : array_namespace
The standard-compatible namespace for `x`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
Returns
-------
Expand All @@ -189,6 +198,9 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
"""
if xp is None:
xp = array_namespace(x)

if x.ndim != 1:
err_msg = "`x` must be 1-dimensional."
raise ValueError(err_msg)
Expand All @@ -200,7 +212,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
"""
Expand the shape of an array.
Expand All @@ -220,8 +232,8 @@ def expand_dims(
given by a positive index could also be referred to by a negative index -
that will also result in an error).
Default: ``(0,)``.
xp : array_namespace
The standard-compatible namespace for `a`.
xp : array_namespace, optional
The standard-compatible namespace for `a`. Default: infer
Returns
-------
Expand Down Expand Up @@ -265,6 +277,9 @@ def expand_dims(
[2]]], dtype=array_api_strict.int64)
"""
if xp is None:
xp = array_namespace(a)

if not isinstance(axis, tuple):
axis = (axis,)
ndim = a.ndim + len(axis)
Expand All @@ -282,7 +297,7 @@ def expand_dims(
return a


def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Kronecker product of two arrays.
Expand All @@ -294,8 +309,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
Parameters
----------
a, b : array
xp : array_namespace
The standard-compatible namespace for `a` and `b`.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer
Returns
-------
Expand Down Expand Up @@ -357,6 +372,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
Array(True, dtype=array_api_strict.bool)
"""
if xp is None:
xp = array_namespace(a, b)

b = xp.asarray(b)
singletons = (1,) * (b.ndim - a.ndim)
Expand Down Expand Up @@ -390,7 +407,12 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:


def setdiff1d(
x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType
x1: Array,
x2: Array,
/,
*,
assume_unique: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Find the set difference of two arrays.
Expand All @@ -406,8 +428,8 @@ def setdiff1d(
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace
The standard-compatible namespace for `x1` and `x2`.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer
Returns
-------
Expand All @@ -427,6 +449,8 @@ def setdiff1d(
Array([1, 2], dtype=array_api_strict.int64)
"""
if xp is None:
xp = array_namespace(x1, x2)

if assume_unique:
x1 = xp.reshape(x1, (-1,))
Expand All @@ -436,7 +460,7 @@ def setdiff1d(
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType) -> Array:
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
r"""
Return the normalized sinc function.
Expand All @@ -456,8 +480,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
x : array
Array (possibly multi-dimensional) of values for which to calculate
``sinc(x)``. Must have a real floating point dtype.
xp : array_namespace
The standard-compatible namespace for `x`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
Returns
-------
Expand Down Expand Up @@ -511,6 +535,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
-3.89817183e-17], dtype=array_api_strict.float64)
"""
if xp is None:
xp = array_namespace(x)

if not xp.isdtype(x.dtype, "real floating"):
err_msg = "`x` must have a real floating data type."
raise ValueError(err_msg)
Expand Down
Loading

0 comments on commit b06cbcc

Please sign in to comment.