diff --git a/docs/api-reference.md b/docs/api-reference.md index a459743..6eddb76 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -8,6 +8,7 @@ atleast_nd cov + create_diagonal expand_dims kron ``` diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 3126420..c07a365 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from ._typing import Array, ModuleType -__all__ = ["atleast_nd", "cov", "expand_dims", "kron"] +__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron"] def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: @@ -141,6 +141,42 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: + """ + Construct a diagonal array. + + Parameters + ---------- + x : array + A 1-D array + offset : int, optional + Offset from the leading diagonal (default is ``0``). + Use positive ints for diagonals above the leading diagonal, + and negative ints for diagonals below the leading diagonal. + + Returns + ------- + res : array + A 2-D array with `x` on the diagonal (offset by `offset`). + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([2, 4, 8]) + + >>> xpx.create_diagonal(x, xp=xp) + Array([[2, 0, 0], + [0, 4, 0], + [0, 0, 8]], dtype=array_api_strict.int64) + + >>> xpx.create_diagonal(x, offset=-2, xp=xp) + Array([[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [2, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 8, 0, 0]], dtype=array_api_strict.int64) + + """ if x.ndim != 1: err_msg = "`x` must be 1-dimensional." raise ValueError(err_msg)