From 6699efb3817399b709f51c1aea47f82fa72a536b Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 12 Dec 2024 16:05:14 +0000 Subject: [PATCH] BUG: fix device compat (#63) --- src/array_api_extra/_funcs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 75fa3b1..bbbc59c 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -1,6 +1,6 @@ import warnings -from ._lib import _utils +from ._lib import _compat, _utils from ._lib._compat import array_namespace from ._lib._typing import Array, ModuleType @@ -200,7 +200,7 @@ def create_diagonal( err_msg = "`x` must be 1-dimensional." raise ValueError(err_msg) n = x.shape[0] + abs(offset) - diag = xp.zeros(n**2, dtype=x.dtype, device=x.device) + diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x)) i = offset if offset >= 0 else abs(offset) * n diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x return xp.reshape(diag, (n, n)) @@ -540,6 +540,6 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: y = xp.pi * xp.where( xp.astype(x, xp.bool), x, - xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device), + xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), ) return xp.sin(y) / y