Skip to content

Commit

Permalink
simple UOp lt/ge folding (tinygrad#5657)
Browse files Browse the repository at this point in the history
works if lhs is a DEFINE_VAR.
folds trivial x < -math.inf now, need to change SPECIAL to use DEFINE_VAR to fold more
  • Loading branch information
chenyuxyz authored Jul 23, 2024
1 parent b0fc5a4 commit 199b3bf
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
6 changes: 1 addition & 5 deletions test/unit/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def Variable(expr, nmin, nmax):
# TODO: fix DEFINE_VAR to not need this
class TempVar:
def __init__(self, x): self.expr = x
#return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
return UOp(UOps.DEFINE_VAR, dtypes.int, tuple(), TempVar(expr))
return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
Expand Down Expand Up @@ -73,7 +72,6 @@ def test_ge(self):
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")

@unittest.expectedFailure
def test_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
Expand Down Expand Up @@ -272,11 +270,9 @@ def test_big_mod(self):
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")

@unittest.expectedFailure
def test_ge_remove(self):
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")

@unittest.expectedFailure
def test_lt_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
Expand Down
3 changes: 3 additions & 0 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce_allow_any_len):
# NOTE: this can be wrong for loaded NaN
(UOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
(UOp.var('x') - UOp.var('x'), lambda x: x.const(0)), # x-x -> 0
# lt folding
(UOp.var('x').lt(UOp.cvar('c')),
lambda x,c: UOp.const(dtypes.bool, True) if x.vmax.arg < c.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= c.arg else None),
# ** load/store folding **
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
# ** two stage add/sub folding **
Expand Down
11 changes: 10 additions & 1 deletion tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
import functools, itertools
import functools, itertools, math
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
Expand Down Expand Up @@ -73,6 +73,7 @@ def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const
@staticmethod
@functools.lru_cache(maxsize=None)
def _const(dtype:Optional[DType], b:ConstType|Variable):
# TODO: min/max for const Variable?
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
@staticmethod
Expand All @@ -97,6 +98,14 @@ def divides(self, v):
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
return False # generic false if we aren't sure
@functools.cached_property
def vmax(self) -> UOp:
if self.op is UOps.DEFINE_VAR: return self.src[1]
return UOp.const(dtypes.float, math.inf)
@functools.cached_property
def vmin(self) -> UOp:
if self.op is UOps.DEFINE_VAR: return self.src[0]
return UOp.const(dtypes.float, -math.inf)

class UPat:
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
Expand Down

0 comments on commit 199b3bf

Please sign in to comment.