Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Dec 3, 2024
1 parent 4a1c182 commit 6f0df91
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tinygrad/renderer/wgsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None):
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
(UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"),
UPat(op=Ops.CONST, name="c1"), UPat(op=Ops.CONST, name="c2"))))),
# TODO: why is this needed, and only for this MUL order
(UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"), UPat.cvar("c1"), UPat.cvar("c2"))))),
lambda m,a,g,c1,c2: UOp(Ops.WHERE, dtype=m.dtype, src=(g, UOp.const(dtype=dtypes.float, b=float('nan')), a))
if math.isnan(c1.arg) and c2.arg == 1.0 else None),
]) + extra_pm
Expand Down

0 comments on commit 6f0df91

Please sign in to comment.