Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add create_diagonal #19

Merged
merged 6 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

atleast_nd
cov
create_diagonal
expand_dims
kron
sinc
Expand Down
12 changes: 10 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

from ._funcs import atleast_nd, cov, expand_dims, kron, sinc
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc

__version__ = "0.1.2.dev0"

__all__ = ["__version__", "atleast_nd", "cov", "expand_dims", "kron", "sinc"]
__all__ = [
"__version__",
"atleast_nd",
"cov",
"create_diagonal",
"expand_dims",
"kron",
"sinc",
]
51 changes: 50 additions & 1 deletion src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd", "cov", "expand_dims", "kron", "sinc"]
__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
Expand Down Expand Up @@ -140,6 +140,55 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
return xp.squeeze(c, axis=axes)


def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
"""
Construct a diagonal array.

Parameters
----------
x : array
A 1-D array
offset : int, optional
Offset from the leading diagonal (default is ``0``).
Use positive ints for diagonals above the leading diagonal,
and negative ints for diagonals below the leading diagonal.
xp : array_namespace
The standard-compatible namespace for `x`.

Returns
-------
res : array
A 2-D array with `x` on the diagonal (offset by `offset`).

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([2, 4, 8])

>>> xpx.create_diagonal(x, xp=xp)
Array([[2, 0, 0],
[0, 4, 0],
[0, 0, 8]], dtype=array_api_strict.int64)

>>> xpx.create_diagonal(x, offset=-2, xp=xp)
Array([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[2, 0, 0, 0, 0],
[0, 4, 0, 0, 0],
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)

"""
if x.ndim != 1:
err_msg = "`x` must be 1-dimensional."
raise ValueError(err_msg)
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype)
i = offset if offset >= 0 else abs(offset) * n
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
return xp.reshape(diag, (n, n))


def _mean(
x: Array,
/,
Expand Down
39 changes: 38 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

# array-api-strict#6
import array_api_strict as xp # type: ignore[import-untyped]
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

from array_api_extra import atleast_nd, cov, expand_dims, kron, sinc
from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)
Expand Down Expand Up @@ -112,6 +113,42 @@ def test_combination(self):
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)


class TestCreateDiagonal:
def test_1d(self):
lucascolley marked this conversation as resolved.
Show resolved Hide resolved
# from np.diag tests
vals = 100 * xp.arange(5, dtype=xp.float64)
b = xp.zeros((5, 5))
for k in range(5):
b[k, k] = vals[k]
assert_array_equal(create_diagonal(vals, xp=xp), b)
b = xp.zeros((7, 7))
c = xp.asarray(b, copy=True)
for k in range(5):
b[k, k + 2] = vals[k]
c[k + 2, k] = vals[k]
assert_array_equal(create_diagonal(vals, offset=2, xp=xp), b)
assert_array_equal(create_diagonal(vals, offset=-2, xp=xp), c)

@pytest.mark.parametrize("n", range(1, 10))
@pytest.mark.parametrize("offset", range(1, 10))
def test_create_diagonal(self, n, offset):
# from scipy._lib tests
rng = np.random.default_rng(2347823)
one = xp.asarray(1.0)
x = rng.random(n)
A = create_diagonal(xp.asarray(x, dtype=one.dtype), offset=offset, xp=xp)
B = xp.asarray(np.diag(x, offset), dtype=one.dtype)
assert_array_equal(A, B)

def test_0d(self):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray(1), xp=xp)

def test_2d(self):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray([[1]]), xp=xp)


class TestKron:
def test_basic(self):
# Using 0-dimensional array
Expand Down