Skip to content

Commit

Permalink
API: remove dependency on array-api-compat
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Sep 20, 2024
1 parent 30d7dec commit ab8085d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 43 deletions.
36 changes: 1 addition & 35 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = [
"array-api-compat",
]
dependencies = []

[project.optional-dependencies]
test = [
Expand Down
12 changes: 7 additions & 5 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from typing import TYPE_CHECKING

from array_api_compat import array_namespace # type: ignore[import-not-found]

if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd"]


def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
"""
Recursively expand the dimension of an array to have at least `ndim`.
Expand All @@ -19,15 +17,19 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
x: array
An array.
ndim: int
The minimum number of dimensions for the result.
xp: array_namespace
The array namespace for `x`.
Returns
-------
res: array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, ``res.ndim`` will equal `ndim`.
"""
xp = array_namespace(x) if xp is None else xp

x = xp.asarray(x)
if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
Expand Down

0 comments on commit ab8085d

Please sign in to comment.