Skip to content

Commit

Permalink
split scalar getitem tests into correctness and optimization [pr] (ti…
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalin authored Dec 10, 2024
1 parent 7436ebe commit 6d33da0
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions test/imported/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,9 +1062,9 @@ def test_getitem_scalars(self):
numpy_testing_assert_equal_helper(a[0, one], a[zero, 1])

# indexing by a scalar should slice (not copy)
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int32)])
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])

# scalar indexed with scalar
r = Tensor.randn()
Expand Down Expand Up @@ -1105,6 +1105,20 @@ def test_setitem_scalars(self):
np.testing.assert_allclose(9.9, r, rtol=1e-7)
'''

def test_getitem_casted_scalars_folding(self):
Tensor.manual_seed(0)
# cast of const is just another const, don't need extra kernels for this
a = Tensor.randn(2, 3)
one = Tensor(1, dtype=dtypes.int64)
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))

def test_getitem_scalars_simple_folding(self):
a = Tensor.randn(2, 3)
zero = Tensor(0, dtype=dtypes.int64)
one = Tensor(1, dtype=dtypes.int64)
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))

def test_basic_advanced_combined(self):
# From the NumPy indexing example
x = Tensor.arange(0, 12).reshape(4, 3)
Expand Down

0 comments on commit 6d33da0

Please sign in to comment.