Skip to content

Commit

Permalink
allow VIEW on BUFFER [pr] (tinygrad#8136)
Browse files Browse the repository at this point in the history
* allow VIEW of BUFFER [pr]

* base it later

* better diff

* base shouldn't exist after anywhere merge_views
  • Loading branch information
Qazalin authored Dec 10, 2024
1 parent 3a2658e commit 2d26b01
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
16 changes: 7 additions & 9 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,23 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
buf.buffer.dtype = dtype
buf.buffer.options = None
# base is a VIEW of (BUFFER, (optional) op)
if buf.is_realized:
# TODO: this is the same underlying Buffer in all schedules
buf_uop = UOp.new_buffer(buf.device, buf.size, dtype)
op = None
# TODO: this is the same underlying Buffer in all schedules, delete_lazy fixes this
if buf.is_realized: ret = UOp.new_buffer(buf.device, buf.size, dtype).view(buf.st)
# ASSIGN uses the target buffer, otherwise we create a new buffer
else:
src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs)
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop,) if op is None else (buf_uop, op.contiguous() if buf.forced_realize else op), buf.st)
# keep track of ops outside the big graph
buffers[buf_uop] = buf.buffer
if op is not None:
ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
# keep track of scheduled ops
buf.buffer.ref(1)
ctx.lazybufs[buf_uop] = buf
ctx.allbufs[buf_uop] = ret
if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
cache[buf] = ret
buffers[ret.buf_uop] = buf.buffer
return ret

# **** AST graph rewrite
Expand Down Expand Up @@ -425,7 +423,7 @@ def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
# everything else is a VIEW of BUFFER that either realizes or fuses
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
# just load realized buffers
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))),
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))),
])

@track_rewrites(named=True)
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,11 @@ def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and al
@property
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
if self.st is None: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
# instant folding rules
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return ret
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))
Expand Down

0 comments on commit 2d26b01

Please sign in to comment.