Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Hopper MMA by warp-tile before instruction tile #3642

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Dec 24, 2024

Currently we ignore the warp tile parameter when scheduling Hopper matmuls (see #3636). This PR introduces a test with different CTA, warp, and instruction tiles and modifies the Hopper scheduler to split by warp tile in addition to instruction tile. Note that the instruction tile split results in two serial loop domain so we wind up executing multiple mma instructions in each main loop. In the included example, warp_tile is 64, 128, 16 and the macro is Hopper_64_8_16. In this case, there are 128/8 = 16 instruction tiles per warp tile so the generated main loop looks like this:

  #pragma unroll 3
  for(nvfuser_index_t i33 = 0; i33 < i4; ++i33) {
    nvfuser_index_t i34;
    i34 = 48 + (16 * i33);
    nvfuser_index_t i35;
    i35 = (3 + i33) % 4;
    unsigned i36;
    i36 = i7 + (8192 * i35);
    unsigned i37;
    i37 = i10 + (4096 * i35);
    nvfuser_index_t i38;
    i38 = i33 % 4;
    unsigned i39;
    i39 = i13 + (4096 * i38);
    uint64_t i40;
    i40 = 4611686293305294848ULL | ((262143ULL & (uint64_t)(i39)) >> 4ULL);
    unsigned i41;
    i41 = i15 + (8192 * i38);
    if (((Hopper::electSync(4294967295U) && b22) && b23)) {
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 8192U);
      #pragma unroll
      for(nvfuser_index_t i31 = 0; i31 < 4; ++i31) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(i6 + (64 * i31)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i36 + (2048 * i31)));
      }
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 4096U);
      #pragma unroll
      for(nvfuser_index_t i32 = 0; i32 < 2; ++i32) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(i9 + (64 * i32)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i37 + (2048 * i32)));
      }
    }
    mbarrier::waitParity(toSmem((&T8[(i33 % 4)])), (uint32_t)(((i33 / 4) % 2)));
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 16; ++i25) {
      unsigned i42;
      i42 = (i41 + (2048 * (i25 / 8))) + (16 * (i25 % 8));
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %6, 0;\n"
        "  wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 {%0, %1, %2, %3}, %4, %5, p0, %7, %8, %9, %10;\n"
        "}\n"
        :"+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[0]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[1]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[2]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[3])
        :"l"(i40),
         "l"((4611686293305294848ULL | ((262143ULL & (uint64_t)(i42)) >> 4ULL))),
         "n"((uint32_t)(true)),
         "n"(1),
         "n"(1),
         "n"(1),
         "n"(1)
      );
    }
    __syncthreads();
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  }

Fixes #3636

@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

The bank conflict came from stmatrix scheduling which needs to be updated. I will do that in a separate PR. For now, I've disabled smem epilogue in the included test.

@jacobhinkle jacobhinkle marked this pull request as ready for review December 31, 2024 13:47
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Dec 31, 2024

When I manually disable stmatrix but keep TMA store, I still hit a bank conflict and misaligned address in the smem read when doing the TMA store. The epilogue looks like this:

  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i50 = 0; i50 < 16; ++i50) {
    nvfuser_index_t i51;
    i51 = 4 * i50;
    #pragma unroll
    for(nvfuser_index_t i52 = 0; i52 < 2; ++i52) {
      nvfuser_index_t i53;
      i53 = i51 + (2 * i52);
      Array<__half, 2, 2> T6;
      #pragma unroll
      for(nvfuser_index_t i54 = 0; i54 < 2; ++i54) {
        T6[i54]
           = __float2half(T2[(i53 + i54)]);
      }
      loadGeneric<__half, 2>( &T7[(i17 + (128 * i52))],  &T6[0]);
    }
    __syncthreads();
    asm volatile("fence.proxy.async;\n");
    if (b24) {
      Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr19, (Array<nvfuser_index_t, 2, 1>{(i20 + (8 * i50)), i21}) }), i18);
    }
    __syncthreads();
    asm volatile("cp.async.bulk.commit_group;\n");
    asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");
  }
  asm volatile("cp.async.bulk.commit_group;\n");
  asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");

The misaligned read happens with i20 = 1152, i50 = 0, i21 = 320, i18 = 3088. Note that we have

  threadIdx.y = 3;
  i11 = ((nvfuser_index_t)threadIdx.y) / 2; // =1
  i12 = 2048 * i11; // =2048
  i14 = ((nvfuser_index_t)threadIdx.y) % 2; // =1
  i18 = (toSmem(T7) + i12) + (16 * i14); // =toSmem(T7) + 2064

CUDA Exception: Warp Misaligned Address

@jacobhinkle
Copy link
Collaborator Author

mma result before this PR:

T2_l_float[iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y61{64}, iS58{64}, iS60{8}, rS52{16}]
 root domain : (rS6{i0}, iS7{i1}, iS8{i6})
 logical domain : (iS7{i1}, iS8{i6}, rS6{i0})
 contiguity: t t n
  Split: iS7{i1} by factor 128 -> iblockIdx.y55{( ceilDiv(i1, 128) )}, iS56{128}
  Split: iS8{i6} by factor 256 -> iblockIdx.x53{( ceilDiv(i6, 256) )}, iS54{256}
  Split: rS6{i0} by factor 16 -> rS51{( ceilDiv(i0, 16) )}, rS52{16}
  Split: iS56{128} by factor 64 -> iS57{2}, iS58{64}
  Split: iS54{256} by factor 8 -> iS59{32}, iS60{8}
  Merge: iS57{2} and iS59{32} -> ithreadIdx.y61{64}
 loop domain : (iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y61{64}, iS58{64}, iS60{8}, rS52{16})

And after this PR:

T2_l_float[iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y65{4}, iS59{1}, iS63{16}, iS60{64}, iS64{8}, rS52{16}]
 root domain : (rS6{i0}, iS7{i1}, iS8{i6})
 logical domain : (iS7{i1}, iS8{i6}, rS6{i0})
 contiguity: t t n
  Split: iS7{i1} by factor 128 -> iblockIdx.y55{( ceilDiv(i1, 128) )}, iS56{128}
  Split: iS8{i6} by factor 256 -> iblockIdx.x53{( ceilDiv(i6, 256) )}, iS54{256}
  Split: rS6{i0} by factor 16 -> rS51{( ceilDiv(i0, 16) )}, rS52{16}
  Split: iS56{128} by factor 64 -> iS57{2}, iS58{64}
  Split: iS54{256} by factor 128 -> iS61{2}, iS62{128}
  Merge: iS57{2} and iS61{2} -> ithreadIdx.y65{4}
  Split: iS58{64} by factor 64 -> iS59{1}, iS60{64}
  Split: iS62{128} by factor 8 -> iS63{16}, iS64{8}
 loop domain : (iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y65{4}, iS59{1}, iS63{16}, iS60{64}, iS64{8}, rS52{16})

@jacobhinkle
Copy link
Collaborator Author

Note that I can enable smem epilogue and the test passes if I use Hopper_64_64_16 and I disable stmatrix.

Comment on lines +47 to +49
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 is this enough or is #3616 still needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is all that is required for scheduler changes.

// size
// Original: [..., M, N(, K)]
// We split this into warp tiles then instruction tiles
if (is_mma_result) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: since there is no code in common between these branches, we should split this into two separate functions.

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove this limitation to handle all matmul parameter configurations?

CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]

Comment on lines +47 to +49
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is all that is required for scheduler changes.

@rdspring1
Copy link
Collaborator

rdspring1 commented Jan 2, 2025

I see C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/executor.cpp":1421, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_INVALID_VALUE failed with error invalid argument with warp specialization enabled in test HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile

@jacobhinkle
Copy link
Collaborator Author

I see C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/executor.cpp":1421, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_INVALID_VALUE failed with error invalid argument with warp specialization enabled in test HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile

I just checked this by modifying the test to expect 5 warpgroups instead of four and by adding a WarpSpecialized(ParallelType::TIDy) argument to these circularBuffer calls:

acw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
}
for (TensorView* bcw_smem : bcw_smems_) {
bcw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);

The result for me is a passing test but perf drops.

@jacobhinkle
Copy link
Collaborator Author

Do we need to remove this limitation to handle all matmul parameter configurations?

CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]

I might be confused here. The thing is that the K dimension is treated differently from the M and N dimensions in these tile definitions. The instruction tile's K dimension is clear, and the warp tile's K dimension (I think) signifies how much data we should load at a time then we can loop to compute instructions over all the loaded data. The CTA tile's M and N dimensions specify the tiling of the output, but what does the cta_tile.k signify? This is why I was thinking we'd keep this restriction.

Note that this restriction cta_tile.k == warp_tile.k is enforced on Ampere as part of scheduleWarpTileWithReduction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Split by warp tile in Hopper matmul scheduler
2 participants