Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-matmul scheduler: add test and schedule smem operand store #2913

Merged
merged 23 commits into from
Sep 23, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Sep 5, 2024

This PR follows up from #2719 toward the eventual goal of merging #2458. Here I am introducing a parametrized test suite that I will use to check that the generated code matches the current scheduler for a variety of scenarios. I also am scheduling the first consumer tensor in the fusion: the smem store of operands. This shows how we can selectively check that the schedule is correct before we have everything implemented, since the test is only checking these tensors and nothing else in the fusion.

In the next PRs, I will schedule the rest of the prologue, as well as the mma result and the epilogue tensors. In each case, I will change the test to check more and more of the fusion for correctness.

This PR introduces AbstractMatmuTensor, an AbstractTensor for which each dimension can be tagged with a role. This lets us track the roles of each dimension during abstract tensor scheduling. This is only currently used in blockTileTensors and mma_utils::makeTile, but in the future we can imagine using it more.

//! domain must be set as loop domain. For the case of new swizzle, this domain
//! must be set as allocation domain.
template <bool legacy = true>
AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied straight from matmul.cpp

@jacobhinkle
Copy link
Collaborator Author

!build

Comment on lines +367 to +368
// Take the consumer of each input, which is the smem store
tvs.push_back(v->uses().at(0)->output(0)->as<TensorView>());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this just checks the smem store. But compareTvs is recursive so we can change this to add the smem loads next, then the MmaOp outputs, and eventually the fusion outputs, in order to check every tensor in the fusion.

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle marked this pull request as ready for review September 6, 2024 14:00
ASSERT_FALSE(testing::Test::HasFailure()) << suffix;
}

void compareSchedules() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the future plan of this check? I don't think in the long term we want to keep the legacy scheduler just for testing purposes. Will we completely removing those tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Once we remove the old scheduler we will remove this file.

csrc/abstract_tensor.h Outdated Show resolved Hide resolved
csrc/abstract_tensor.h Outdated Show resolved Hide resolved
jacobhinkle added a commit that referenced this pull request Sep 11, 2024
This generalizes `AbstractTensor` to the templated struct `AbstractTensorWithInfo<Info>`.

Related to #2913.
jacobhinkle added a commit that referenced this pull request Sep 18, 2024
This generalizes `AbstractTensor` to the templated struct
`AbstractTensorWithInfo<Info>` and introduces a special case subclass
`TaggedAbstractTensor<Tag>`. This can be used by passing an enum class
for `Tag`, and holds an `unordered_set<Tag>` for each dimension. Merging
and swizzling unions these sets, and split duplicates the set.

Note that a lot of code had to be moved out of the cpp into the header
because of templatization. However, there are no changes to the
`Dispatch*` classes. The `AbstractTensorWithInfo` methods like split,
merge, swizzle, etc. are just changed to add calls to `Info::merge`.

Related to #2913, which specializes this as `using AbstractMatmulTensor
= TaggedAbstractTensor<MatmulDimRole>`.
Comment on lines +77 to +78
checkConcreteStaticDim(swizzle_domain[-2]);
checkConcreteStaticDim(swizzle_domain[-1]);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented checkConcreteStaticDim so we don't need to convert to IterDomain* for this call now.

@jacobhinkle
Copy link
Collaborator Author

!build

I will investigate switching to BROADCAST separately but for now that is
breaking mma_utils::canonicalDimOrdering so I'm disabling it.
@jacobhinkle
Copy link
Collaborator Author

!build

@@ -1128,6 +1210,28 @@ std::vector<MatmulDimRole> canonicalizeMmaTvOrdering(
return roles;
}

void mergeConsecutiveAxesWithSameRole(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we be also interested in merging non-consecutive axes for more flexibility?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will pick up that case too. The consecutive dims are consecutive after a reordering so the original order can be anything. After we merge the refactor I want to add some tests with exotic multiple dim combinations to start exercising this code more.

csrc/scheduler/multi_matmul.cpp Show resolved Hide resolved
@@ -98,25 +534,22 @@ class MultipleMatmulScheduler {
// and dimension roles for all tensors in the fusion
findPatterns();
translatePatterns();
// translatePatterns changes the TensorView graph, so we build the IdModel
// afterward
buildIdModel();
findRoles();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does findRoles need IdModel, or it is already updated somewhere before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used by findRoles() but it is updated at the end of translatePatterns() now. I moved it there to make the logic in this scope a bit clearer.

}

// Recursively compare scalar values
void compareScalars(Val* v_orig, Val* v_new) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be simplified as:

return cloner_->clone(v_orig)->sameAs(v_new);

If the above isn't just working, then don't bother investing more time to make it work. This is a test that will throw out in the future, so I don't really care about having a clean and elegant implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would work if we had cloned the scheduled fusion, but here I have cloned the unscheduled fusion then scheduled them both. That means that some of these scalars are new in both the original and the clone. When I clone those over using cloner_->clone(v_orig) we wind up with a new undefined scalar, which fails the sameAs check if the original had a definition.

@jacobhinkle jacobhinkle merged commit bc3ddae into main Sep 23, 2024
5 checks passed
@jacobhinkle jacobhinkle deleted the multi_matmul_schedule_smem_store branch September 23, 2024 11:49
jacobhinkle added a commit that referenced this pull request Sep 23, 2024
`main` is currently not compiling due to the changes introduced in
#2913. There was a change on main that I didn't see between when CI
passed and when I merged. This PR fixes that.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants