Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warp specialization as a circular buffering type #3511

Merged
merged 42 commits into from
Dec 5, 2024

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Dec 2, 2024

This PR adds warp specialization as a new type of circular buffering.

Today, we already support pipelined circular buffer, and optionally, we could choose whether we want to use block-sync or mbarrier for handling WAR hazards. If we choose to use mbarrier for handling WAR harzard, then we will generate kernel like below:

# Mark buffer[i] as empty and ready to be loaded
for i in range(prefetch + 1):
  arrive(war_mbarrier[i])
# Prologue: thanks to the previous arrives, all the loads will just go through and no wait needed
for i in range(prefetch):
  wait war_mbarrier[i]
  arrive-expect-tx raw_mbarrier[i]
  load data[i] into buffer[i]
# Main loop:
for i in range(data.size - prefetch):
  if elect-sync:
    wait war_mbarrier[(i + prefetch) % stage]
    arrive-expect-tx raw_mbarrier[(i + prefetch) % stage]
    load data[i + prefetch] to buffer[(i + prefetch) % stage]
  wait raw_mbarrier[i % stage]
  mma on buffer[i % stage] for data[i]
  wait until there are at most stage - prefetch - 1 pending mma
  arrive war_mbarrier[(i + prefetch + 1) % stage]
# Epilogue
for i in range(data.size - prefetch, data.size):
  wait raw_mbarrier[i % stage]
  mma on buffer[i % stage] for data[i]
wait until there are at most 0 pending mma
write result back to gmem

The kernel above has the following problems:

  1. The MMA loop is not clean. There is one thread doing an extra work of loading, while other threads in the warp groups just waiting this one thread to finish. (Note that mma is a warp-group collective, so all threads in the warp group must arrive that instruction for it to start). Ideally, we should have a for loop with only mma, and nothing else. Having extra instructions could increase the latency.
  2. There is a false dependency between the loading of data[i + prefetch] and the computing of data[i]. These two things are not dealing with the same data, so in theory, they should not depend on each other, and whoever gets its mbarrier cleared first should go first. However, just because codes are executed from top to bottom, the mma has to wait until the load is issued. This further increases latency.

With the above problem observed, it is naturally to ask: why not use different warps for load and compute? The load code and the compute code in the main loop are completely independent, and both the RAW and WAR are handled by mbarrier, which is on smem and accessible across the entire CTA, so all the preconditions for warp specialization are mature, and we just need to put different IR nodes into different places.

This PR adds warp specialization. The generated code is similar to the pipelined code that uses mbarrier for WAR, but actually simpler. The code looks like below (assuming doing warp specialization on TIDy):

if threadIdx.y == blockDim.y - 1:
  # If we use warp specialization on TIDy, then the blockDim.y of the
  # kernel will be (whatever_value_inferred_from_schedule + 1), and the
  # last threadIdx.y will be used as load warp
  for i in range(data.size):
    wait war_mbarrier[i % stage]
    load data[i] to buffer[i % stage]
else:
  # Every threadIdx.y other than the last will be used for compute
  for i in range(prefetch + 1):
    arrive war_mbarrier[i % stage]
  for i in range(data.size):
    wait raw_mbarrier[i % stage]
    compute buffer[i % stage]
    wait until there are at most stage - prefetch - 1 pending mma
    arrive war_mbarrier[(i + prefetch + 1) % stage]

This new way of doing circular buffering is intended to be computation-agnostic, it should work on whatever kernel we are scheduling, instead of just matmuls. But note that today, there are some strong limitations that makes it less applicable:

  1. The computation can not have hardcoded blockDim in it. So block reduction will not work. I believe this will be easy to fix, but it is beyond the scope of this PR.
  2. Because the warp-specialized parallel type will no longer be exact, there will be thread predicates generated for it. Predication elimination is not yet smart enough to know that this is in the compute warp, so already predicated and not need to predicate it again. This limitation also means, the computation can not be tensor core operations (MmaOp), so this PR actually does not work with matmul.

Besides the above limitation, I believe this new circular buffer type is pretty generic, and in the future, I believe we should be able to try it with TMA in perf tuning.

@zasdfgbnm zasdfgbnm marked this pull request as ready for review December 4, 2024 00:46
@zasdfgbnm
Copy link
Collaborator Author

!test

@zasdfgbnm zasdfgbnm requested review from naoyam and rdspring1 December 4, 2024 00:46
csrc/device_lower/pass/predicate.cpp Outdated Show resolved Hide resolved
tests/cpp/test_circular_buffering.cpp Outdated Show resolved Hide resolved
//! parallelized IterDomains have extent 16, and there is no TIDz
//! parallelization, then we will have:
//! blockDim = (x=32, y=17, z=1)
//! And this function will return (32 * 16) because the extra one for TIDy is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it is less intuitive to the developer that an extra TIDy is added to the CTA than taking one TIDy for circular buffer loads.

e.g., If you wanted to use all 1024 threads, you specify (x=32, y=31, z=1) rather than (x=32, y=32, z=1).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to toggle between warp-specialized, you'd may have to change the schedule when using all 1024 threads.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it is less intuitive to the developer that an extra TIDy is added to the CTA

This part, I agree. I dislike the fact that the parallel dimension is not attached to an IterDomain but having one mystery +1. I believe this is just a temporary solution, because I can not come up with a better solution that can be implemented in short amount of time. In the future, likely we will be able to slice on parallel types, and select a slice for load and another slice for compute.

taking one TIDy for circular buffer loads.

How will that work? If you are already using all 1024 threads to do computation, then you can not take any of them out to load.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this design choice of adding an extra warp. Since we're cloning loops, we have to schedule the TV from the perspective of the pipelined approach. To get warp specialization, we have to add an extra warp for loads.

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

My understanding of the remaining TODOs for warp-specialization.

  • Fix predicates to be warp-specialization aware. [Required for MMA]
  • Fix block-sync and block-reduction to be aware of active threads in if-then-else block. [For reduction and persistent kernels]
  • Make elect-sync pick a thread in the active warps in the if-then-else block. [Necessary for using elect-sync with TMA-store]
  • Decouple warp-specialization from TMA loads. [Nice-to-have]

@zasdfgbnm
Copy link
Collaborator Author

LGTM.

My understanding of the remaining TODOs for warp-specialization.

  • Fix predicates to be warp-specialization aware.
  • Fix block-sync and block-reduction to be aware of active threads in if-then-else block.
  • Make elect-sync pick a thread in the active warps in the if-then-else block.

For the first two, yes.

The first one is the top priority that unblocks tensor core.

The second one should be a quick fix that I will likely go after it after this PR is merged. I don't think we urgently need it now, but I prefer to grind out sharp edges as much as possible if it is quick to grind.

For the third one, with this PR, elect-sync is already doing the correct thing if it is in the load warp or outside a warp-specialized region even if we are warp-specializing on TIDx. But it will cause hang if used in the compute warp. Likely I will not fix anything, because I don't think we will use elect-sync on compute warp.

@zasdfgbnm
Copy link
Collaborator Author

TMA store is outside the warp-specialized region. I expect it would just work with elect-sync.

@zasdfgbnm
Copy link
Collaborator Author

!test

@rdspring1
Copy link
Collaborator

I tried half-heartedly to apply elect-sync to tma store and got incorrect results from the kernel.

@zasdfgbnm
Copy link
Collaborator Author

I tried half-heartedly to apply elect-sync to tma store and got incorrect results from the kernel.

Thanks for mentioning. I feel surprised. We need to look into it.

@rdspring1
Copy link
Collaborator

Do we know if register sharing from tma warp groups to compute warp groups will be necessary?

// CTA (BDX=128, BDY=3, BDZ=1)
if (threadIdx.y < (blockDim.y-1)) {
    asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
    tma-load(...);
} else {
    asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}");
    mma-compute(...);
}

@zasdfgbnm
Copy link
Collaborator Author

Do we know if register sharing from tma warp groups to compute warp groups will be necessary?

// CTA (BDX=128, BDY=3, BDZ=1)
if (threadIdx.y < (blockDim.y-1)) {
    asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
    tma-load(...);
} else {
    asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}");
    mma-compute(...);
}

I naively only checked

// CTA (BDX=128, BDY=3, BDZ=1)
if (threadIdx.y < (blockDim.y-1)) {
    asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
    tma-load(...);
} else {
    // no register increase
    mma-compute(...);
}

and there is little help

@rdspring1
Copy link
Collaborator

Do we know if register sharing from tma warp groups to compute warp groups will be necessary?

I think it is mostly an occupancy optimization, so only kernels with high register pressure see benefit.

IIUC, you have to use setmaxnreg.dec and setmaxnreg.inc together to move registers from load to compute warp groups on the SM.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zasdfgbnm zasdfgbnm merged commit 64bc560 into main Dec 5, 2024
48 checks passed
@zasdfgbnm zasdfgbnm deleted the warp-specialization-submit branch December 5, 2024 09:43
@zasdfgbnm
Copy link
Collaborator Author

Do we know if register sharing from tma warp groups to compute warp groups will be necessary?

// CTA (BDX=128, BDY=3, BDZ=1)
if (threadIdx.y < (blockDim.y-1)) {
    asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
    tma-load(...);
} else {
    asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}");
    mma-compute(...);
}

Register stealing seems to help a lot. Still need to verify my experiment, but I just get:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name

 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     30.2           115679          1  115679.0  115679.0    115679    115679          0.0  <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<<unnamed>::__half, (int)3, (int)3>, <unnamed>…
     23.9            91647          1   91647.0   91647.0     91647     91647          0.0  nvjet_hsh_128x256_64x4_2x1_v_bz_coopA_NTN

which is 79% of cuBLAS!

@rdspring1
Copy link
Collaborator

Do we know why register stealing helps? e.g., greater occupancy and active warps.

@zasdfgbnm
Copy link
Collaborator Author

Do we know why register stealing helps? e.g., greater occupancy and active warps.

I don't know, but my first step is to figure out why is the wrong result error (likely because of #3561), and with that problem fixed, what speedup we can get.

@zasdfgbnm
Copy link
Collaborator Author

Looks like after fixing #3561, the perf benefit of register stealing disappear...

@naoyam
Copy link
Collaborator

naoyam commented Dec 10, 2024

Oh, no, don't fix it 😆

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants