Skip to content

Commit

Permalink
TST: create_diagonal: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Nov 14, 2024
1 parent 57e26a8 commit 0a53e57
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from ._funcs import atleast_nd, cov, expand_dims, kron
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron

__version__ = "0.1.2.dev0"

__all__ = ["__version__", "atleast_nd", "cov", "expand_dims", "kron"]
__all__ = ["__version__", "atleast_nd", "cov", "create_diagonal", "expand_dims", "kron"]
3 changes: 3 additions & 0 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:


def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
if x.ndim != 1:
err_msg = "`x` must be 1-dimensional."
raise ValueError(err_msg)
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype)
i = offset if offset >= 0 else abs(offset) * n
Expand Down
26 changes: 25 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

from array_api_extra import atleast_nd, cov, expand_dims, kron
from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)
Expand Down Expand Up @@ -112,6 +112,30 @@ def test_combination(self):
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)


class TestCreateDiagonal:
def test_1d(self):
vals = 100 * xp.arange(5, dtype=xp.float64)
b = xp.zeros((5, 5))
for k in range(5):
b[k, k] = vals[k]
assert_array_equal(create_diagonal(vals, xp=xp), b)
b = xp.zeros((7, 7))
c = xp.asarray(b, copy=True)
for k in range(5):
b[k, k + 2] = vals[k]
c[k + 2, k] = vals[k]
assert_array_equal(create_diagonal(vals, offset=2, xp=xp), b)
assert_array_equal(create_diagonal(vals, offset=-2, xp=xp), c)

def test_0d(self):
with pytest.raises(ValueError, match="1-dimensional"):
print(create_diagonal(xp.asarray(1), xp=xp))

def test_2d(self):
with pytest.raises(ValueError, match="1-dimensional"):
print(create_diagonal(xp.asarray([[1]]), xp=xp))


class TestKron:
def test_basic(self):
# Using 0-dimensional array
Expand Down

0 comments on commit 0a53e57

Please sign in to comment.