Skip to content

Commit

Permalink
Unroll warp-specialized loops (#3547)
Browse files Browse the repository at this point in the history
When used with #3545, this
contribute a speedup of 5% of cuBLAS!

Perf together with #3545 on H100:

```
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name

 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     33.8           136319          1  136319.0  136319.0    136319    136319          0.0  <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<<unnamed>::__half, (int)3, (int)3>, <unnamed>…
     22.7            91487          1   91487.0   91487.0     91487     91487          0.0  nvjet_hsh_128x256_64x4_2x1_v_bz_coopA_NTN
```

nvFuser/cuBLAS: 67%

Note that the above test is run with smem epilogue disabled. I will run
a test with everything combined later. Also note that this number is on
H100, which is different from the H200 in
#3279.
  • Loading branch information
zasdfgbnm authored Dec 10, 2024
1 parent 4a897a4 commit 1978cf4
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3019,14 +3019,15 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}
if (loop->isUnrolled()) {
indent() << "#pragma unroll\n";
} else if (
loop->circularBufferLoopStage() == CircularBufferLoopStage::Main) {
indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth()
<< "\n";
} else if (
loop->circularBufferLoopStage() == CircularBufferLoopStage::Epilog) {
indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth() - 1
<< "\n";
} else if (
loop->circularBufferLoopStage() !=
CircularBufferLoopStage::NotApplicable) {
indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth()
<< "\n";
} else {
indent() << "#pragma unroll 1\n";
}
Expand Down

0 comments on commit 1978cf4

Please sign in to comment.