Skip to content

Commit

Permalink
TST: setdiff1d: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Nov 29, 2024
1 parent f835502 commit fadf701
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/array_api_extra/_lib/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def in1d(
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
if sar.size >= 1:
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
else:
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
ret = xp.take(flag, reverse_order, axis=0)

Expand Down
38 changes: 37 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
from array_api_extra import (
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
setdiff1d,
sinc,
)

if typing.TYPE_CHECKING:
from array_api_extra._lib._typing import Array
Expand Down Expand Up @@ -263,6 +271,34 @@ def test_positive_negative_repeated(self):
expand_dims(a, axis=(3, -3), xp=xp)


class TestSetDiff1D:
def test_setdiff1d(self):
x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])
x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5])

expected = xp.asarray([6, 7])
actual = setdiff1d(x1, x2, xp=xp)
assert_array_equal(actual, expected)

x1 = xp.arange(21)
x2 = xp.arange(19)
expected = xp.asarray([19, 20])
actual = setdiff1d(x1, x2, xp=xp)
assert_array_equal(actual, expected)

assert_array_equal(setdiff1d(xp.empty(0), xp.empty(0), xp=xp), xp.empty(0))
x1 = xp.empty(0, dtype=xp.uint32)
x2 = x1
assert_equal(setdiff1d(x1, x2, xp=xp).dtype, xp.uint32)

def test_setdiff1d_unique(self):
x1 = xp.asarray([3, 2, 1])
x2 = xp.asarray([7, 5, 2])
expected = xp.asarray([3, 1])
actual = setdiff1d(x1, x2, assume_unique=True, xp=xp)
assert_array_equal(actual, expected)


class TestSinc:
def test_simple(self):
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
Expand Down

0 comments on commit fadf701

Please sign in to comment.