Skip to content

Commit

Permalink
more test examples to merge views [pr] (tinygrad#8277)
Browse files Browse the repository at this point in the history
these have masks in self and masks in the merged views
  • Loading branch information
chenyuxyz authored Dec 17, 2024
1 parent 6e2e56c commit 3195bd0
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions test/unit/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,32 @@ def test_with_mask_2(self):
self.assertIsNotNone(v)
self.assertEqual(v, View(shape=(3, 3, 2, 2), strides=(27, 9, 3, 1), offset=3, mask=None, contiguous=False))

def test_with_mask_3(self):
# from test/test_ops.py::TestOps::test_pad_reflect_mode
# has a mask in the final view
v0 = View(shape=(3, 3, 4, 4), strides=(27, 9, 3, 1), offset=-5, mask=((0, 3), (0, 3), (2, 4), (0, 2)), contiguous=False)
v1 = View(shape=(3, 3, 4, 2), strides=(48, 16, 4, 1), offset=0, mask=None, contiguous=False)
v = v0 + v1
self.assertIsNotNone(v)
self.assertEqual(v, View(shape=(3, 3, 4, 2), strides=(27, 9, 3, 1), offset=-5, mask=((0, 3), (0, 3), (2, 4), (0, 2)), contiguous=False))

def test_with_mask_4(self):
# from test/test_ops.py::TestOps::test_pad_reflect_mode
# has a mask in the final view
v0 = View(shape=(3, 3, 5, 3), strides=(27, 9, -3, 1), offset=6, mask=((0, 3), (0, 3), (0, 2), (1, 3)), contiguous=False)
v1 = View(shape=(3, 3, 3, 3), strides=(45, 15, 3, 1), offset=6, mask=None, contiguous=False)
v = v0 + v1
self.assertIsNotNone(v)
self.assertEqual(v, View(shape=(3, 3, 3, 3), strides=(0, 0, 0, 0), offset=0, mask=((0, 0), (0, 0), (0, 0), (0, 0)), contiguous=False))

def test_with_mask_5(self):
# from test/test_ops.py::TestOps::test_pad_reflect_mode
# has a mask in the final view
v0 = View(shape=(1, 1, 6, 5), strides=(0, 0, 5, 1), offset=-5, mask=((0, 1), (0, 1), (1, 6), (0, 5)), contiguous=False)
v1 = View(shape=(1, 1, 6, 3), strides=(0, 0, 5, -1), offset=3, mask=None, contiguous=False)
v = v0 + v1
self.assertIsNotNone(v)
self.assertEqual(v, View(shape=(1, 1, 6, 3), strides=(0, 0, 5, -1), offset=-2, mask=((0, 1), (0, 1), (1, 6), (0, 3)), contiguous=False))

if __name__ == '__main__':
unittest.main()

0 comments on commit 3195bd0

Please sign in to comment.