Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEV: add numpydoc to pre-commit #67

Merged
merged 2 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@ repos:
hooks:
- id: check-dependabot
- id: check-github-workflows

- repo: https://github.com/numpy/numpydoc
rev: "v1.8.0"
hooks:
- id: numpydoc-validation
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Sphinx config."""

import importlib.metadata
from typing import Any

Expand Down
229 changes: 174 additions & 55 deletions pixi.lock

Large diffs are not rendered by default.

26 changes: 20 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pre-commit = "*"
pylint = "*"
basedmypy = "*"
basedpyright = "*"
numpydoc = ">=1.8.0,<2"
# import dependencies for mypy:
array-api-strict = "*"
numpy = "*"
Expand Down Expand Up @@ -145,13 +146,9 @@ ci-py313 = ["py313", "tests"]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
filterwarnings = [
"error",
]
filterwarnings = ["error"]
log_cli_level = "INFO"
testpaths = [
"tests",
]
testpaths = ["tests"]


# Coverage
Expand Down Expand Up @@ -262,3 +259,20 @@ messages_control.disable = [
"missing-function-docstring",
"wrong-import-position",
]


# numpydoc

[tool.numpydoc_validation]
checks = [
"all", # report on all checks, except the below
"EX01",
"SA01",
"ES01",
]
exclude = [ # don't report on objects that match any of these regex
'.*test_funcs.*',
'.*test_utils.*',
'.*test_version.*',
'.*test_vendor.*',
]
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc

__version__ = "0.4.1.dev0"
Expand Down
45 changes: 20 additions & 25 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Public API Functions."""

import warnings

from ._lib import _compat, _utils
Expand All @@ -22,14 +24,15 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
Parameters
----------
x : array
Input array.
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
res : array
array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
Expand All @@ -47,7 +50,6 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
... [3, 4]]])
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
True

"""
if xp is None:
xp = array_namespace(x)
Expand Down Expand Up @@ -77,11 +79,11 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
Each row of `m` represents a variable, and each column a single
observation of all those variables.
xp : array_namespace, optional
The standard-compatible namespace for `m`. Default: infer
The standard-compatible namespace for `m`. Default: infer.

Returns
-------
res : array
array
The covariance matrix of the variables.

Examples
Expand All @@ -104,7 +106,6 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
Array([[ 1., -1.],
[-1., 1.]], dtype=array_api_strict.float64)


Note that element :math:`C_{0,1}`, which shows the correlation between
:math:`x_0` and :math:`x_1`, is negative.

Expand All @@ -122,7 +123,6 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:

>>> xpx.cov(y, xp=xp)
Array(2.14413333, dtype=array_api_strict.float64)

"""
if xp is None:
xp = array_namespace(m)
Expand Down Expand Up @@ -161,17 +161,17 @@ def create_diagonal(
Parameters
----------
x : array
A 1-D array
A 1-D array.
offset : int, optional
Offset from the leading diagonal (default is ``0``).
Use positive ints for diagonals above the leading diagonal,
and negative ints for diagonals below the leading diagonal.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
res : array
array
A 2-D array with `x` on the diagonal (offset by `offset`).

Examples
Expand All @@ -191,7 +191,6 @@ def create_diagonal(
[2, 0, 0, 0, 0],
[0, 4, 0, 0, 0],
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)

"""
if xp is None:
xp = array_namespace(x)
Expand Down Expand Up @@ -221,18 +220,19 @@ def expand_dims(
Parameters
----------
a : array
Array to have its shape expanded.
axis : int or tuple of ints, optional
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
If multiple positions are provided, they should be unique (note that a position
given by a positive index could also be referred to by a negative index -
that will also result in an error).
Default: ``(0,)``.
xp : array_namespace, optional
The standard-compatible namespace for `a`. Default: infer
The standard-compatible namespace for `a`. Default: infer.

Returns
-------
res : array
array
`a` with an expanded shape.

Examples
Expand Down Expand Up @@ -270,7 +270,6 @@ def expand_dims(
>>> y
Array([[[1],
[2]]], dtype=array_api_strict.int64)

"""
if xp is None:
xp = array_namespace(a)
Expand Down Expand Up @@ -304,12 +303,13 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
Parameters
----------
a, b : array
Input arrays.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
res : array
array
The Kronecker product of `a` and `b`.

Notes
Expand All @@ -333,7 +333,6 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
[ ... ... ],
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]


Examples
--------
>>> import array_api_strict as xp
Expand All @@ -352,7 +351,6 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
[0., 0., 1., 1.],
[0., 0., 1., 1.]], dtype=array_api_strict.float64)


>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
>>> c = xpx.kron(a, b, xp=xp)
Expand All @@ -365,7 +363,6 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
>>> c[K] == a[I]*b[J]
Array(True, dtype=array_api_strict.bool)

"""
if xp is None:
xp = array_namespace(a, b)
Expand Down Expand Up @@ -424,11 +421,11 @@ def setdiff1d(
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer
The standard-compatible namespace for `x1` and `x2`. Default: infer.

Returns
-------
res : array
array
1D array of values in `x1` that are not in `x2`. The result
is sorted when `assume_unique` is ``False``, but otherwise only sorted
if the input is sorted.
Expand All @@ -442,7 +439,6 @@ def setdiff1d(
>>> x2 = xp.asarray([3, 4, 5, 6])
>>> xpx.setdiff1d(x1, x2, xp=xp)
Array([1, 2], dtype=array_api_strict.int64)

"""
if xp is None:
xp = array_namespace(x1, x2)
Expand Down Expand Up @@ -476,11 +472,11 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Array (possibly multi-dimensional) of values for which to calculate
``sinc(x)``. Must have a real floating point dtype.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
res : array
array
``sinc(x)`` calculated elementwise, which has the same shape as the input.

Notes
Expand Down Expand Up @@ -528,7 +524,6 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
-5.84680802e-02, -8.90384387e-02,
-8.40918587e-02, -4.92362781e-02,
-3.89817183e-17], dtype=array_api_strict.float64)

"""
if xp is None:
xp = array_namespace(x)
Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Modules housing private functions."""
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Acquire helpers from array-api-compat."""
# Allow packages that vendor both `array-api-extra` and
# `array-api-compat` to override the import location

Expand Down
10 changes: 6 additions & 4 deletions src/array_api_extra/_lib/_compat.pyi
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Static type stubs for `_compat.py`."""

from types import ModuleType

from ._typing import Array, Device

# pylint: disable=missing-class-docstring,unused-argument

class ArrayModule(ModuleType):
def device(self, x: Array, /) -> Device: ...
class ArrayModule(ModuleType): # numpydoc ignore=GL08
def device(self, x: Array, /) -> Device: ... # numpydoc ignore=GL08

def array_namespace(
*xs: Array,
api_version: str | None = None,
use_compat: bool | None = None,
) -> ArrayModule: ...
def device(x: Array, /) -> Device: ...
) -> ArrayModule: ... # numpydoc ignore=GL08
def device(x: Array, /) -> Device: ... # numpydoc ignore=GL08
2 changes: 2 additions & 0 deletions src/array_api_extra/_lib/_typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Static typing helpers."""

from types import ModuleType
from typing import Any

Expand Down
10 changes: 6 additions & 4 deletions src/array_api_extra/_lib/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utility functions used by `array_api_extra/_funcs.py`."""

from . import _compat
from ._typing import Array, ModuleType

Expand All @@ -12,9 +14,9 @@ def in1d(
assume_unique: bool = False,
invert: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""Checks whether each element of an array is also present in a
second array.
) -> Array: # numpydoc ignore=PR01,RT01
"""
Check whether each element of an array is also present in a second array.

Returns a boolean array the same length as `x1` that is True
where an element of `x1` is in `x2` and False otherwise.
Expand Down Expand Up @@ -68,7 +70,7 @@ def mean(
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
xp: ModuleType | None = None,
) -> Array:
) -> Array: # numpydoc ignore=PR01,RT01
"""
Complex mean, https://github.com/data-apis/array-api/issues/846.
"""
Expand Down
2 changes: 1 addition & 1 deletion vendor_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# Allow for relative imports in test_vendor.py
"""Allow for relative imports in `test_vendor.py`."""
5 changes: 3 additions & 2 deletions vendor_tests/_array_api_compat_vendor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# This file is a hook imported by src/array_api_extra/_lib/_compat.py
"""This file is a hook imported by `src/array_api_extra/_lib/_compat.py`."""

from .array_api_compat import * # noqa: F403
from .array_api_compat import array_namespace as array_namespace_compat


# Let unit tests check with `is` that we are picking up the function from this module
# and not from the original array_api_compat module.
def array_namespace(*xs, **kwargs):
def array_namespace(*xs, **kwargs): # numpydoc ignore=GL08
return array_namespace_compat(*xs, **kwargs)
Loading