Skip to content

Commit

Permalink
small changes to refine the delete_lazy diff (tinygrad#8134)
Browse files Browse the repository at this point in the history
* _view -> view

* const_arg things
  • Loading branch information
Qazalin authored Dec 10, 2024
1 parent 6d33da0 commit 3a2658e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion test/imported/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def consec(shape, start=1):
def set_(reference: Tensor, shape, strides, offset):
if reference.lazydata.base.realized is None: reference.realize()
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
strided = Tensor(reference.lazydata._view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
strided = Tensor(reference.lazydata.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
return strided

Expand Down
8 changes: 4 additions & 4 deletions test/test_lazybuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def test_shrink_const_then_cast(self):

def test_const_dtype(self):
lb: LazyBuffer = Tensor([1], dtype=dtypes.int).lazydata
assert lb.const_like(1).base.arg == 1
assert type(lb.const_like(1).base.arg) is int
assert lb.const_like(1).const_arg == 1
assert type(lb.const_like(1).const_arg) is int

lb: LazyBuffer = Tensor([1], dtype=dtypes.float).lazydata
assert lb.const_like(1).base.arg == 1.0
assert type(lb.const_like(1).base.arg) is float
assert lb.const_like(1).const_arg == 1.0
assert type(lb.const_like(1).const_arg) is float

def test_forced_realized_alu(self):
a = Tensor.randn(2, 2).realize()
Expand Down
22 changes: 11 additions & 11 deletions tinygrad/engine/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyB
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
# TODO: applying this makes gpt2 slower
return self.base.cast(dtype, bitcast)._view(self.st)
return self.base.cast(dtype, bitcast).view(self.st)
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, None, (self,))

Expand All @@ -135,21 +135,21 @@ def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> Lazy

# const doesn't have to be copied (issues with disk tensor)
if self.is_unrealized_const():
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg).view(self.st)

# if it's a shrink, do the shrink before the copy with CONTIGUOUS
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)

# copy the base and apply the shapetracker on the new device
return self.base._copy(device)._view(self.st)
return self.base._copy(device).view(self.st)

def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)

def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
srcs: List[LazyBuffer] = []
for s in (self,)+in_srcs:
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
srcs.append(root._view(s.base.contiguous_child[1]))
srcs.append(root.view(s.base.contiguous_child[1]))
else:
srcs.append(s)
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
Expand Down Expand Up @@ -207,15 +207,15 @@ def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:

# *** movement ops ***

def _view(self, new_st:ShapeTracker) -> LazyBuffer:
def view(self, new_st:ShapeTracker) -> LazyBuffer:
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
return self.const_with_shape(0, new_st.shape)
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)

def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
def reshape(self, arg:Tuple[sint, ...]): return self.view(self.st.reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(self.st.pad(arg))
def expand(self, arg:Tuple[sint, ...]): return self.view(self.st.expand(arg))
def permute(self, arg:Tuple[int, ...]): return self.view(self.st.permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(self.st.shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self.view(self.st.stride(arg))

0 comments on commit 3a2658e

Please sign in to comment.