Generate predicates for cp.async.bulk
normally
#1903
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In our current main branch, all predicates of
cp.async.bulk
are skipped. It is skipped not because it should be like that, but instead, it is just a quick simple hack to allow us to incrementally build out TMA. Currently, TMA can only be used in a<<<1, 1>>>
kernel, and it can only be used to copy the entire tensor, instead of copying a part of that tensor. Under this limitation, it totally makes sense to skip the predicates.However, it no longer makes sense to skip predicate generation for TMA as we are adding support for non-trivial cases. For example, in #1484, an
if (threadIdx.x == 0 && threadIdx.x == 0 && threadIdx.x == 0)
is manually created in the double buffering pass as a temporary solution. Also, I just started working on allowing TMA to be used in a non-<<<1, 1>>>
kernel, where a thread predicate is clearly needed.In this PR, I am re-enabling predicate generation for TMA. For all the code that is already in main branch, this PR should be a no-op. I do not expect any change in the generated code for any TMA test. However, #1484 will be impacted in the sense that the
if (threadIdx.x == 0 && threadIdx.x == 0 && threadIdx.x == 0)
should no longer be created manually in the double-buffering pass, but instead, the double-buffering pass should leave the TMA op as-is, and the predicate generation pass will handle it.