-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TMA: Don't merge dims when proven wrong (#3472)
In our TMA support, there is a mechanism "Define box by compositing" which is used to reduce the dimensionality of TMA when possible. The reason for doing so is because our hardware has a limitation that the tensor dimension can not be more than 5D, and by collapsing some dimensions, we will be able to support tensors higher than 5D by viewing it as a tensor of less than 5D. While I do believe this mechanism makes great sense in general, mindlessly collapsing as much as possible can cause trouble. For example, if the box size is `(2, 256)`, collapsing it into `(512,)` is a mistake because our hardware has another limitation that the box size can not be larger than `256`. This PR makes the "Define box by compositing" mechanism smarter: instead of always collapsing, we check if collapsing can cause trouble. If so, we will not collapse. Note that this check can not guarantee 100% safety. For example, if the box size is a symbolic value `v1`, then there is no way to get its value in lowering, therefore, it is not possible to know if merging it with another dim can make the box size greater than 256. But I believe this is fine in practice, because in my knowledge, we are not interested in symbolic box size right now, so the box size is almost always constant. In the future, assume that we are moving towards analytical heuristics, that is, the box size is a `Val` that is a function of our input shapes, and represented by our IR node. In this world, I think dynamic concretization is necessary to make "Define box by compositing" bug-free. For example, let's say our box size is `(v1, v2, v3, v4, v5, v6)`, depending on the values of these variables, we might want to do different things: - `(v1, v2, v3, v4, v5, v6) = (256, 2, 128, 256, 256, 256)` --> merge `v2` with `v3` and use 5D TMA - `(v1, v2, v3, v4, v5, v6) = (256, 256, 2, 128, 256, 256)` --> merge `v3` with `v4` and use 5D TMA - `(v1, v2, v3, v4, v5, v6) = (256, 256, 256, 128, 256, 256)` --> raise an error, not supported
- Loading branch information
Showing
2 changed files
with
121 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters