Skip to content

Commit

Permalink
derive COPY nbytes late in realize [pr] (tinygrad#8137)
Browse files Browse the repository at this point in the history
* derive COPY arg later in realize [pr]

* can assume no implicit casts or movement ops here
  • Loading branch information
Qazalin authored Dec 10, 2024
1 parent 2d26b01 commit 56c84ce
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tinygrad/engine/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def const_arg(self) -> ConstType:

def _copy(self, device:str) -> LazyBuffer:
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, srcs=(self,), enable_cache=False)

def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
# no COPY
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
if si.ast.op is Ops.SINK:
runner = get_runner(si.outputs[0].device, si.ast)
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
out, arg = si.outputs[0], si.ast.arg
out = si.outputs[0]
if si.ast.op is Ops.COPY:
kernel_type = BufferCopy
if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
kernel_type = BufferXfer
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
return ExecItem(kernel_type(out.nbytes, out.device, si.inputs[0].device), list(si.bufs))
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
raise RuntimeError(f"don't know how to lower {si.ast}")
Expand Down

0 comments on commit 56c84ce

Please sign in to comment.