From 56c84cee2934d8b1cd5b07c02bd555121c2d3371 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:04:07 +0200 Subject: [PATCH] derive COPY nbytes late in realize [pr] (#8137) * derive COPY arg later in realize [pr] * can assume no implicit casts or movement ops here --- tinygrad/engine/lazy.py | 2 +- tinygrad/engine/realize.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 868915eed185..3a05622e7006 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -123,7 +123,7 @@ def const_arg(self) -> ConstType: def _copy(self, device:str) -> LazyBuffer: assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}" - return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False) + return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, srcs=(self,), enable_cache=False) def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer: # no COPY diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 2e4ade9bbb73..9434db46b726 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -186,12 +186,12 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem: if si.ast.op is Ops.SINK: runner = get_runner(si.outputs[0].device, si.ast) return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata) - out, arg = si.outputs[0], si.ast.arg + out = si.outputs[0] if si.ast.op is Ops.COPY: kernel_type = BufferCopy if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: kernel_type = BufferXfer - return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs)) + return ExecItem(kernel_type(out.nbytes, out.device, si.inputs[0].device), list(si.bufs)) if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs)) if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs)) raise RuntimeError(f"don't know how to lower {si.ast}")