Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit unrolling of all circular buffered loops to depth equal to prefetch #3627

Merged
merged 6 commits into from
Dec 24, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Dec 20, 2024

Currently for dynamic shapes with circular buffered loops we unroll the following loops to different depths:

  • epilogue: stages - 1 supposedly, but often specified as #pragma unroll probably due to use of ensureStaticIndexing in the indexing pass since this loop always has constant extent.
  • main loop: unrolled as #pragma unroll stages
  • prologue: fully unrolled #pragma unroll similar to epilogue.

This PR unrolls each of these loops explicitly by #pragma prefetch where prefetch is the circular buffering prefetch distance which is usually set to stages - 1.

Motivation

When using static shapes like in Fusions we receive from Thunder, I noticed that our matmul main loops are being fully unrolled (at least this is requested but the compiler likely does not fully unroll). For example I have seen this:

  #pragma unroll
  for(nvfuser_index_t i68 = 0; i68 < 160; ++i68)

This particular kernel took 35 seconds to compile. After this change, we will instead do the following:

  #pragma unroll 3
  for(nvfuser_index_t i68 = 0; i68 < 160; ++i68)

and the compile time is under 400 ms with no change to kernel runtime.

@jacobhinkle
Copy link
Collaborator Author

!test

@rdspring1
Copy link
Collaborator

I think the unroll factor should be the prefetch stage for prologue and epilogue loops. Maybe for the main loop too.

        const auto& opt = GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
            circular_buffer_loop_->iter_domain();
       int64_t prologue_unroll = opt.prefetch;

@jacobhinkle jacobhinkle changed the title Limit unrolling of static circ buffered main loops Limit unrolling of all circular buffered loops to depth equal to prefetch Dec 20, 2024
@jacobhinkle jacobhinkle marked this pull request as ready for review December 20, 2024 17:54
@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle jacobhinkle marked this pull request as draft December 20, 2024 18:08
@jacobhinkle
Copy link
Collaborator Author

Marking as draft. Apparently getCircularBufferOptionsFor requires an active GpuLower which we do not have at this point.

…uring lowering

If there's no lowering, it means we're looking up the circ buffer
options after lowering, so this is already being called on a ForLoop ID.
@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle jacobhinkle marked this pull request as ready for review December 23, 2024 16:03
@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle
Copy link
Collaborator Author

It looks like the generated code is often better in terms of register usage. Spills are often reduced, but sometimes a stack is introduced. For example
image

--- 02ffc838

+++ c05b6c17

@@ -166,11 +166,11 @@

     );
   }
   asm volatile("cp.async.commit_group;\n");
   asm volatile("cp.async.wait_all;\n");
   __syncthreads();
-  #pragma unroll 2
+  #pragma unroll 1
   for(nvfuser_index_t i66 = 0; i66 < i0; ++i66) {
     nvfuser_index_t i67;
     i67 = 32 * i66;
     __half* ptr68;
     ptr68 = ptr21 + i67;
@@ -247,11 +247,11 @@

          "=r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[0]))[1]),
          "=r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[0]))[2]),
          "=r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[0]))[3])
         :"r"((uint32_t)((i74 + i81)))
       );
-      #pragma unroll
+      #pragma unroll 1
       for(nvfuser_index_t i61 = 0; i61 < 3; ++i61) {
         nvfuser_index_t i86;
         i86 = 8 * (i61 % 2);
         nvfuser_index_t i87;
         i87 = 32 * i61;
@@ -331,11 +331,11 @@

           "ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
           :"=r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T8[0]))[0]),
            "=r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T8[0]))[1])
           :"r"((uint32_t)(i84))
         );
-        #pragma unroll
+        #pragma unroll 1
         for(nvfuser_index_t i63 = 0; i63 < 7; ++i63) {
           nvfuser_index_t i108;
           i108 = 4 * (i63 % 2);
           asm volatile(
             "ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
@@ -569,11 +569,11 @@

         "ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
         :"=r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T8[0]))[0]),
          "=r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T8[0]))[1])
         :"r"((uint32_t)(i84))
       );
-      #pragma unroll
+      #pragma unroll 1
       for(nvfuser_index_t i63 = 0; i63 < 7; ++i63) {
         nvfuser_index_t i150;
         i150 = 4 * (i63 % 2);
         asm volatile(
           "ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"

@jacobhinkle
Copy link
Collaborator Author

There are other test cases where this increases spilling:

--- 02ffc838

+++ c05b6c17

@@ -117,11 +117,11 @@

     #pragma unroll
     for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
       ((*reinterpret_cast<Array<float, 4, 1>*>(&T7[(i54 + (4 * i55))]))).set(0);
     }
   }
-  #pragma unroll
+  #pragma unroll 3
   for(nvfuser_index_t i56 = 0; i56 < 3; ++i56) {
     nvfuser_index_t i57;
     i57 = 32 * i56;
     __half* ptr58;
     ptr58 = ptr9 + i57;
@@ -167,11 +167,11 @@

     }
     asm volatile("cp.async.commit_group;\n");
   }
   asm volatile("cp.async.wait_group %0;\n"::"n"(2LL));
   __syncthreads();
-  #pragma unroll 4
+  #pragma unroll 3
   for(nvfuser_index_t i66 = 0; i66 < i0; ++i66) {
     nvfuser_index_t i67;
     i67 = 32 * i66;
     __half* ptr68;
     ptr68 = ptr22 + i67;

image

@jacobhinkle jacobhinkle merged commit e214d37 into main Dec 24, 2024
53 of 55 checks passed
@jacobhinkle jacobhinkle deleted the limit_static_main_loop_unrolling branch December 24, 2024 00:58
jacobhinkle added a commit that referenced this pull request Jan 2, 2025
…3663)

In #3627 we reduced unrolling. This doesn't lead to a reduction in
syncing, but it does reduce the number of sync instructions we see in
the compiled SASS. That change broke a test which I'm updating in this
PR.

Fixes #3661
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants