From 79203ef8171b405c39b058d00edb65b32b129dab Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 10 Sep 2024 11:14:09 +0800 Subject: [PATCH] test upper bound for special --- tinygrad/ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8398ff0af9999..f2341444d28c8 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -430,12 +430,12 @@ def vmax(self) -> ConstType: return self._min_max[1] @functools.cached_property def _min_max(self) -> Tuple[ConstType, ConstType]: # NOTE: returned UOp is assumed to be CONST - if self.op is UOps.DEFINE_VAR and self.arg: - return self.arg[1].arg, self.arg[2].arg if self.arg[2].op is UOps.CONST else dtypes.max(cast(DType, self.dtype)) + if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1].vmin, self.arg[2].vmax if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax if self.op is UOps.EXPAND: return min(x.vmin for x in self.src), max(x.vmax for x in self.src) # TODO: UOps.SPECIAL is UOps.DEFINE_VAR - if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(cast(DType, self.dtype)) + # NOTE: arg[1] is int or Variable + if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].max-1 if self.op is UOps.CONST: return self.arg, self.arg if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]