diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index 82a3050b1fed..23ad967d07f9 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -1062,9 +1062,9 @@ def test_getitem_scalars(self): numpy_testing_assert_equal_helper(a[0, one], a[zero, 1]) # indexing by a scalar should slice (not copy) - self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one])) - self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)])) - self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)])) + numpy_testing_assert_equal_helper(a[0, 1], a[zero, one]) + numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int32)]) + numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)]) # scalar indexed with scalar r = Tensor.randn() @@ -1105,6 +1105,20 @@ def test_setitem_scalars(self): np.testing.assert_allclose(9.9, r, rtol=1e-7) ''' + def test_getitem_casted_scalars_folding(self): + Tensor.manual_seed(0) + # cast of const is just another const, don't need extra kernels for this + a = Tensor.randn(2, 3) + one = Tensor(1, dtype=dtypes.int64) + self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)])) + self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)])) + + def test_getitem_scalars_simple_folding(self): + a = Tensor.randn(2, 3) + zero = Tensor(0, dtype=dtypes.int64) + one = Tensor(1, dtype=dtypes.int64) + self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one])) + def test_basic_advanced_combined(self): # From the NumPy indexing example x = Tensor.arange(0, 12).reshape(4, 3)