-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Accept axis mapping when defining MmaOp (#3391)
This keeps the default interface of `fusedMultiplySum` but also adds an option to provide an `MmaOp::AxisMapping` object. This mapping defines, for each output dimension, which axis in each operand (if any) corresponds to that output dimension. This PR does not alter the behavior of `mma_utils::MatmulPattern::translateToMmaOp` meaning we still have BroadcastOp in translations for Hopper matmuls, but that change should be relatively simpler. Fixes #3372 The included test only checks that dimensions are properly mapped in an MmaOp defined without broadcast axes. In followup PRs I plan to do the following: 1. Demonstrate scheduling a Hopper matmul with unbroadcasted inputs manually. This should surface any bugs in the lowering of the MmaOp instruction when broadcasts are absent. 2. Ensure that we don't depend on having broadcast dims in the Hopper matmul scheduler. For example, we will handle this case in `moveInnerBroadcastLeft` and we may also need to adjust the swizzling of the TMA smem load TensorView. At this point we will be able to automatically schedule an `MmaOp` without broadcasted inputs that has been manually defined using our automatic scheduler. 3. Add an option `MatmulPattern::translateToMmaOp(/*avoid_intermediates=*/true)` and enable that in the Hopper matmul scheduler. At this point it will be safe for us to accept `MatmulOp` and `LinearOp` in the Hopper matmul scheduler.
- Loading branch information
1 parent
a5022da
commit 030c2ba
Showing
12 changed files
with
269 additions
and
320 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.