From 175434548b5dc09230f5e7e6ffdec259b24b5176 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 9 Dec 2024 18:00:58 +0000 Subject: [PATCH] BUG: fix sinc for torch --- src/array_api_extra/_funcs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 7a9ba40..3d961e2 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -543,6 +543,8 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: raise ValueError(err_msg) # no scalars in `where` - array-api#807 y = xp.pi * xp.where( - x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device) + xp.astype(x, xp.bool), + x, + xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device), ) return xp.sin(y) / y