Skip to content

Commit

Permalink
add size of the buffer to the ptr dtype (tinygrad#8322)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot authored Dec 18, 2024
1 parent 52243b2 commit 6608ba3
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_ptr_eq(self):
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(dtypes.float32.ptr()), "dtypes.float.ptr()")
self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)")

class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fold_expanded(ex, buf):
rootsrc[0] if isinstance(rootsrc, tuple) else None)
else:
# for non image, we upcast the index pointer
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.local))
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(local=new_src[0].dtype.local))
# generate the folded new_srcs
if is_load:
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
Expand Down
17 changes: 10 additions & 7 deletions tinygrad/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, prod

ConstType = Union[float, int, bool]

Expand Down Expand Up @@ -38,30 +38,33 @@ def vec(self, sz:int) -> DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
def ptr(self, size=-1, local=False) -> PtrDType:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self

@dataclass(frozen=True, eq=False)
class PtrDType(DType):
_base: DType
local: bool
v: int
size: int = -1 # -1 is unlimited size
@property
def base(self): return self._base
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def vec(self, sz:int) -> DType:
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
@property
def vcount(self): return self.v
def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
def __repr__(self):
return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')

@dataclass(frozen=True, eq=False)
class ImageDType(PtrDType):
shape: Tuple[int, ...] = () # shape of the Image
def ptr(self, local=False) -> PtrDType:
def ptr(self, size=-1, local=False) -> PtrDType:
assert not local, "images can't be local"
return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
Expand Down Expand Up @@ -131,9 +134,9 @@ def fields() -> Dict[str, DType]: return DTYPES_DICT

# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)

default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
assert x.arg[0] != -1, "fake -1 BUFFERS should not make it here"
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(), (), len(ctx.bufs)-1)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.arg[1]), (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])

def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
Expand Down

0 comments on commit 6608ba3

Please sign in to comment.