Skip to content

Commit

Permalink
ENH: add atleast_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Sep 20, 2024
1 parent e765ddf commit 30d7dec
Show file tree
Hide file tree
Showing 6 changed files with 848 additions and 3 deletions.
784 changes: 783 additions & 1 deletion pixi.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = []
dependencies = [
"array-api-compat",
]

[project.optional-dependencies]
test = [
"pytest >=6",
"pytest-cov >=3",
"array-api-strict",
]
dev = [
"pytest >=6",
"pytest-cov >=3",
"array-api-strict",
"pylint",
]
docs = [
Expand Down Expand Up @@ -83,6 +87,7 @@ lint = { depends-on = ["pre-commit", "pylint"] }
[tool.pixi.feature.test.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
array-api-strict = "*"

[tool.pixi.feature.test.tasks]
test = { cmd = "pytest" }
Expand Down
4 changes: 3 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from ._funcs import atleast_nd

__version__ = "0.1.dev0"

__all__ = ["__version__"]
__all__ = ["__version__", "atleast_nd"]
35 changes: 35 additions & 0 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from array_api_compat import array_namespace # type: ignore[import-not-found]

if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd"]


def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to have at least `ndim`.
Parameters
----------
x: array
An array.
Returns
-------
res: array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, ``res.ndim`` will equal `ndim`.
"""
xp = array_namespace(x) if xp is None else xp

x = xp.asarray(x)
if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
return x
9 changes: 9 additions & 0 deletions src/array_api_extra/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from types import ModuleType
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)

__all__ = ["Array", "ModuleType"]
12 changes: 12 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

import array_api_strict as xp # type: ignore[import-not-found]

from array_api_extra import atleast_nd


class TestAtLeastND:
def test_1d_to_2d(self):
x = xp.asarray([0, 1])
y = atleast_nd(x, ndim=2, xp=xp)
assert y.ndim == 2

0 comments on commit 30d7dec

Please sign in to comment.