-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix double-scheduling of dc in HSS hopper matmuls #3590
Conversation
!test |
We would have hit this eventually but the HSS tests are still guarded against Hopper. I'm posting this now to unblock some internal heuristics work. |
// TODO: Currently we use stmatrix whenever this is true. We cannot do that | ||
// when the dtype is not 16 bits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@protonu we need to handle all possible dtypes in the epilogue.
// not casting back to half-precision in the output | ||
tvs_to_schedule.push_back(dc); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we still using stmatrix
if the output is fp32 and there wasn't a cast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if you enable use_smem_epilogue
in the included test we hit an error in scheduleStMatrixForMmaOutput
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we need an if/else here that checks the dtype of d_smem
and schedules with vectorized stores instead if not 16bit
Fuser/csrc/scheduler/hopper_multi_matmul.cpp
Lines 584 to 588 in 230f633
MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem); | |
// Schedule shared memory cache; Output from StMatrix | |
mma_utils::scheduleStMatrixForMmaOutput( | |
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n); |
When we do not have an epilogue (not even a cast), it might be the case that the original
MmaOp
has output which is a Fusion output. In this case the cached output which we often calldc
is actually anmma_result
. Currently this causes us to schedule that tensor once inscheduleMmaResults
then again inscheduleEpilogue
, leading to an esoteric error (see included test). This PR simply skips scheduling those tensors directly if they are already known to be mma results.