Skip to content

Commit

Permalink
tensor cores
Browse files Browse the repository at this point in the history
  • Loading branch information
SzymonOzog committed Mar 22, 2024
1 parent 3202cbc commit d4dd28f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
],
"CUDA": [
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
if getenv("PTX") else "__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
],
}

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,6 @@ def flops_mem(self) -> Tuple[sint, sint]:
elif u.uop is UOps.WMMA:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
elif "m16n8k16"in u.arg: flops += 2*(8*16*16)//32 * mults
else: raise Exception("not implemented")
return flops, mem
8 changes: 8 additions & 0 deletions tinygrad/renderer/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if lang.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma=[]
for v in vin[:2]:
for i in range(0, len(r[v]), 2): # type: ignore
wmma.append(ssa(None, "wmma", "b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[v][i:i+2])}}};') # type: ignore
r[u] = r[vin[2]]
kk(f'{args} {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
else: raise NotImplementedError(f"no code for {uop}")

return lang.render_kernel(kernel, function_name, bufs, c.items())
Expand Down

0 comments on commit d4dd28f

Please sign in to comment.