Skip to content

Commit

Permalink
Very naive and stupid CGA support (#3557)
Browse files Browse the repository at this point in the history
This PR adds a super naive CGA support. It is by no means how we should
design CGA, and not even an incremental step. But this PR is simple
enough and it does provide us with an additional parameter to tune
about.

Perf on H100:

```
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name

 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     33.4           134047          1  134047.0  134047.0    134047    134047          0.0  <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<<unnamed>::__half, (int)3, (int)3>, <unnamed>…
     22.9            92031          1   92031.0   92031.0     92031     92031          0.0  nvjet_hsh_128x256_64x4_2x1_v_bz_coopA_NTN
```

nvFuser/cuBLAS: 68.7%
  • Loading branch information
zasdfgbnm authored Dec 10, 2024
1 parent 456b319 commit 214d598
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
11 changes: 10 additions & 1 deletion csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
code_ << "__global__ void " << kernel_name << "(";
code_ << "__global__ void ";
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
kernel_->getManaged<std::tuple<int64_t, int64_t, int64_t>>(
"cluster_dims");
code_ << "__cluster_dims__(" << std::get<0>(cluster_dims) << ", "
<< std::get<1>(cluster_dims) << ", " << std::get<2>(cluster_dims)
<< ") ";
}
code_ << kernel_name << "(";

std::unordered_set<Val*> unique_args;

Expand Down
7 changes: 7 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3664,6 +3664,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
const int64_t cta_m = 2 * getM(macro);
const int64_t cta_n = 1 * getN(macro);

constexpr std::tuple<int64_t, int64_t, int64_t> cluster_dims{2, 1, 1};

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
fusion.addInput(tv0);
Expand All @@ -3679,6 +3681,11 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

if constexpr (
cluster_dims != std::tuple<int64_t, int64_t, int64_t>{1, 1, 1}) {
fusion.manage("cluster_dims", cluster_dims);
}

auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
NVF_CHECK(
1 == mma_ops.size(),
Expand Down

0 comments on commit 214d598

Please sign in to comment.