-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-matmul scheduler: add test and schedule smem operand store #2913
Conversation
//! domain must be set as loop domain. For the case of new swizzle, this domain | ||
//! must be set as allocation domain. | ||
template <bool legacy = true> | ||
AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied straight from matmul.cpp
!build |
// Take the consumer of each input, which is the smem store | ||
tvs.push_back(v->uses().at(0)->output(0)->as<TensorView>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this just checks the smem store. But compareTvs
is recursive so we can change this to add the smem loads next, then the MmaOp outputs, and eventually the fusion outputs, in order to check every tensor in the fusion.
!build |
ASSERT_FALSE(testing::Test::HasFailure()) << suffix; | ||
} | ||
|
||
void compareSchedules() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the future plan of this check? I don't think in the long term we want to keep the legacy scheduler just for testing purposes. Will we completely removing those tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Once we remove the old scheduler we will remove this file.
This generalizes `AbstractTensor` to the templated struct `AbstractTensorWithInfo<Info>`. Related to #2913.
This generalizes `AbstractTensor` to the templated struct `AbstractTensorWithInfo<Info>` and introduces a special case subclass `TaggedAbstractTensor<Tag>`. This can be used by passing an enum class for `Tag`, and holds an `unordered_set<Tag>` for each dimension. Merging and swizzling unions these sets, and split duplicates the set. Note that a lot of code had to be moved out of the cpp into the header because of templatization. However, there are no changes to the `Dispatch*` classes. The `AbstractTensorWithInfo` methods like split, merge, swizzle, etc. are just changed to add calls to `Info::merge`. Related to #2913, which specializes this as `using AbstractMatmulTensor = TaggedAbstractTensor<MatmulDimRole>`.
checkConcreteStaticDim(swizzle_domain[-2]); | ||
checkConcreteStaticDim(swizzle_domain[-1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented checkConcreteStaticDim
so we don't need to convert to IterDomain*
for this call now.
!build |
I will investigate switching to BROADCAST separately but for now that is breaking mma_utils::canonicalDimOrdering so I'm disabling it.
!build |
@@ -1128,6 +1210,28 @@ std::vector<MatmulDimRole> canonicalizeMmaTvOrdering( | |||
return roles; | |||
} | |||
|
|||
void mergeConsecutiveAxesWithSameRole( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will we be also interested in merging non-consecutive axes for more flexibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will pick up that case too. The consecutive dims are consecutive after a reordering so the original order can be anything. After we merge the refactor I want to add some tests with exotic multiple dim combinations to start exercising this code more.
@@ -98,25 +534,22 @@ class MultipleMatmulScheduler { | |||
// and dimension roles for all tensors in the fusion | |||
findPatterns(); | |||
translatePatterns(); | |||
// translatePatterns changes the TensorView graph, so we build the IdModel | |||
// afterward | |||
buildIdModel(); | |||
findRoles(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does findRoles
need IdModel, or it is already updated somewhere before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used by findRoles()
but it is updated at the end of translatePatterns()
now. I moved it there to make the logic in this scope a bit clearer.
} | ||
|
||
// Recursively compare scalar values | ||
void compareScalars(Val* v_orig, Val* v_new) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be simplified as:
return cloner_->clone(v_orig)->sameAs(v_new);
If the above isn't just working, then don't bother investing more time to make it work. This is a test that will throw out in the future, so I don't really care about having a clean and elegant implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would work if we had cloned the scheduled fusion, but here I have cloned the unscheduled fusion then scheduled them both. That means that some of these scalars are new in both the original and the clone. When I clone those over using cloner_->clone(v_orig)
we wind up with a new undefined scalar, which fails the sameAs
check if the original had a definition.
`main` is currently not compiling due to the changes introduced in #2913. There was a change on main that I didn't see between when CI passed and when I merged. This PR fixes that.
This PR follows up from #2719 toward the eventual goal of merging #2458. Here I am introducing a parametrized test suite that I will use to check that the generated code matches the current scheduler for a variety of scenarios. I also am scheduling the first consumer tensor in the fusion: the smem store of operands. This shows how we can selectively check that the schedule is correct before we have everything implemented, since the test is only checking these tensors and nothing else in the fusion.
In the next PRs, I will schedule the rest of the prologue, as well as the mma result and the epilogue tensors. In each case, I will change the test to check more and more of the fusion for correctness.
This PR introduces
AbstractMatmuTensor
, anAbstractTensor
for which each dimension can be tagged with a role. This lets us track the roles of each dimension during abstract tensor scheduling. This is only currently used inblockTileTensors
andmma_utils::makeTile
, but in the future we can imagine using it more.