Skip to content

Commit

Permalink
merge where alu when one branch is identity element const
Browse files Browse the repository at this point in the history
this is another case that does not increase alu count
  • Loading branch information
chenyuxyz committed Dec 16, 2024
1 parent 9789a83 commit 56c1ef2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions test/unit/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,9 @@ def test_where_combine(self):
ba = cond.where(b, a)
self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))")

# not combining # TODO: can combine if one is identity element const
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
# combining if one is identity element const
self.helper_test_variable(aa+ab, 0, 6, "((a*2) if (x<2) else b)")
self.helper_test_variable(bb*ab, 0, 9, "((a*b) if (x<2) else b)")

def test_symbolic_div(self):
# from symbolic arange
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,8 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# alu of two where with same conds can combine, only do if true branch or false branch is const
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST or \
(alu.op in (Ops.MAX, Ops.ADD, Ops.MUL) and any(b.op is Ops.CONST and b.arg==identity_element(alu.op, b.dtype) for b in (t,tt,f,ff))) else None),
# ALU min==max -> CONST (slow!)
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
Expand Down

0 comments on commit 56c1ef2

Please sign in to comment.