diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 71653e49502cf..1be6b68669743 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -61,7 +61,8 @@ def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 ], "CUDA": [ - TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501 + TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + if getenv("PTX") else "__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501 ], } diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 484098089e9d3..84f145428c28d 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -385,6 +385,6 @@ def flops_mem(self) -> Tuple[sint, sint]: elif u.uop is UOps.WMMA: if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults - elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults + elif "m16n8k16"in u.arg: flops += 2*(8*16*16)//32 * mults else: raise Exception("not implemented") return flops, mem diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index d7965047d458d..05aefad239b10 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -225,6 +225,14 @@ def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): if lang.load_global: dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param")) + elif uop is UOps.WMMA: + wmma=[] + for v in vin[:2]: + for i in range(0, len(r[v]), 2): # type: ignore + wmma.append(ssa(None, "wmma", "b32")) + kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[v][i:i+2])}}};') # type: ignore + r[u] = r[vin[2]] + kk(f'{args} {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};') else: raise NotImplementedError(f"no code for {uop}") return lang.render_kernel(kernel, function_name, bufs, c.items())