diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 44adb7538b01..a74f6f4a40d9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -324,11 +324,9 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]: # ** ops in the big graph can either be pre-realized or scheduled (fused/realized) -class UPatRealized(UPat): - def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),)) class UPatScheduled(UPat): def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), - UPat(*args, **{**kwargs,"name":"to_store"}))) + UPat(*args, **{"name":"to_store",**kwargs}))) # ** this is schedule level const folding @@ -364,20 +362,19 @@ def simplify_reduceop(ctx, reduce:UOp, x:UOp) -> Optional[UOp]: # ** this decides which ops get realized -def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: +def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, **kwargs) -> None: if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)]) -def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None: - if to_store.op in {Ops.CONST, Ops.BIND}: return None - base_shape = unwrap(base.st).shape +def realize_view(ctx:Dict[UOp, UOp], view:UOp, src:UOp, b:UOp, **kwargs) -> None: + if src.st is None: return None st = unwrap(view.st) # fold simple pads - if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])): - return None if can_pad(base, ctx, set()) else realize(ctx, b, to_store, base) + if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])): + return None if can_pad(src, ctx, set()) else realize(ctx, b, src) # early realize before expand - if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base) + if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src) # otherwise safety check pads - return None if (all(v.mask is None for v in st.views) or can_pad(base, ctx, set())) else realize(ctx, b, to_store, base) + return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx, set())) else realize(ctx, b, src) def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]: if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None @@ -394,7 +391,7 @@ def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]: # always realize meta ops (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize), # realize before expand or unsafe pad ops - (UPatScheduled().view(name="view"), realize_view), + (UPatScheduled(name="src").view(name="view"), realize_view), # don't realize image to image casts (UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast), # realize before COPY or BUFFER_VIEW @@ -410,8 +407,8 @@ def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape) def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp: - ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), append_op(ctx, b, to_store)) - return UOp(Ops.LOAD, base.dtype, (b, st.to_uop())) + ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(base.shape).to_uop(), append_op(ctx, b, to_store)) + return UOp(Ops.LOAD, base.dtype, (b, unwrap(base.st).to_uop())) def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp: if (m:=ctx.lazybufs[b].metadata) is not None: ctx.ops_metadata[to_store] = m @@ -420,10 +417,10 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp: break_sched = PatternMatcher([ # consts are always fused and generated (UPatScheduled({Ops.CONST, Ops.BIND}), generate_valid), - # everything else is a VIEW of BUFFER that either realizes or fuses + # view of realized buffer just loads + (UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, v.st.to_uop()))), + # all other views either fold or realize with a store (UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)), - # just load realized buffers - (UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))), ]) @track_rewrites(named=True)