From f39e9b428844c87a07c20a73331767273ba0841f Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 28 Nov 2024 10:03:43 -0500 Subject: [PATCH] match lazy movement ops in uop [pr] (#7944) --- test/test_schedule.py | 4 ++-- tinygrad/ops.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index e0ef4afc26dc7..a7ad17ea9ca3d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1885,10 +1885,10 @@ def test_contiguous_view_simplify(self): def test_late_fusion_post_permute_simpler(self): base = ShapeTracker.from_shape((32, 16, 1)) start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop())) - r = start.view(start.st.expand((32, 16, 16))).r(Ops.ADD, (2,)) + r = start.expand((32, 16, 16)).r(Ops.ADD, (2,)) add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1)) self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1))) - to_store = add.view(add.st.permute((1, 0, 2))).contiguous() + to_store = add.permute((1, 0, 2)).contiguous() self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1))) self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2))) self.assertIs(to_store.src[0].op, Ops.VIEW) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e091371b287d0..1e87926f0c520 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -362,7 +362,12 @@ def view(self, new_st:ShapeTracker) -> UOp: return UOp.const_with_shape(self.dtype, 0, new_st.shape) if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base return UOp(Ops.VIEW, self.dtype, (self.base,), new_st) - def reshape(self, arg:Tuple[sint, ...]) -> UOp: return self.view(unwrap(self.st).reshape(arg)) + def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg)) + def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg)) + def expand(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).expand(arg)) + def permute(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).permute(arg)) + def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).shrink(arg)) + def stride(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).stride(arg)) # *** uop Buffer stuff ***