From 5ba2f418195132aeb43d99cb9c5a386a50c2e1f7 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Sun, 19 Jan 2025 13:32:01 +0100 Subject: [PATCH] Update indexing tests --- dpnp/tests/test_indexing.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/dpnp/tests/test_indexing.py b/dpnp/tests/test_indexing.py index 595151008eb..20a43a09ffa 100644 --- a/dpnp/tests/test_indexing.py +++ b/dpnp/tests/test_indexing.py @@ -17,7 +17,12 @@ import dpnp from dpnp.dpnp_array import dpnp_array -from .helper import get_all_dtypes, get_integer_dtypes, has_support_aspect64 +from .helper import ( + get_all_dtypes, + get_array, + get_integer_dtypes, + has_support_aspect64, +) from .third_party.cupy import testing @@ -441,16 +446,15 @@ class TestPut: ) @pytest.mark.parametrize("ind_dt", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( - "vals", + "ivals", [0, [1, 2], (2, 2), dpnp.array([1, 2])], ids=["0", "[1, 2]", "(2, 2)", "dpnp.array([1,2])"], ) @pytest.mark.parametrize("mode", ["clip", "wrap"]) - def test_input_1d(self, a_dt, indices, ind_dt, vals, mode): + def test_input_1d(self, a_dt, indices, ind_dt, ivals, mode): a = numpy.array([-2, -1, 0, 1, 2], dtype=a_dt) - b = numpy.copy(a) - ia = dpnp.array(a) - ib = dpnp.array(b) + b, vals = numpy.copy(a), get_array(numpy, ivals) + ia, ib = dpnp.array(a), dpnp.array(b) ind = numpy.array(indices, dtype=ind_dt) if ind_dt == dpnp.bool and ind.all(): @@ -459,18 +463,18 @@ def test_input_1d(self, a_dt, indices, ind_dt, vals, mode): if numpy.can_cast(ind_dt, numpy.intp, casting="safe"): numpy.put(a, ind, vals, mode=mode) - dpnp.put(ia, iind, vals, mode=mode) + dpnp.put(ia, iind, ivals, mode=mode) assert_array_equal(ia, a) b.put(ind, vals, mode=mode) - ib.put(iind, vals, mode=mode) + ib.put(iind, ivals, mode=mode) assert_array_equal(ib, b) else: assert_raises(TypeError, numpy.put, a, ind, vals, mode=mode) - assert_raises(TypeError, dpnp.put, ia, iind, vals, mode=mode) + assert_raises(TypeError, dpnp.put, ia, iind, ivals, mode=mode) assert_raises(TypeError, b.put, ind, vals, mode=mode) - assert_raises(TypeError, ib.put, iind, vals, mode=mode) + assert_raises(TypeError, ib.put, iind, ivals, mode=mode) @pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( @@ -637,7 +641,7 @@ def test_values(self, arr_dt, idx_dt, ndim, values): ia, iind = dpnp.array(a), dpnp.array(ind) for axis in range(ndim): - numpy.put_along_axis(a, ind, values, axis) + numpy.put_along_axis(a, ind, get_array(numpy, values), axis) dpnp.put_along_axis(ia, iind, values, axis) assert_array_equal(ia, a)