Skip to content

Commit

Permalink
always rewrite MAX to CMPLT + WHERE [pr]
Browse files Browse the repository at this point in the history
we are already doing this, and all max folding rules is equivalent in where rules
  • Loading branch information
chenyuxyz committed Dec 16, 2024
1 parent befa864 commit 176964a
Showing 5 changed files with 4 additions and 9 deletions.
1 change: 0 additions & 1 deletion test/unit/test_uop_symbolic.py
Original file line number Diff line number Diff line change
@@ -496,7 +496,6 @@ def test_where_combine(self):
self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)")
self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)")
self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)")
self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)")

# not combining because it increased total ALU
c = Variable("c", 0, 3)
3 changes: 2 additions & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
@@ -1257,10 +1257,11 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
# ALU min==max -> CONST (slow!)
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# TODO: why does this rule break beautiful_mnist?
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
# after possible foldings, rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# ** two stage ALU folding **
*((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
6 changes: 2 additions & 4 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
@@ -56,8 +56,6 @@
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(Ops.BITCAST, name="x"),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])

def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
@@ -367,9 +365,9 @@ class AMDRenderer(CStyleLanguage):

# language options
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
ocml = [(f"__ocml_{name}_f{n}", dt, dt, atr)
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
for name, atr in [("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]

kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
kernel_prefix += '\nextern "C" __attribute__((global))'
2 changes: 0 additions & 2 deletions tinygrad/renderer/llvmir.py
Original file line number Diff line number Diff line change
@@ -85,8 +85,6 @@ class LLVMRenderer(Renderer):
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
# rewrite cast to bool to CMPNE 0
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
# rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])

def render(self, name: str, uops: List[UOp]) -> str:
1 change: 0 additions & 1 deletion tinygrad/renderer/ptx.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@ def render_val(x, dtype):
Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if dt == dtypes.bool else f"and.b{name[1:]} {d}, {a}, {b};",
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};",
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};",
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \

0 comments on commit 176964a

Please sign in to comment.