Skip to content

Commit

Permalink
fix acc folding for NV tensor cores (tinygrad#5658)
Browse files Browse the repository at this point in the history
* fix acc folding for NV tensor cores

* fix correctness of reduce_before_expand
  • Loading branch information
geohot authored Jul 23, 2024
1 parent 01fe00e commit 4d47968
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
26 changes: 26 additions & 0 deletions extra/gemm/tinygrad_nv_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from tinygrad import Tensor, dtypes, Device
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem

N = 4096
if __name__ == "__main__":
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
C = A.matmul(B, acc_dtype=dtypes.float32)
si = C.schedule()[-1]
ast = si.ast
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
opts = [Opt(op=OptOps.TC, axis=0, amt=0),
Opt(op=OptOps.UPCAST, axis=1, amt=16),
Opt(op=OptOps.UPCAST, axis=0, amt=2),
Opt(op=OptOps.LOCAL, axis=0, amt=4),
Opt(op=OptOps.UNROLL, axis=0, amt=4),
Opt(op=OptOps.LOCAL, axis=1, amt=2),
]
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
tflops = []
for i in range(5):
tm = ei.run(wait=True)
tflops.append((2*N*N*N/tm)*1e-12)
print(f"TFLOPS: {sum(tflops)/len(tflops):.2f}")
3 changes: 2 additions & 1 deletion tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
elif opt.op is OptOps.UPCAST: # yellow
check(axis < self.first_reduce, "upcast is for non-reduce")
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
check(amt <= 8, "don't upcast more than 8")
check(amt <= 16, "don't upcast more than 16")
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op is OptOps.UPCASTMID: # white
Expand Down Expand Up @@ -729,6 +729,7 @@ def linearize(self) -> Kernel:
if DEBUG >= 3:
print(self.name)
print(modified_ast)
print(self.applied_opts)
verify_lazyop(modified_ast)

uop_sink = lazyop_to_uop(modified_ast, self.opts)
Expand Down
10 changes: 6 additions & 4 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def threefry2x32(x: UOp, seed: UOp):
# ***** main rewriter *****

def reduce_before_expand(reduce_allow_any_len, expand, x):
# if the expand is being reduced, you can't push it through
# NOTE: could do a partial push here in some cases
expands = flatten([x.arg for x in reduce_allow_any_len.src[1:] if x.op is UOps.EXPAND])
if any(x in expands for x in expand.arg): return None
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce_allow_any_len.src[1:], reduce_allow_any_len.arg)
gep = tuple(UOp(UOps.GEP, reduce_allow_any_len.dtype, (red,), i) for i in range(x.dtype.count))
return UOp(expand.op, expand.dtype, gep, expand.arg)
Expand Down Expand Up @@ -154,10 +158,8 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce_allow_any_len):
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
# tensor core cleanups
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand),
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand),
*[(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(j))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand) for j in [2,4,8]],
(UOp.var("add") + UOp(UOps.WMMA).name("wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
Expand Down

0 comments on commit 4d47968

Please sign in to comment.