From d726eb6f48aaae7f9c3b9addf7fee1d17950249b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 1 Oct 2024 13:11:42 +0800 Subject: [PATCH] uop resolve [run_process_replay] (#6826) * uop bool and int and stuff [run_process_replay] * add ne support * can't even be None anymore * BinaryOps.AND support * less compare --- test/external/fuzz_linearizer.py | 2 +- test/test_multitensor.py | 2 +- test/unit/test_uop_resolve.py | 84 ++++++++++++++++++++++++++++++++ tinygrad/codegen/kernel.py | 5 +- tinygrad/codegen/uopgraph.py | 4 +- tinygrad/engine/schedule.py | 4 +- tinygrad/lazy.py | 2 +- tinygrad/ops.py | 31 ++++++++++-- tinygrad/renderer/assembly.py | 2 +- viz/serve.py | 2 +- 10 files changed, 122 insertions(+), 16 deletions(-) create mode 100644 test/unit/test_uop_resolve.py diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 27f2645a34469..06ce78beafb45 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -185,7 +185,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2): def _is_simple(lin: Kernel) -> bool: if len(lin.ast.src) > 1: return False ast:UOp = lin.ast.src[0] - if ast.src[0] and ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is UOps.LOAD: return True + if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is UOps.LOAD: return True return False if __name__ == "__main__": diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 98b6db14addfd..113826868dfa3 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -618,7 +618,7 @@ def test_broadcast_const(self): ast = si.ast.src[0] assert ast.op is UOps.STORE assert ast.src[2].arg is BinaryOps.ADD - assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0] + assert ast.src[2].src[0].op is UOps.LOAD assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1 t = 2 * t for si in t.schedule(): diff --git a/test/unit/test_uop_resolve.py b/test/unit/test_uop_resolve.py new file mode 100644 index 0000000000000..dd4633abcd8f2 --- /dev/null +++ b/test/unit/test_uop_resolve.py @@ -0,0 +1,84 @@ +import unittest +from tinygrad.dtype import dtypes +from tinygrad.ops import UOp + +class TestUOpResolve(unittest.TestCase): + def test_simple_int(self): + u = UOp.const(dtypes.int, 4) + self.assertEqual(int(u), 4) + + def test_int_add(self): + u = UOp.const(dtypes.int, 4) + 7 + self.assertEqual(int(u), 11) + + def test_lt(self): + u = UOp.const(dtypes.int, 4) < 7 + self.assertTrue(u) + + def test_leq(self): + u = UOp.const(dtypes.int, 4) <= 4 + self.assertTrue(u) + + def test_ne(self): + u = UOp.const(dtypes.int, 4).ne(7) + self.assertTrue(u) + + def test_ne_f(self): + u = UOp.const(dtypes.int, 4).ne(4) + self.assertFalse(u) + + def test_ngt(self): + u = UOp.const(dtypes.int, 4) > 7 + self.assertFalse(u) + + def test_float_direct(self): + u = UOp.const(dtypes.float, 4.5) + 7 + self.assertEqual(float(u), 11.5) + + def test_var_cmp_t(self): + u = UOp.define_var("i", dtypes.pyint, 1, 10) < 20 + self.assertTrue(u) + + def test_var_cmp_t2(self): + u = UOp.define_var("i", dtypes.pyint, 1, 10)//2 < 20 + self.assertTrue(u) + + def test_var_cmp_f(self): + u = UOp.define_var("i", dtypes.pyint, 1, 10) < 1 + self.assertFalse(u) + + def test_var_cmp_f2(self): + u = UOp.define_var("i", dtypes.pyint, 1, 10) > 11 + self.assertFalse(u) + + def test_or_true(self): + u = UOp.define_var("b", dtypes.bool, False, True) | True + self.assertTrue(u) + + def test_or_false(self): + with self.assertRaises(ValueError): + u = UOp.define_var("b", dtypes.bool, False, True) | False + self.assertTrue(u) + + def test_and_false(self): + u = UOp.define_var("b", dtypes.bool, False, True) & False + self.assertFalse(u) + + def test_and_true(self): + with self.assertRaises(ValueError): + u = UOp.define_var("b", dtypes.bool, False, True) & True + self.assertFalse(u) + + @unittest.skip("too fancy to be supported right now") + def test_var_cmp_range(self): + v = UOp.define_var("i", dtypes.pyint, 1, 10) + u = v > 4 or v < 6 + self.assertTrue(u) + + def test_var_cmp_assert(self): + with self.assertRaises(ValueError): + u = UOp.define_var("i", dtypes.pyint, 1, 10) < 5 + self.assertFalse(u) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index a5b875dbc4ed5..aee854e85a44d 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -385,7 +385,8 @@ def apply_opt(self, opt:Opt, append_opt:bool=True): if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift") else: amt = -1 - if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): + if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ + (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): acc_sz = self.reduceop.dtype.itemsize upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b]) local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces]) @@ -598,7 +599,7 @@ def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in @functools.cached_property def name(self) -> str: # kernel name (before late upcast) - name = ("r" if self.reduceop else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \ + name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \ (f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \ colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 48d64446aa167..213696b8c97b6 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -39,7 +39,7 @@ def fold_expanded(ex, buf): if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)): load_1 = new_srcs[offsets[o]] new_src = list(load_1.src) - if not new_src[1].divides(fold_length): continue + if new_src[1].divides(fold_length) is None: continue # for images, we rewrite the index. it must evenly divide 4 from the above check if is_image: new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((new_src[1] // 4) % buf.dtype.shape[1], (new_src[1] // (4 * buf.dtype.shape[1])))) @@ -264,7 +264,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp): if not drop_stmt and idx.key == start_idx.key: return None new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None - return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx))) + return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx))) # ***** optional patterns ***** diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1005ea87392e1..d29b0acbc7765 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -50,7 +50,7 @@ class ScheduleItemContext: # ** helpers for doing movementops on uops def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: - if (n:=cache.get(u)): return n + if (n:=cache.get(u)) is not None: return n if u.op is UOps.SHAPETRACKER: new_st = apply_to_st(u.arg) return u if u.arg == new_st else UOp(UOps.SHAPETRACKER, dtypes.void, (), new_st) @@ -140,7 +140,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. # buffer ops define ShapeTracker # if it's realized, it's a load and we add it to the inputs - if (ubuf:=buf_uops.get(buf.buffer)) and buf not in outputs: + if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs: unbound_st, st_var_vals = st.simplify().unbind() var_vals.update(st_var_vals) if buf.op is MetaOps.CONST: diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 47bc6803dc9ce..9962b76885147 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -16,7 +16,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[ if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base)) - if enable_cache and (rret := lazycache.get(cache_key, None)): return rret + if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get()) if enable_cache: lazycache[cache_key] = ret diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 01f2e19a4eb01..b5bfee512f734 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -65,9 +65,13 @@ def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x)) def eq(self, x): return self.ne(x).ne(True) def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x)) def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self) - # TODO: use this one instead def ge(self, x): return self.lt(x).ne(True) - #def ge(self, x): return (-self).lt(-x+1) + def le(self, x): return self.gt(x).ne(True) + # NOTE: __eq__/__ne__ can't be overridden, and means the same thing as is and is not + def __lt__(self, x): return self.lt(x) + def __gt__(self, x): return self.gt(x) + def __ge__(self, x): return self.ge(x) + def __le__(self, x): return self.le(x) def max(self, x): return self.alu(BinaryOps.MAX, self.ufix(x)) def min(self, x): return -(-self).max(-x) def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y) @@ -166,6 +170,16 @@ def key(self) -> bytes: return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest() def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg + # *** uop evaluation *** + def _eval(self, dtype, expected_type) -> ConstType: + assert self.dtype in dtype, f"eval with wrong dtype {self}" + vmin, vmax = self._min_max + if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax}") + assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}" + return vmin + def __bool__(self): return self._eval((dtypes.bool,), bool) + def __int__(self): return self._eval(dtypes.ints, int) + def __float__(self): return self._eval(dtypes.floats, float) # *** uop syntactic sugar @property def st_arg(self) -> ShapeTracker: @@ -284,8 +298,15 @@ def _min_max(self) -> Tuple[ConstType, ConstType]: if s1.arg < 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg) if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax) if self.arg is BinaryOps.CMPLT: return (s0.vmax s0.vmax) + sometimes_ne = not (s0.vmin == s0.vmax == s1.vmin == s1.vmax) + return (always_ne, sometimes_ne) # float has NAN issue and we use explicit NAN in transcendental if self.arg is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax) + if self.dtype is dtypes.bool: + if self.arg is BinaryOps.OR: return s0.vmin or s1.vmin, s0.vmax or s1.vmax + if self.arg is BinaryOps.AND: return s0.vmin and s1.vmin, s0.vmax and s1.vmax return dtypes.min(self.dtype), dtypes.max(self.dtype) @dataclass(frozen=True) @@ -563,12 +584,12 @@ def __init__(self, pm, ctx): self.nodes: Dict[Tuple, UOp] = {} self.replace: Dict[UOp, UOp] = {} def rewrite(self, n:UOp) -> UOp: - if rn := self.replace.get(n): return rn + if (rn := self.replace.get(n)) is not None: return rn replace_source = (n.op, n.dtype, new_src:=tuple(map(self.rewrite, n.src)), n.arg) - if found := self.nodes.get(replace_source): self.replace[n] = found + if (found := self.nodes.get(replace_source)) is not None: self.replace[n] = found else: x = UOp(*replace_source) if new_src != n.src else n - self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x + self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x return found def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: if TRACK_MATCH_STATS >= 2: diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index cb3f174d1cd32..10e3be047a539 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -144,7 +144,7 @@ def const(x:ConstType, dtype:DType, mov=False): def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): if atype == dtype or isinstance(atype, PtrDType): - if u: r[u] = a + if u is not None: r[u] = a return a kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast)) return ret diff --git a/viz/serve.py b/viz/serve.py index 1ecb8a62cfe26..fb9269a826052 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -74,7 +74,7 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: return graph def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp: - if (found:=replaces.get(base.key)): return found + if (found:=replaces.get(base.key)) is not None: return found new_srcs = tuple(replace_uop(x, replaces) for x in base.src) replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base return ret