From 7baaa4af0b151f617ca08d3bf3b8e70fae8f4e6c Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 Nov 2024 19:52:07 +0000 Subject: [PATCH] TST: sinc: add tests --- src/array_api_extra/_funcs.py | 7 +++++-- tests/test_funcs.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 05a8d17..5e72d3a 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -367,9 +367,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: Parameters ---------- - x : array + x : array of floats Array (possibly multi-dimensional) of values for which to calculate - ``sinc(x)``. + ``sinc(x)``. Should have a floating point dtype. Returns ------- @@ -423,5 +423,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: -3.89817183e-17], dtype=array_api_strict.float64) """ + if not xp.isdtype(x.dtype, "real floating"): + err_msg = "`x` must have a real floating data type." + raise ValueError(err_msg) y = xp.pi * xp.where(x == 0, xp.asarray(1.0e-20), x) return xp.sin(y) / y diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 556add1..b37d32b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -9,7 +9,7 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal, assert_equal -from array_api_extra import atleast_nd, cov, expand_dims, kron +from array_api_extra import atleast_nd, cov, expand_dims, kron, sinc if TYPE_CHECKING: Array = Any # To be changed to a Protocol later (see array-api#589) @@ -224,3 +224,16 @@ def test_positive_negative_repeated(self): a = xp.empty((2, 3, 4, 5)) with pytest.raises(ValueError, match="Duplicate dimensions"): expand_dims(a, axis=(3, -3), xp=xp) + + +class TestSinc: + def test_simple(self): + assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + w = sinc(xp.linspace(-1, 1, 100), xp=xp) + # check symmetry + assert_allclose(w, xp.flip(w, axis=0)) + + @pytest.mark.parametrize("x", [0, 1 + 3j]) + def test_dtype(self, x): + with pytest.raises(ValueError, match="real floating data type"): + sinc(xp.asarray(x), xp=xp)