Skip to content

Commit

Permalink
BUG: fix device compat (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley authored Dec 12, 2024
1 parent 7ff0d0a commit 6699efb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

from ._lib import _utils
from ._lib import _compat, _utils
from ._lib._compat import array_namespace
from ._lib._typing import Array, ModuleType

Expand Down Expand Up @@ -200,7 +200,7 @@ def create_diagonal(
err_msg = "`x` must be 1-dimensional."
raise ValueError(err_msg)
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype, device=x.device)
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
i = offset if offset >= 0 else abs(offset) * n
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
return xp.reshape(diag, (n, n))
Expand Down Expand Up @@ -540,6 +540,6 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
y = xp.pi * xp.where(
xp.astype(x, xp.bool),
x,
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y

0 comments on commit 6699efb

Please sign in to comment.