diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index a371768..38266ac 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -140,6 +140,14 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: return xp.squeeze(c, axis=axes) +def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: + n = x.shape[0] + abs(offset) + diag = xp.zeros(n**2, dtype=x.dtype) + i = offset if offset >= 0 else abs(offset) * n + diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x + return xp.reshape(diag, (n, n)) + + def _mean( x: Array, /,