Skip to content

Commit

Permalink
Merge pull request #3 from lucascolley/atleast_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley authored Sep 24, 2024
2 parents 7415c53 + 89b9c8c commit 241f566
Show file tree
Hide file tree
Showing 12 changed files with 2,045 additions and 249 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ jobs:
with:
pixi-version: v0.30.0
cache: true
- name: Run Pylint
run: pixi run -e lint pylint
- name: Run Pylint & Mypy
run: |
pixi run -e lint pylint
pixi run -e lint mypy
checks:
name: Check ${{ matrix.environment }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ instance/

# Sphinx documentation
docs/_build/
docs/generated/

# PyBuilder
.pybuilder/
Expand Down
9 changes: 0 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ repos:
args: ["--fix", "--show-fixes"]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.11.1"
hooks:
- id: mypy
files: src|tests
args: []
additional_dependencies:
- pytest

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
hooks:
Expand Down
10 changes: 10 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# API Reference

```{eval-rst}
.. currentmodule:: array_api_extra
.. autosummary::
:nosignatures:
:toctree: generated
atleast_nd
```
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
extensions = [
"myst_parser",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
Expand Down
8 changes: 1 addition & 7 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,9 @@
```{toctree}
:maxdepth: 2
:hidden:
api-reference.md
```

```{include} ../README.md
:start-after: <!-- SPHINX-START -->
```

## Indices and tables

- {ref}`genindex`
- {ref}`modindex`
- {ref}`search`
2,090 changes: 1,876 additions & 214 deletions pixi.lock

Large diffs are not rendered by default.

39 changes: 23 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@ dependencies = []
test = [
"pytest >=6",
"pytest-cov >=3",
]
dev = [
"pytest >=6",
"pytest-cov >=3",
"pylint",
"array-api-strict",
"numpy",
]
docs = [
"sphinx>=7.0",
Expand Down Expand Up @@ -68,29 +65,30 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
[tool.pixi.pypi-dependencies]
array-api-extra = { path = ".", editable = true }

[tool.pixi.tasks]
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }

[tool.pixi.feature.lint.dependencies]
pre-commit = "*"
mypy = "*"
pylint = "*"
# import dependencies for mypy:
array-api-strict = "*"
numpy = "*"

[tool.pixi.feature.lint.tasks]
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }
mypy = { cmd = "mypy", cwd = "." }
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
lint = { depends-on = ["pre-commit", "pylint"] }
lint = { depends-on = ["pre-commit", "pylint", "mypy"] }

[tool.pixi.feature.test.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
array-api-strict = "*"
numpy = "*"

[tool.pixi.feature.test.tasks]
test = { cmd = "pytest" }
test-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --durations=20" }

[tool.pixi.feature.dev.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
pylint = "*"

[tool.pixi.feature.docs.dependencies]
sphinx = ">=7.0"
furo = ">=2023.08.17"
Expand All @@ -100,6 +98,15 @@ myst_parser = ">=0.13"
sphinx_copybutton = "*"
sphinx_autodoc_typehints = "*"

[tool.pixi.feature.docs.tasks]
docs = { cmd = ["sphinx-build", ".", "build/"], cwd = "docs" }

[tool.pixi.feature.dev.dependencies]
ipython = "*"

[tool.pixi.feature.dev.tasks]
ipython = { cmd = "ipython" }

[tool.pixi.feature.py309.dependencies]
python = "~=3.9.0"

Expand All @@ -109,9 +116,9 @@ python = "~=3.12.0"
[tool.pixi.environments]
default = { solve-group = "default" }
lint = { features = ["lint"], solve-group = "default" }
docs = { features = ["docs"], solve-group = "default" }
test = { features = ["test"], solve-group = "default" }
dev = { features = ["dev", "docs"], solve-group = "default" }
docs = { features = ["docs"], solve-group = "default" }
dev = { features = ["lint", "test", "docs", "dev"], solve-group = "default" }
ci-py309 = ["py309", "test"]
ci-py312 = ["py312", "test"]

Expand Down
4 changes: 3 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from ._funcs import atleast_nd

__version__ = "0.1.dev0"

__all__ = ["__version__"]
__all__ = ["__version__", "atleast_nd"]
48 changes: 48 additions & 0 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd"]


def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.
Parameters
----------
x : array
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace
The standard-compatible namespace for `x`.
Returns
-------
res : array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
until ``res.ndim`` equals `ndim`.
Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1])
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
Array([[[1]]], dtype=array_api_strict.int64)
>>> x = xp.asarray([[[1, 2],
... [3, 4]]])
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
True
"""
if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
return x
9 changes: 9 additions & 0 deletions src/array_api_extra/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from types import ModuleType
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)

__all__ = ["Array", "ModuleType"]
69 changes: 69 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

# array-api-strict#6
import array_api_strict as xp # type: ignore[import-untyped]
from numpy.testing import assert_array_equal

from array_api_extra import atleast_nd


class TestAtLeastND:
def test_0D(self):
x = xp.asarray(1)

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=1, xp=xp)
assert_array_equal(y, xp.ones((1,)))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1)))

def test_1D(self):
x = xp.asarray([0, 1])

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=1, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=2, xp=xp)
assert_array_equal(y, xp.asarray([[0, 1]]))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2)))

def test_2D(self):
x = xp.asarray([[3]])

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=2, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=3, xp=xp)
assert_array_equal(y, 3 * xp.ones((1, 1, 1)))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))

def test_5D(self):
x = xp.ones((1, 1, 1, 1, 1))

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=4, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=6, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))

y = atleast_nd(x, ndim=9, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))

0 comments on commit 241f566

Please sign in to comment.