Skip to content

Commit

Permalink
add dtype.ptr() [pr] (tinygrad#6839)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalin authored Oct 2, 2024
1 parent be12409 commit 29363fb
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 2 additions & 0 deletions tinygrad/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def vec(self, sz:int):
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def ptr(self) -> Union[PtrDType, ImageDType]: return PtrDType(self)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self

# dependent typing?
Expand All @@ -29,6 +30,7 @@ class ImageDType(DType):
local: bool = False # images are never local
def scalar(self) -> DType: return self.base
def vec(self, sz:int): return self.base.vec(sz)
def ptr(self) -> Union[PtrDType, ImageDType]: return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})"

# @dataclass(frozen=True, init=False, repr=False, eq=False)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])

enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp.define_global(x.dtype, ctx.bufs.index(x.arg)))])
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype.ptr(), (), ctx.bufs.index(x.arg)))])

def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
if not AST_REWRITE: return sink
Expand Down
2 changes: 0 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
@functools.lru_cache(None)
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@staticmethod
def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg)
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
Expand Down

0 comments on commit 29363fb

Please sign in to comment.