Skip to content

Commit

Permalink
Update indexing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Jan 19, 2025
1 parent 52b8052 commit 5ba2f41
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions dpnp/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import dpnp
from dpnp.dpnp_array import dpnp_array

from .helper import get_all_dtypes, get_integer_dtypes, has_support_aspect64
from .helper import (
get_all_dtypes,
get_array,
get_integer_dtypes,
has_support_aspect64,
)
from .third_party.cupy import testing


Expand Down Expand Up @@ -441,16 +446,15 @@ class TestPut:
)
@pytest.mark.parametrize("ind_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize(
"vals",
"ivals",
[0, [1, 2], (2, 2), dpnp.array([1, 2])],
ids=["0", "[1, 2]", "(2, 2)", "dpnp.array([1,2])"],
)
@pytest.mark.parametrize("mode", ["clip", "wrap"])
def test_input_1d(self, a_dt, indices, ind_dt, vals, mode):
def test_input_1d(self, a_dt, indices, ind_dt, ivals, mode):
a = numpy.array([-2, -1, 0, 1, 2], dtype=a_dt)
b = numpy.copy(a)
ia = dpnp.array(a)
ib = dpnp.array(b)
b, vals = numpy.copy(a), get_array(numpy, ivals)
ia, ib = dpnp.array(a), dpnp.array(b)

ind = numpy.array(indices, dtype=ind_dt)
if ind_dt == dpnp.bool and ind.all():
Expand All @@ -459,18 +463,18 @@ def test_input_1d(self, a_dt, indices, ind_dt, vals, mode):

if numpy.can_cast(ind_dt, numpy.intp, casting="safe"):
numpy.put(a, ind, vals, mode=mode)
dpnp.put(ia, iind, vals, mode=mode)
dpnp.put(ia, iind, ivals, mode=mode)
assert_array_equal(ia, a)

b.put(ind, vals, mode=mode)
ib.put(iind, vals, mode=mode)
ib.put(iind, ivals, mode=mode)
assert_array_equal(ib, b)
else:
assert_raises(TypeError, numpy.put, a, ind, vals, mode=mode)
assert_raises(TypeError, dpnp.put, ia, iind, vals, mode=mode)
assert_raises(TypeError, dpnp.put, ia, iind, ivals, mode=mode)

assert_raises(TypeError, b.put, ind, vals, mode=mode)
assert_raises(TypeError, ib.put, iind, vals, mode=mode)
assert_raises(TypeError, ib.put, iind, ivals, mode=mode)

@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize(
Expand Down Expand Up @@ -637,7 +641,7 @@ def test_values(self, arr_dt, idx_dt, ndim, values):
ia, iind = dpnp.array(a), dpnp.array(ind)

for axis in range(ndim):
numpy.put_along_axis(a, ind, values, axis)
numpy.put_along_axis(a, ind, get_array(numpy, values), axis)
dpnp.put_along_axis(ia, iind, values, axis)
assert_array_equal(ia, a)

Expand Down

0 comments on commit 5ba2f41

Please sign in to comment.