diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index c8f40c8ea3414..10707e40067c8 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -19,6 +19,7 @@ def vec(self, sz:int): assert self.count == 1, f"can't vectorize {self} with size {sz}" if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) + def ptr(self) -> Union[PtrDType, ImageDType]: return PtrDType(self) def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self # dependent typing? @@ -29,6 +30,7 @@ class ImageDType(DType): local: bool = False # images are never local def scalar(self) -> DType: return self.base def vec(self, sz:int): return self.base.vec(sz) + def ptr(self) -> Union[PtrDType, ImageDType]: return self def __repr__(self): return f"dtypes.{self.name}({self.shape})" # @dataclass(frozen=True, init=False, repr=False, eq=False) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 948af071e8c08..68e93d8fb1c55 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -122,7 +122,7 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp.define_global(x.dtype, ctx.bufs.index(x.arg)))]) +enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype.ptr(), (), ctx.bufs.index(x.arg)))]) def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp: if not AST_REWRITE: return sink diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d02690c28405f..734616d187945 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -246,8 +246,6 @@ def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): @functools.lru_cache(None) def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @staticmethod - def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg) - @staticmethod def range(dtype:DType, start:ConstType, end:ConstType, idx:int): return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,)) def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)