diff --git a/pixi.lock b/pixi.lock index ad4b53f..26383e3 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1087,7 +1087,7 @@ packages: name: array-api-extra version: 0.1.dev0 path: . - sha256: 3015369d39b509279df91f3537e0e3ddb8cce9065ce1361d4ba7877fdbe46642 + sha256: b82b0f121a8e80e8a7d82ea44bbd2adde87e5b8299f80203eea65f4cb43914ca requires_dist: - array-api-strict ; extra == 'dev' - numpy ; extra == 'dev' diff --git a/pyproject.toml b/pyproject.toml index 368c822..5df5b52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show pylint = "*" [tool.pixi.feature.lint.tasks] -pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" } +pylint = { cmd = "pylint", cwd = "src" } lint = { depends-on = ["pre-commit", "pylint"] } [tool.pixi.feature.test.dependencies] @@ -108,6 +108,9 @@ myst_parser = ">=0.13" sphinx_copybutton = "*" sphinx_autodoc_typehints = "*" +[tool.pixi.feature.docs.tasks] +docs = { cmd = "sphinx-build . build/", cwd = "docs" } + [tool.pixi.feature.py309.dependencies] python = "~=3.9.0" diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 45d381f..72689dd 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -10,27 +10,24 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array: """ - Recursively expand the dimension of an array to have at least `ndim`. + Recursively expand the dimension of an array to at least `ndim`. Parameters ---------- x: array - An array. - ndim: int The minimum number of dimensions for the result. - xp: array_namespace - The array namespace for `x`. + The standard-compatible namespace for `x`. Returns ------- res: array An array with ``res.ndim`` >= `ndim`. If ``x.ndim`` >= `ndim`, `x` is returned. - If ``x.ndim`` < `ndim`, ``res.ndim`` will equal `ndim`. + If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes + until ``res.ndim`` equals `ndim`. """ - x = xp.asarray(x) if x.ndim < ndim: x = xp.expand_dims(x, axis=0) x = atleast_nd(x, ndim=ndim, xp=xp)