Skip to content

Commit

Permalink
TST: sinc: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Nov 14, 2024
1 parent 28bff59 commit 7baaa4a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
Parameters
----------
x : array
x : array of floats
Array (possibly multi-dimensional) of values for which to calculate
``sinc(x)``.
``sinc(x)``. Should have a floating point dtype.
Returns
-------
Expand Down Expand Up @@ -423,5 +423,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
-3.89817183e-17], dtype=array_api_strict.float64)
"""
if not xp.isdtype(x.dtype, "real floating"):
err_msg = "`x` must have a real floating data type."
raise ValueError(err_msg)
y = xp.pi * xp.where(x == 0, xp.asarray(1.0e-20), x)
return xp.sin(y) / y
15 changes: 14 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

from array_api_extra import atleast_nd, cov, expand_dims, kron
from array_api_extra import atleast_nd, cov, expand_dims, kron, sinc

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)
Expand Down Expand Up @@ -224,3 +224,16 @@ def test_positive_negative_repeated(self):
a = xp.empty((2, 3, 4, 5))
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(3, -3), xp=xp)


class TestSinc:
def test_simple(self):
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
w = sinc(xp.linspace(-1, 1, 100), xp=xp)
# check symmetry
assert_allclose(w, xp.flip(w, axis=0))

@pytest.mark.parametrize("x", [0, 1 + 3j])
def test_dtype(self, x):
with pytest.raises(ValueError, match="real floating data type"):
sinc(xp.asarray(x), xp=xp)

0 comments on commit 7baaa4a

Please sign in to comment.