Skip to content

Commit

Permalink
MAINT: atleast_nd: tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Sep 20, 2024
1 parent 7ce7e2a commit ca64ab2
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show
pylint = "*"

[tool.pixi.feature.lint.tasks]
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
pylint = { cmd = "pylint", cwd = "src" }
lint = { depends-on = ["pre-commit", "pylint"] }

[tool.pixi.feature.test.dependencies]
Expand Down Expand Up @@ -108,6 +108,9 @@ myst_parser = ">=0.13"
sphinx_copybutton = "*"
sphinx_autodoc_typehints = "*"

[tool.pixi.feature.docs.tasks]
docs = { cmd = "sphinx-build . build/", cwd = "docs" }

[tool.pixi.feature.py309.dependencies]
python = "~=3.9.0"

Expand Down
9 changes: 3 additions & 6 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@

def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
"""
Recursively expand the dimension of an array to have at least `ndim`.
Recursively expand the dimension of an array to at least `ndim`.
Parameters
----------
x: array
An array.
ndim: int
The minimum number of dimensions for the result.
xp: array_namespace
The array namespace for `x`.
Expand All @@ -28,9 +25,9 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
res: array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, ``res.ndim`` will equal `ndim`.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
until ``res.ndim`` equals `ndim`.
"""
x = xp.asarray(x)
if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
Expand Down

0 comments on commit ca64ab2

Please sign in to comment.