From 176964abf5f0f7fed78f42a7f3532229746d49dd Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 4 Dec 2024 20:45:56 -0500 Subject: [PATCH] always rewrite MAX to CMPLT + WHERE [pr] we are already doing this, and all max folding rules is equivalent in where rules --- test/unit/test_uop_symbolic.py | 1 - tinygrad/ops.py | 3 ++- tinygrad/renderer/cstyle.py | 6 ++---- tinygrad/renderer/llvmir.py | 2 -- tinygrad/renderer/ptx.py | 1 - 5 files changed, 4 insertions(+), 9 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 0f182a56653f..839658dc344d 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -496,7 +496,6 @@ def test_where_combine(self): self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)") self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)") self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)") - self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)") # not combining because it increased total ALU c = Variable("c", 0, 3) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 810ed858f0c4..18912f4ac24c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1257,10 +1257,11 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, # ALU min==max -> CONST (slow!) (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding - (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # TODO: why does this rule break beautiful_mnist? #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z), ((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const), + # after possible foldings, rewrite MAX to CMPLT + WHERE + (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), # ** two stage ALU folding ** *((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"), lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative), diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 4d7aaa8c23e4..5431753586da 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -56,8 +56,6 @@ # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? (UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None), - # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends) - (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), ]) def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) @@ -367,9 +365,9 @@ class AMDRenderer(CStyleLanguage): # language options ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]] - ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr) + ocml = [(f"__ocml_{name}_f{n}", dt, dt, atr) for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]] - for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]] + for name, atr in [("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]] kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml) kernel_prefix += '\nextern "C" __attribute__((global))' diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index c90e76da8acb..0155af36ff2b 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -85,8 +85,6 @@ class LLVMRenderer(Renderer): (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))), # rewrite cast to bool to CMPNE 0 (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)), - # rewrite MAX to CMPLT + WHERE - (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), ]) def render(self, name: str, uops: List[UOp]) -> str: diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 5be5c1c1f3f7..d1289709a11e 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -25,7 +25,6 @@ def render_val(x, dtype): Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if dt == dtypes.bool else f"and.b{name[1:]} {d}, {a}, {b};", Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};", Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", - Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};", Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};", Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \