Skip to content

Commit

Permalink
clean up wgsl_matcher [pr] (tinygrad#8015)
Browse files Browse the repository at this point in the history
use more UPat syntatic sugar and remove unneeded rules
  • Loading branch information
chenyuxyz authored Dec 3, 2024
1 parent db330a3 commit a5af4e5
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions tinygrad/renderer/wgsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,15 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None):
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)

wgsl_matcher = PatternMatcher([
(UPat(Ops.CMPLT, src=(UPat(name="a", dtype=dtypes.bool), UPat(name="b")), name="c"),
lambda a,b,c: UOp(c.op, c.dtype, (a.cast(dtypes.int), b.cast(dtypes.int)))),
(UPat(Ops.XOR, dtype=dtypes.bool, src=(UPat(name="a"), UPat(name="b")), name="c"),
lambda a,b,c: UOp(c.op, dtypes.int, (a.cast(dtypes.int), b.cast(dtypes.int))).cast(dtypes.bool)),
*[(UPat(a, src=(UPat(name="b", dtype=(dtypes.uint, dtypes.int, dtypes.bool))), name="a"),
lambda a,b: UOp(a, dtypes.float, (b.cast(dtypes.float),)).cast(b.dtype)) for a in (Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.SQRT)],
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat(name="b")), name="c"),
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None),
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
(UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"),
UPat(op=Ops.CONST, name="c1"), UPat(op=Ops.CONST, name="c2"))))),
lambda m,a,g,c1,c2: UOp(Ops.WHERE, dtype=m.dtype, src=(g, UOp.const(dtype=dtypes.float, b=float('nan')), a))
if math.isnan(c1.arg) and c2.arg == 1.0 else None),
# TODO: why is this needed, and only for this MUL order
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
]) + extra_pm

type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",
Expand Down

0 comments on commit a5af4e5

Please sign in to comment.