Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stmatrix in the unit test HopperMatmulTest/HSH_NT_128BSwizzle #3411

Merged
merged 17 commits into from
Nov 23, 2024

Conversation

protonu
Copy link
Collaborator

@protonu protonu commented Nov 15, 2024

This demonstrates the use of stmatrix in a multi-tile hopper matmul.

@protonu protonu requested a review from zasdfgbnm November 15, 2024 01:49
@protonu protonu force-pushed the pbasu_wip_stmatrix_HSH_NT_128BSwizzle branch from aad4d2d to a7eb8fd Compare November 15, 2024 20:30
@protonu protonu changed the title prototype Add support for stmatrix in the unit test HopperMatmulTest/HSH_NT_128BSwizzle Nov 15, 2024
@protonu protonu marked this pull request as ready for review November 15, 2024 20:32
tests/cpp/test_matmul.cpp Outdated Show resolved Hide resolved
fusion.addOutput(tv4);

// We'll use stmatrix.x4 to store from reg to shared memory
fusion.manage("st_matrix_m_tile", 16);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear, does this mean we must use a macro such that getN(macro) is a multiple of 16?

Copy link
Collaborator Author

@protonu protonu Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I'd phrase as use a tile that can divide the macro.
For the m-dimension this shouldn't be a problem since it's always 64 for Hopper.
For the n-dimension, if it's not a multiple of 16, it has to a multiple of 8, and we'll have to use stmatrix.x2 where st_matrix_n_tile will be 8.

tests/cpp/test_matmul.cpp Show resolved Hide resolved
tests/cpp/test_matmul.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also post some perf data here with smem epilogue on? The command I use is posted here: #3279

csrc/device_lower/pass/index.cpp Outdated Show resolved Hide resolved
tests/cpp/test_matmul.cpp Show resolved Hide resolved
@protonu protonu force-pushed the pbasu_wip_stmatrix_HSH_NT_128BSwizzle branch from 9fdca30 to c088f18 Compare November 19, 2024 18:05
@protonu
Copy link
Collaborator Author

protonu commented Nov 19, 2024

Observations:

  1. On running getBankConflictInfo, I see both cases (mma with and without a stmatrix) have zero bank conflicts. The codes I used was:
  KernelExecutor ke;
  auto launch_constraints = LaunchParams();
  ke.compile(
      &fusion,
      {inputs.first, inputs.second},
      launch_constraints,
      matmul_cparams);

  auto bank_conflict_info =
      getBankConflictInfo(ke.kernel(), launch_constraints);

  if (bank_conflict_info.empty()) {
    debug() << "===== No bank confliction =====" << std::endl;
  } else {
    debug() << "======= Bank confliction =======" << std::endl;
    for (auto info : bank_conflict_info) {
      debug() << "Expr: " << info.first->toString() << std::endl;
      auto conflict = info.second;
      if (conflict.first > 1) {
        debug() << "input conflict: " << conflict.first << " way, ";
      }
      if (conflict.second > 1) {
        debug() << "output conflict: " << conflict.second << " way";
      }
      debug() << std::endl;
    }
    debug() << "================================" << std::endl;
  }
  fusion.printKernel();
  1. Performance sees a minor degradation:

Without stmatrix
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name


 36.8           156830          1  156830.0  156830.0    156830    156830          0.0  <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<<unnamed>::__half, (int)3, (int)3>, <unnamed>…
 23.1            98654          1   98654.0   98654.0     98654     98654          0.0  nvjet_hsh_256x128_64x4_1x2_h_bz_coopA_NTT

Where stmatrix was about 63%

With stmatrix this falls to about 59-60%

TODO: Use nsight to look at bank conflict data.

Base automatically changed from pbasu_new_wip_mma_stmatrix to main November 22, 2024 20:49
@protonu
Copy link
Collaborator Author

protonu commented Nov 23, 2024

!build

@protonu protonu merged commit caa7f07 into main Nov 23, 2024
17 checks passed
@protonu protonu deleted the pbasu_wip_stmatrix_HSH_NT_128BSwizzle branch November 23, 2024 23:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants