From e0aeb2e9f48c4a37094b1c3fb715c94b02993dc0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 15 Dec 2024 03:29:51 +0200 Subject: [PATCH] unwrap existing buffer from assign src [pr] (#8252) * unwrap existing buffer from assign src [pr] * this is a upat --- tinygrad/engine/schedule.py | 5 ++--- tinygrad/ops.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 910005d326061..28ca361f5136d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -68,9 +68,8 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:Dict[UOp, UOp]) -> UOp: op = buf.src[1].replace(src=tuple(to_uop(x, ctx, cache) for x in buf.src[1].src)) # ASSIGN uses the target buffer, otherwise we create a new buffer else: - src = tuple(to_uop(x, ctx, cache) for x in buf.srcs) - buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) - op = UOp(buf.op, dtype.base, src, buf.arg) + buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) + op = buf.replace(dtype=dtype.base, src=tuple(to_uop(x, ctx, cache) for x in buf.srcs)) # track the underlying tensor uop for this op ctx.tensor_uops[buf_uop] = [buf] # (early) bufferize diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c9305494f89ac..73c980a1e1146 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -509,8 +509,8 @@ def _device(self) -> Optional[str]: @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self - assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}" - return self.src[0] + assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW}, f"buf_uop called on {self.op}" + return self.src[0].buf_uop @property def buffer(self) -> Buffer: if self.base.realized is not None: return self.base.realized