Skip to content

Commit

Permalink
match lazy movement ops in uop [pr] (tinygrad#7944)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalin authored Nov 28, 2024
1 parent f545085 commit f39e9b4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 2 additions & 2 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,10 +1885,10 @@ def test_contiguous_view_simplify(self):
def test_late_fusion_post_permute_simpler(self):
base = ShapeTracker.from_shape((32, 16, 1))
start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
r = start.view(start.st.expand((32, 16, 16))).r(Ops.ADD, (2,))
r = start.expand((32, 16, 16)).r(Ops.ADD, (2,))
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
to_store = add.view(add.st.permute((1, 0, 2))).contiguous()
to_store = add.permute((1, 0, 2)).contiguous()
self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1)))
self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2)))
self.assertIs(to_store.src[0].op, Ops.VIEW)
Expand Down
7 changes: 6 additions & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,12 @@ def view(self, new_st:ShapeTracker) -> UOp:
return UOp.const_with_shape(self.dtype, 0, new_st.shape)
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg))
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))
def expand(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).expand(arg))
def permute(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).stride(arg))

# *** uop Buffer stuff ***

Expand Down

0 comments on commit f39e9b4

Please sign in to comment.