Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix double-scheduling of dc in HSS hopper matmuls #3590

Merged
merged 1 commit into from
Dec 13, 2024

Conversation

jacobhinkle
Copy link
Collaborator

When we do not have an epilogue (not even a cast), it might be the case that the original MmaOp has output which is a Fusion output. In this case the cached output which we often call dc is actually an mma_result. Currently this causes us to schedule that tensor once in scheduleMmaResults then again in scheduleEpilogue, leading to an esoteric error (see included test). This PR simply skips scheduling those tensors directly if they are already known to be mma results.

@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

We would have hit this eventually but the HSS tests are still guarded against Hopper. I'm posting this now to unblock some internal heuristics work.

Comment on lines +2708 to +2709
// TODO: Currently we use stmatrix whenever this is true. We cannot do that
// when the dtype is not 16 bits.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@protonu we need to handle all possible dtypes in the epilogue.

// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we still using stmatrix if the output is fp32 and there wasn't a cast?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you enable use_smem_epilogue in the included test we hit an error in scheduleStMatrixForMmaOutput.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we need an if/else here that checks the dtype of d_smem and schedules with vectorized stores instead if not 16bit

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);
// Schedule shared memory cache; Output from StMatrix
mma_utils::scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);

@jacobhinkle jacobhinkle merged commit cbd628f into main Dec 13, 2024
38 of 39 checks passed
@jacobhinkle jacobhinkle deleted the fix_hopper_hss_epilogue branch December 13, 2024 21:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants