Skip to content

Commit

Permalink
pre merge_views+ops_folding refactor [pr] (tinygrad#8140)
Browse files Browse the repository at this point in the history
* simple start

* valid early

* more dumb things removed

* don't ever use base

* cleaner
  • Loading branch information
Qazalin authored Dec 10, 2024
1 parent 56c84ce commit 0840527
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,9 @@ def group_realizes(ctx:ScheduleContext) -> List[List[UOp]]:

# ** ops in the big graph can either be pre-realized or scheduled (fused/realized)

class UPatRealized(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),))
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
UPat(*args, **{**kwargs,"name":"to_store"})))
UPat(*args, **{"name":"to_store",**kwargs})))

# ** this is schedule level const folding

Expand Down Expand Up @@ -364,20 +362,19 @@ def simplify_reduceop(ctx, reduce:UOp, x:UOp) -> Optional[UOp]:

# ** this decides which ops get realized

def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None:
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, **kwargs) -> None:
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.update([(b, to_store)])

def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
if to_store.op in {Ops.CONST, Ops.BIND}: return None
base_shape = unwrap(base.st).shape
def realize_view(ctx:Dict[UOp, UOp], view:UOp, src:UOp, b:UOp, **kwargs) -> None:
if src.st is None: return None
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])):
return None if can_pad(base, ctx, set()) else realize(ctx, b, to_store, base)
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base)
if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(base, ctx, set())) else realize(ctx, b, to_store, base)
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx, set())) else realize(ctx, b, src)

def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None
Expand All @@ -394,7 +391,7 @@ def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
# always realize meta ops
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
# realize before expand or unsafe pad ops
(UPatScheduled().view(name="view"), realize_view),
(UPatScheduled(name="src").view(name="view"), realize_view),
# don't realize image to image casts
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
# realize before COPY or BUFFER_VIEW
Expand All @@ -410,8 +407,8 @@ def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
return UOp.const_with_shape(base.dtype, val, unwrap(base.st).shape)

def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), append_op(ctx, b, to_store))
return UOp(Ops.LOAD, base.dtype, (b, st.to_uop()))
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(base.shape).to_uop(), append_op(ctx, b, to_store))
return UOp(Ops.LOAD, base.dtype, (b, unwrap(base.st).to_uop()))

def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
if (m:=ctx.lazybufs[b].metadata) is not None: ctx.ops_metadata[to_store] = m
Expand All @@ -420,10 +417,10 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
break_sched = PatternMatcher([
# consts are always fused and generated
(UPatScheduled({Ops.CONST, Ops.BIND}), generate_valid),
# everything else is a VIEW of BUFFER that either realizes or fuses
# view of realized buffer just loads
(UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, v.st.to_uop()))),
# all other views either fold or realize with a store
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
# just load realized buffers
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))),
])

@track_rewrites(named=True)
Expand Down

0 comments on commit 0840527

Please sign in to comment.