-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create ElectSync predicate type #2923
Conversation
Why not exactly following what CUTLAS does? |
I'd drop I don't think it's necessary at this moment, but to make it more future proof, we may also want to consider how this predicate could be extended to support electing one execution entity within each parallel type. Again, just a thought, not important right now. |
Cutlass already knows which dimensions of the block are zero, so they can skip the extra predicates.
I agree. Cutlass does select warps based on
I don't have a strong opinion about changing the name to
|
I think my biggest concern is, I don't think The reason I dislike having this I believe a better way to do this is: run through all the existing predicate analyses we already have in our system. If these analyses indicates that only one thread is selected, then optimize it as elec-sync. What do you think? @naoyam |
I generally agree with @zasdfgbnm. Ideally, we should automatically use That said, I think at this point we should prioritize get things done as quickly as possible, so I'd prefer to have something working first and then reconsider the design. |
That's true. Actually, I don't this is worth worrying too much. It's just a Kernel IR type, so not exposed to the user, so never mind. |
I agree with this, but I am still feeling uncomfortable with silent wrong results. Today, TMA, just like any other expressions in nvFuser, fully support being launched by multiple threads. It will generate predicates like But suddenly we introduced this arbitrary limitation that does not match with what has been documented and could provide silent wrong result. In the long term, I believe we should have as little special case as possible, which means, TMA should continue support being launched by multiple threads, and an arbitrary In the short term, I am not suggesting that this is the most important thing, but I still believe at least what is documented should match with what is implemented, which should also match with what is validated. Which means, I don't think we can just add a new predicate type and use it. We should at least add some validation that the schedule is compatible with this predicate type, and update the document about the limitation. |
We can have a check so that we're not launching multiple threads with any block parallel dimension. |
Since TMA is an async operation, I do wonder if if (threadIdx.x < 5) {
// launch tma simultaneously
} is better than for (size_t i : irange(5)) {
if (warp_idx == 0 && electSync()) {
// launch tma individually
}
} Or vice versa? Probably not much difference. This ^^^ isn't related to the PR. Just general curiosity. |
I don't know honestly. I don't think there will be any first-order difference. But I won't be surprised if there are some second-order effects. For example, one variant uses more registers than the other, and this extra register bring the occupancy down from 2 to 1. Just want to have this flexibility to easily experiment different things. |
Yeah, adding a check is sufficient for now. Please also add a warning message to the doc mentioning that if you do circular buffering, you must launch TMA using one thread. |
!build |
I added the compatibility check for the Summary: In predicate lowering pass, for any If-Then-Else expression with There isn't an expression that uses the |
@@ -349,6 +349,11 @@ the TMA domain can be completely inferred from the schedule. | |||
> We do not have validation on shared memory schedule yet. | |||
> If you scheduled something invalid, likely you will see misaligned address error or silent wrong result. | |||
|
|||
> [!WARNING] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps there should be a circular buffering section but I added the warning here for now.
!build |
Summary
The PR creates the
ElectSync
predicate type to select a single thread in an if-then-else block.Why?
The standard predicate
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0
selects a specific thread from a CTA, which can create a peeling loop. Theelect.sync
ptx will select an arbitrary thread from the a given warp.Lowering Details
ElectSync
PredicateTypeElectSync
as a unary op.blockDim.x
has at least one warp because the defaultmembermask
is0xFFFFFFFF
.NvFuser's ElectSync Predicate
Note: It could be more efficient to use
__shfl_sync
to getwarp_idx
like CUTLASS.How CUTLASS selects a leader thread to issue TMA instructions