Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create ElectSync predicate type #2923

Merged
merged 7 commits into from
Sep 18, 2024
Merged

Create ElectSync predicate type #2923

merged 7 commits into from
Sep 18, 2024

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Sep 9, 2024

Summary

The PR creates the ElectSync predicate type to select a single thread in an if-then-else block.

Why?

The standard predicate threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 selects a specific thread from a CTA, which can create a peeling loop. The elect.sync ptx will select an arbitrary thread from the a given warp.

Lowering Details

  1. Create ElectSync PredicateType
  2. Add ElectSync as a unary op.
  3. Enforce that blockDim.x has at least one warp because the default membermask is 0xFFFFFFFF.

NvFuser's ElectSync Predicate

int warp_idx_zero = threadIdx.x < 32 && threadIdx.y == 0 && threadIdx.z == 0;
int lane_predicate = hopper::electSync(/*membermask=*/0xFFFFFFFF);
if (warp_idx_zero && lane_predicate) {
  // Do Something
}

Note: It could be more efficient to use __shfl_sync to get warp_idx like CUTLASS.

How CUTLASS selects a leader thread to issue TMA instructions

int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_predicate = hopper::electSync(/*membermask=*/0xFFFFFFFF);

if ((warp_idx == 0) && lane_predicate) {
  // Issue Tma Descriptor Prefetch from a single thread
}

@rdspring1 rdspring1 marked this pull request as ready for review September 11, 2024 16:55
@naoyam
Copy link
Collaborator

naoyam commented Sep 11, 2024

Why not exactly following what CUTLAS does?

@naoyam
Copy link
Collaborator

naoyam commented Sep 11, 2024

I'd drop Sync from the name. It implies the warp synchronous instruction execution. Here, it's just electing one thread from a block. Maybe ElectThread?

I don't think it's necessary at this moment, but to make it more future proof, we may also want to consider how this predicate could be extended to support electing one execution entity within each parallel type. Again, just a thought, not important right now.

csrc/device_lower/utils.h Outdated Show resolved Hide resolved
@rdspring1
Copy link
Collaborator Author

Why not exactly following what CUTLASS does?

Cutlass already knows which dimensions of the block are zero, so they can skip the extra predicates.

consider how this predicate could be extended to support electing one execution entity within each parallel type.

I agree. Cutlass does select warps based on threadIdx.x and threadIdx.y.

Change name to ElectThread?

I don't have a strong opinion about changing the name to ElectThread. However, all threads within the member-mask are synchronized, so there is some synchronization with this PTX instruction.

The mandatory .sync qualifier indicates that elect causes the executing thread to wait until all threads in the member-mask execute the elect instruction before resuming execution.

Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Sep 11, 2024

I think my biggest concern is, I don't think ElectSync is a separate predicate type. I would rather consider it as an optimization on existing predicate types. For example, if BDIMx is 32, and there is an IterDomain I0{5} parallelized in TIDx, then we will generate predicate threadIdx.x < 5, which can not be optimized as elec sync. Similarly, if I0's extent is 1, instead of 5, then we will have threadIdx.x < 1, which can be optimized as elect sync.

The reason I dislike having this ElectSync as a separate predicate type is because it effectively skips all our existing analyses on predication and just hard code the "select one thread", which would generate silent wrong result if the schedule does not imply "only one thread does the work". In the above example, if the extent of I0 is 5, then using the elec sync predicate type is just silently wrong.

I believe a better way to do this is: run through all the existing predicate analyses we already have in our system. If these analyses indicates that only one thread is selected, then optimize it as elec-sync.

What do you think? @naoyam

@naoyam
Copy link
Collaborator

naoyam commented Sep 11, 2024

I generally agree with @zasdfgbnm. Ideally, we should automatically use elect_sync whenever possible. We have "thread predicates", which indicates which threads should be allowed to execute. Maybe we could use it to pipe it through.

That said, I think at this point we should prioritize get things done as quickly as possible, so I'd prefer to have something working first and then reconsider the design.

@naoyam
Copy link
Collaborator

naoyam commented Sep 11, 2024

I don't have a strong opinion about changing the name to ElectThread. However, all threads within the member-mask are synchronized, so there is some synchronization with this PTX instruction.

That's true. Actually, I don't this is worth worrying too much. It's just a Kernel IR type, so not exposed to the user, so never mind.

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Sep 11, 2024

That said, I think at this point we should prioritize get things done as quickly as possible, so I'd prefer to have something working first and then reconsider the design.

I agree with this, but I am still feeling uncomfortable with silent wrong results.

Today, TMA, just like any other expressions in nvFuser, fully support being launched by multiple threads. It will generate predicates like if (threadIdx.x < 5) if the TIDx parallelized ID has extent of 5. And this behavior is well documented in our documents tma.md and test_tutorial.cpp.

But suddenly we introduced this arbitrary limitation that does not match with what has been documented and could provide silent wrong result. In the long term, I believe we should have as little special case as possible, which means, TMA should continue support being launched by multiple threads, and an arbitrary threadIdx.x < 1 should be capable of being optimized as elect sync. I think achieving this does not need extra work on coding, so the end PR should be as short as it is today, but it does require extra work on studying how our existing predicate generation work.

In the short term, I am not suggesting that this is the most important thing, but I still believe at least what is documented should match with what is implemented, which should also match with what is validated. Which means, I don't think we can just add a new predicate type and use it. We should at least add some validation that the schedule is compatible with this predicate type, and update the document about the limitation.

@rdspring1
Copy link
Collaborator Author

We should at least add some validation that the schedule is compatible with this predicate type, and update the document about the limitation.

We can have a check so that we're not launching multiple threads with any block parallel dimension.

@rdspring1
Copy link
Collaborator Author

rdspring1 commented Sep 11, 2024

Since TMA is an async operation, I do wonder if

if (threadIdx.x < 5) {
 // launch tma simultaneously
}

is better than

for (size_t i : irange(5)) {
  if (warp_idx == 0 && electSync()) {
    // launch tma individually
  }
}

Or vice versa? Probably not much difference.

This ^^^ isn't related to the PR. Just general curiosity.

@naoyam
Copy link
Collaborator

naoyam commented Sep 11, 2024

I just realized this was extracted from #2833, which I haven't yet looked at. This PR itself doesn't have any use of the new predicate type, so it's hard to discuss how safe it would be. Let me review #2833 first.

@zasdfgbnm
Copy link
Collaborator

Since TMA is an async operation, I do wonder if

if (threadIdx.x < 5) {
 // launch tma simultaneously
}

is better than

for (size_t i : irange(5)) {
  if (warp_idx == 0 && electSync()) {
    // launch tma individually
  }
}

Or vice versa? Probably not much difference.

This ^^^ isn't related to the PR. Just general curiosity.

I don't know honestly. I don't think there will be any first-order difference. But I won't be surprised if there are some second-order effects. For example, one variant uses more registers than the other, and this extra register bring the occupancy down from 2 to 1. Just want to have this flexibility to easily experiment different things.

@zasdfgbnm
Copy link
Collaborator

We can have a check so that we're not launching multiple threads with any block parallel dimension.

Yeah, adding a check is sufficient for now. Please also add a warning message to the doc mentioning that if you do circular buffering, you must launch TMA using one thread.

@rdspring1
Copy link
Collaborator Author

!build

@rdspring1
Copy link
Collaborator Author

rdspring1 commented Sep 16, 2024

I added the compatibility check for the ElectSync predicate.

Summary: In predicate lowering pass, for any If-Then-Else expression with ElectSync predicate, check that all output TensorViews do not have any block parallelization.

There isn't an expression that uses the ElectSync predicate in this PR, so I added the test to #2833. See TEST_F(NVFuserTest, ElectSyncCompatibility) in f29aa22.

@@ -349,6 +349,11 @@ the TMA domain can be completely inferred from the schedule.
> We do not have validation on shared memory schedule yet.
> If you scheduled something invalid, likely you will see misaligned address error or silent wrong result.

> [!WARNING]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps there should be a circular buffering section but I added the warning here for now.

csrc/device_lower/pass/predicate.cpp Outdated Show resolved Hide resolved
csrc/fusion_executor/executor.cpp Outdated Show resolved Hide resolved
@rdspring1
Copy link
Collaborator Author

!build

@rdspring1 rdspring1 merged commit 21ab9ab into main Sep 18, 2024
34 of 36 checks passed
@rdspring1 rdspring1 deleted the elect_sync branch September 18, 2024 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants