From 199b3bf02b4626fc1b2e06de073b3624e59b526a Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 23 Jul 2024 14:11:05 -0400 Subject: [PATCH] simple UOp lt/ge folding (#5657) works if lhs is a DEFINE_VAR. folds trivial x < -math.inf now, need to change SPECIAL to use DEFINE_VAR to fold more --- test/unit/test_uop_symbolic.py | 6 +----- tinygrad/codegen/uopgraph.py | 3 +++ tinygrad/codegen/uops.py | 11 ++++++++++- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 8392dc0a54de7..2352b7dc32e4a 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -30,8 +30,7 @@ def Variable(expr, nmin, nmax): # TODO: fix DEFINE_VAR to not need this class TempVar: def __init__(self, x): self.expr = x - #return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr)) - return UOp(UOps.DEFINE_VAR, dtypes.int, tuple(), TempVar(expr)) + return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr)) class Node: @staticmethod def sum(ops): return functools.reduce(lambda x,y: x+y, ops) @@ -73,7 +72,6 @@ def test_ge(self): self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1") self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1") - @unittest.expectedFailure def test_lt(self): self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1") self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1") @@ -272,11 +270,9 @@ def test_big_mod(self): self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)") #self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") - @unittest.expectedFailure def test_ge_remove(self): self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0") - @unittest.expectedFailure def test_lt_remove(self): self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0") self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 73efa2a45681b..559cfe9c9e172 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -236,6 +236,9 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce_allow_any_len): # NOTE: this can be wrong for loaded NaN (UOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), (UOp.var('x') - UOp.var('x'), lambda x: x.const(0)), # x-x -> 0 + # lt folding + (UOp.var('x').lt(UOp.cvar('c')), + lambda x,c: UOp.const(dtypes.bool, True) if x.vmax.arg < c.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= c.arg else None), # ** load/store folding ** (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), # ** two stage add/sub folding ** diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index a047d5709665f..4cc90ffc74c86 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict -import functools, itertools +import functools, itertools, math from collections import defaultdict from enum import Enum, auto from dataclasses import dataclass @@ -73,6 +73,7 @@ def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const @staticmethod @functools.lru_cache(maxsize=None) def _const(dtype:Optional[DType], b:ConstType|Variable): + # TODO: min/max for const Variable? if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b) return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) @staticmethod @@ -97,6 +98,14 @@ def divides(self, v): if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src) if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src) return False # generic false if we aren't sure + @functools.cached_property + def vmax(self) -> UOp: + if self.op is UOps.DEFINE_VAR: return self.src[1] + return UOp.const(dtypes.float, math.inf) + @functools.cached_property + def vmin(self) -> UOp: + if self.op is UOps.DEFINE_VAR: return self.src[0] + return UOp.const(dtypes.float, -math.inf) class UPat: def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,