diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index bb9c9ca..c19ecb7 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -193,7 +193,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array: err_msg = "`x` must be 1-dimensional." raise ValueError(err_msg) n = x.shape[0] + abs(offset) - diag = xp.zeros(n**2, dtype=x.dtype) + diag = xp.zeros(n**2, dtype=x.dtype, device=x.device) i = offset if offset >= 0 else abs(offset) * n diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x return xp.reshape(diag, (n, n)) @@ -516,6 +516,6 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array: raise ValueError(err_msg) # no scalars in `where` - array-api#807 y = xp.pi * xp.where( - x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype) + x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device) ) return xp.sin(y) / y diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 4599740..c5303db 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -85,6 +85,11 @@ def test_5D(self): y = atleast_nd(x, ndim=9, xp=xp) assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) + def test_device(self): + device = xp.Device("device1") + x = xp.asarray([1, 2, 3], device=device) + assert atleast_nd(x, ndim=2, xp=xp).device == device + class TestCov: def test_basic(self): @@ -120,6 +125,11 @@ def test_combination(self): assert_allclose(cov(x, xp=xp), xp.asarray(11.71)) assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6) + def test_device(self): + device = xp.Device("device1") + x = xp.asarray([1, 2, 3], device=device) + assert cov(x, xp=xp).device == device + class TestCreateDiagonal: def test_1d(self): @@ -156,6 +166,11 @@ def test_2d(self): with pytest.raises(ValueError, match="1-dimensional"): create_diagonal(xp.asarray([[1]]), xp=xp) + def test_device(self): + device = xp.Device("device1") + x = xp.asarray([1, 2, 3], device=device) + assert create_diagonal(x, xp=xp).device == device + class TestExpandDims: def test_functionality(self): @@ -205,6 +220,11 @@ def test_positive_negative_repeated(self): with pytest.raises(ValueError, match="Duplicate dimensions"): expand_dims(a, axis=(3, -3), xp=xp) + def test_device(self): + device = xp.Device("device1") + x = xp.asarray([1, 2, 3], device=device) + assert expand_dims(x, axis=0, xp=xp).device == device + class TestKron: def test_basic(self): @@ -270,6 +290,12 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]): k = kron(a, b, xp=xp) assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron") + def test_device(self): + device = xp.Device("device1") + x1 = xp.asarray([1, 2, 3], device=device) + x2 = xp.asarray([4, 5], device=device) + assert kron(x1, x2, xp=xp).device == device + class TestSetDiff1D: def test_setdiff1d(self): @@ -298,6 +324,12 @@ def test_assume_unique(self): actual = setdiff1d(x1, x2, assume_unique=True, xp=xp) assert_array_equal(actual, expected) + def test_device(self): + device = xp.Device("device1") + x1 = xp.asarray([3, 8, 20], device=device) + x2 = xp.asarray([2, 3, 4], device=device) + assert setdiff1d(x1, x2, xp=xp).device == device + class TestSinc: def test_simple(self): @@ -316,3 +348,8 @@ def test_3d(self): expected = xp.zeros((3, 3, 2)) expected[0, 0, 0] = 1.0 assert_allclose(sinc(x, xp=xp), expected, atol=1e-15) + + def test_device(self): + device = xp.Device("device1") + x = xp.asarray(0.0, device=device) + assert sinc(x, xp=xp).device == device diff --git a/tests/test_utils.py b/tests/test_utils.py index a34ec56..797b9a6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -22,3 +22,9 @@ def test_no_invert_assume_unique(self, x2: Array): expected = xp.asarray([True, True, False]) actual = in1d(x1, x2, xp=xp) assert_array_equal(actual, expected) + + def test_device(self): + device = xp.Device("device1") + x1 = xp.asarray([3, 8, 20], device=device) + x2 = xp.asarray([2, 3, 4], device=device) + assert in1d(x1, x2, xp=xp).device == device