Skip to content

Commit

Permalink
TST: atleast_nd: full tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Sep 20, 2024
1 parent 550331d commit 7ce7e2a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
4 changes: 3 additions & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ test = [
"pytest >=6",
"pytest-cov >=3",
"array-api-strict",
"numpy",
]
dev = [
"pytest >=6",
"pytest-cov >=3",
"array-api-strict",
"numpy",
"pylint",
]
docs = [
Expand Down Expand Up @@ -84,6 +86,7 @@ lint = { depends-on = ["pre-commit", "pylint"] }
pytest = ">=6"
pytest-cov = ">=3"
array-api-strict = "*"
numpy = "*"

[tool.pixi.feature.test.tasks]
test = { cmd = "pytest" }
Expand All @@ -92,6 +95,8 @@ test-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --duratio
[tool.pixi.feature.dev.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
array-api-strict = "*"
numpy = "*"
pylint = "*"

[tool.pixi.feature.docs.dependencies]
Expand Down Expand Up @@ -155,6 +160,7 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "array_api_extra.*"
Expand Down
62 changes: 59 additions & 3 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,68 @@
from __future__ import annotations

import array_api_strict as xp # type: ignore[import-not-found]
import array_api_strict as xp
from numpy.testing import assert_array_equal

from array_api_extra import atleast_nd


class TestAtLeastND:
def test_1d_to_2d(self):
def test_0D(self):
x = xp.asarray(1)

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=1, xp=xp)
assert_array_equal(y, xp.ones((1,)))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1)))

def test_1D(self):
x = xp.asarray([0, 1])

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=1, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=2, xp=xp)
assert_array_equal(y, xp.asarray([[0, 1]]))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2)))

def test_2D(self):
x = xp.asarray([[3]])

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=2, xp=xp)
assert y.ndim == 2
assert_array_equal(y, x)

y = atleast_nd(x, ndim=3, xp=xp)
assert_array_equal(y, 3 * xp.ones((1, 1, 1)))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))

def test_5D(self):
x = xp.ones((1, 1, 1, 1, 1))

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=4, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=6, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))

y = atleast_nd(x, ndim=9, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))

0 comments on commit 7ce7e2a

Please sign in to comment.