Skip to content

Commit

Permalink
more comments and tests to reshape [pr] (tinygrad#8228)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored Dec 13, 2024
1 parent 6d83a96 commit e371a23
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
10 changes: 10 additions & 0 deletions test/unit/test_view.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import unittest
from tinygrad.shape.view import View, merge_dims
# from tinygrad.shape.shapetracker import ShapeTracker

class TestView(unittest.TestCase):
def test_canonicalize_empty_mask(self):
Expand Down Expand Up @@ -60,5 +61,14 @@ def test_pad_reshape(self):
# permute 0 / 1
self.assertEqual(merge_dims((3, 2), (1, 0), ((0, 2), (1, 2))), ((3, 1, 3), (2, 0, 0)))

def test_different_1_pad(self):
# st = ShapeTracker.from_shape((2, 2, 1)).pad(((0, 0), (0, 0), (0, 1)))
# print(f"{st.views[-1]}")
self.assertEqual(merge_dims((2, 2, 2), (2, 1, 0), ((0, 2), (0, 2), (0, 1))), ((4, 1, 4), (2, 0, 0)))

# st = ShapeTracker.from_shape((2, 1, 1)).pad(((0, 0), (0, 1), (0, 1)))
# print(f"{st.views[-1]}")
self.assertEqual(merge_dims((2, 2, 2), (1, 0, 0), ((0, 2), (0, 2), (0, 1))), ((2, 1, 2), (4, 0, 0)))

if __name__ == '__main__':
unittest.main()
19 changes: 10 additions & 9 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:

@functools.lru_cache(maxsize=None)
def merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o expand (zero stride)), ...]
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
if not shape: return ()
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
Expand Down Expand Up @@ -319,19 +319,20 @@ def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
# all dimensions matched, return the new view directly
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)

strides, r_new_shape = [], reversed(new_shape)
for merged_dim, new_stride, real_dim in reversed(merge_dims(self.shape, self.strides, self.mask)):
r_strides, r_new_shape = [], reversed(new_shape)
for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)):
# TODO: write with get_contraction
acc = 1
# TODO: third resolve shouldn't be needed
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
strides.append(new_stride)
while resolve(acc <= merged_size) and resolve(acc != merged_size) and resolve((new_dim := next(r_new_shape, 0)) > 0):
r_strides.append(new_stride * acc)
acc = acc * new_dim
# TODO: likely a bug, what if expand happened before acc < real_dim happens?
if resolve(new_dim != 1): new_stride *= (new_dim if resolve(acc < real_dim) else 0)
if resolve(acc != merged_dim): return None
# merge dim merges if (1) previous_stride = stride * dim, (2) dim = 1, stride = 0, either padded or not
if not resolve(acc < real_size): new_stride = 0
if resolve(acc != merged_size): return None

if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
new_strides = (0,) * (len(new_shape) - len(strides)) + tuple(strides[::-1])
new_strides = (0,) * (len(new_shape) - len(r_strides)) + tuple(r_strides[::-1])
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
Expand Down

0 comments on commit e371a23

Please sign in to comment.