diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e915e621fc3c..f5caaf74b61d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -145,28 +145,23 @@ def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...], var_vals:Dict[Variable # *** List[LazyBuffer] lowering to ScheduleItem *** -def _recursive_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r + if buf is not buf.base: + cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, cache).view(buf.st) + return ret dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype - - # buffer ops define ShapeTracker - # if it's realized, it's a load and we add it to the inputs if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs: - if buf.op is MetaOps.CONST: return ubuf.view(buf.st) + if buf.op is MetaOps.CONST: return ubuf if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf) - return UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop())) - - # only reduceop changes shape - src: List[UOp] = [] - for x in buf.srcs: - u = _recursive_uop(x.base, outputs, inputs, buf_uops, cache) - src.append(u if x is x.base else u.view(x.st)) + return UOp.load(ubuf, buf.st.to_uop(), dtype=dtype) + src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0])) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1])) - elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype) - elif buf.op is UnaryOps.BITCAST: ret = src[0].bitcast(dtype) - else: ret = UOp(UOps.ALU, dtype, tuple(src), buf.op) + elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src) + elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src) + else: ret = UOp(UOps.ALU, dtype, src, buf.op) cache[buf] = ret return ret @@ -180,7 +175,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val ast: List[UOp] = [] inputs: List[LazyBuffer] = [] for out in outs: - src = _recursive_uop(out, outs, inputs, buf_uops, cache) + src = to_uop(out, outs, inputs, buf_uops, cache) if out.op is MetaOps.ASSIGN and out.arg: assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" output_st = out.arg[0] @@ -193,7 +188,8 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets): raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x in cache if x.metadata and x not in inputs]))) + return LBScheduleItem(sink, tuple(outs+inputs), + tuple(dedup([x.metadata for x in cache if x.metadata is not None and (x.base in outs or x.base.buffer not in buf_uops)]))) # *** DAG creation: decide which LazyBuffers should realize ***