diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8f54e700d000..44adb7538b01 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -67,25 +67,23 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache buf.buffer.dtype = dtype buf.buffer.options = None # base is a VIEW of (BUFFER, (optional) op) - if buf.is_realized: - # TODO: this is the same underlying Buffer in all schedules - buf_uop = UOp.new_buffer(buf.device, buf.size, dtype) - op = None + # TODO: this is the same underlying Buffer in all schedules, delete_lazy fixes this + if buf.is_realized: ret = UOp.new_buffer(buf.device, buf.size, dtype).view(buf.st) # ASSIGN uses the target buffer, otherwise we create a new buffer else: src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs) buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg) - cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.contiguous() if buf.forced_realize else op), buf.st) - # keep track of ops outside the big graph - buffers[buf_uop] = buf.buffer - if op is not None: + ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st) + # keep track of scheduled ops buf.buffer.ref(1) ctx.lazybufs[buf_uop] = buf ctx.allbufs[buf_uop] = ret if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop) for x in op.src: if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None + cache[buf] = ret + buffers[ret.buf_uop] = buf.buffer return ret # **** AST graph rewrite @@ -425,7 +423,7 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp: # everything else is a VIEW of BUFFER that either realizes or fuses (UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)), # just load realized buffers - (UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))), + (UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))), ]) @track_rewrites(named=True) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1effd76f8745..aacc7b7f66ac 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -356,11 +356,11 @@ def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and al @property def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self def view(self, new_st:ShapeTracker) -> UOp: - assert self.st is not None and self.base.st is not None, f"must have shape {self}" + if self.st is None: return UOp(Ops.VIEW, self.dtype, (self,), new_st) ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st) # instant folding rules if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0) - if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base + if new_st.contiguous and self.base.shape == new_st.shape: return self.base return ret def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg)) def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))