Skip to content

Commit

Permalink
TST: add device tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Nov 30, 2024
1 parent ff5c6c2 commit 38105ba
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
err_msg = "`x` must be 1-dimensional."
raise ValueError(err_msg)
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype)
diag = xp.zeros(n**2, dtype=x.dtype, device=x.device)
i = offset if offset >= 0 else abs(offset) * n
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
return xp.reshape(diag, (n, n))
Expand Down Expand Up @@ -516,6 +516,6 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
raise ValueError(err_msg)
# no scalars in `where` - array-api#807
y = xp.pi * xp.where(
x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype)
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
)
return xp.sin(y) / y
37 changes: 37 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def test_5D(self):
y = atleast_nd(x, ndim=9, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))

def test_device(self):
device = xp.Device("device1")
x = xp.asarray([1, 2, 3], device=device)
assert atleast_nd(x, ndim=2, xp=xp).device == device


class TestCov:
def test_basic(self):
Expand Down Expand Up @@ -120,6 +125,11 @@ def test_combination(self):
assert_allclose(cov(x, xp=xp), xp.asarray(11.71))
assert_allclose(cov(y, xp=xp), xp.asarray(2.144133), rtol=1e-6)

def test_device(self):
device = xp.Device("device1")
x = xp.asarray([1, 2, 3], device=device)
assert cov(x, xp=xp).device == device


class TestCreateDiagonal:
def test_1d(self):
Expand Down Expand Up @@ -156,6 +166,11 @@ def test_2d(self):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray([[1]]), xp=xp)

def test_device(self):
device = xp.Device("device1")
x = xp.asarray([1, 2, 3], device=device)
assert create_diagonal(x, xp=xp).device == device


class TestExpandDims:
def test_functionality(self):
Expand Down Expand Up @@ -205,6 +220,11 @@ def test_positive_negative_repeated(self):
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(3, -3), xp=xp)

def test_device(self):
device = xp.Device("device1")
x = xp.asarray([1, 2, 3], device=device)
assert expand_dims(x, axis=0, xp=xp).device == device


class TestKron:
def test_basic(self):
Expand Down Expand Up @@ -270,6 +290,12 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
k = kron(a, b, xp=xp)
assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron")

def test_device(self):
device = xp.Device("device1")
x1 = xp.asarray([1, 2, 3], device=device)
x2 = xp.asarray([4, 5], device=device)
assert kron(x1, x2, xp=xp).device == device


class TestSetDiff1D:
def test_setdiff1d(self):
Expand Down Expand Up @@ -298,6 +324,12 @@ def test_assume_unique(self):
actual = setdiff1d(x1, x2, assume_unique=True, xp=xp)
assert_array_equal(actual, expected)

def test_device(self):
device = xp.Device("device1")
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
assert setdiff1d(x1, x2, xp=xp).device == device


class TestSinc:
def test_simple(self):
Expand All @@ -316,3 +348,8 @@ def test_3d(self):
expected = xp.zeros((3, 3, 2))
expected[0, 0, 0] = 1.0
assert_allclose(sinc(x, xp=xp), expected, atol=1e-15)

def test_device(self):
device = xp.Device("device1")
x = xp.asarray(0.0, device=device)
assert sinc(x, xp=xp).device == device
6 changes: 6 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ def test_no_invert_assume_unique(self, x2: Array):
expected = xp.asarray([True, True, False])
actual = in1d(x1, x2, xp=xp)
assert_array_equal(actual, expected)

def test_device(self):
device = xp.Device("device1")
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
assert in1d(x1, x2, xp=xp).device == device

0 comments on commit 38105ba

Please sign in to comment.