Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-matmul scheduler: add test and schedule smem operand store #2913

Merged
merged 23 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5995c82
Multi-matmul scheduler: add test and schedule smem operand store
jacobhinkle Sep 5, 2024
ecefe56
Use updated swizzleSharedMemory
jacobhinkle Sep 5, 2024
f0f6702
Update comment for swizzleSharedMemory
jacobhinkle Sep 5, 2024
90b95a7
Avoid comparing bias tensors for now
jacobhinkle Sep 6, 2024
64f56ff
clang-tidy mma_utils.cpp
jacobhinkle Sep 6, 2024
f073c79
Merge remote-tracking branch 'origin/main' into multi_matmul_schedule…
jacobhinkle Sep 11, 2024
0813c3c
Revert changes to abstract_tensor.h
jacobhinkle Sep 18, 2024
a1c08ad
Merge remote-tracking branch 'origin/main' into multi_matmul_schedule…
jacobhinkle Sep 18, 2024
3c3aeec
Undo botched merge
jacobhinkle Sep 18, 2024
eb24916
Improve code reuse in rewritten makeTile
jacobhinkle Sep 18, 2024
9e57930
Rename mergeAxesWithSameRole and add comments
jacobhinkle Sep 18, 2024
3bbe314
Move compareSchedules to TearDown()
jacobhinkle Sep 18, 2024
e9e445d
Use std::tie to unpack params
jacobhinkle Sep 18, 2024
e64b63a
Remove unused DisableOptionsGuard
jacobhinkle Sep 18, 2024
ea33ecf
Skip tests with message that this is temporary
jacobhinkle Sep 18, 2024
ae89a97
Remove unused include
jacobhinkle Sep 18, 2024
9bd34fa
Remove .as<IterDomain*>() in call to checkConcreteStaticDim
jacobhinkle Sep 18, 2024
dd6b038
Use broadcast graph. Improve comments
jacobhinkle Sep 18, 2024
b43651a
Improve docstrings
jacobhinkle Sep 18, 2024
48a9079
Switch back to using PERMISSIVE graph
jacobhinkle Sep 19, 2024
7fa7447
Merge remote-tracking branch 'origin/main' into multi_matmul_schedule…
jacobhinkle Sep 19, 2024
dcef0ad
Fix failure on Hopper
jacobhinkle Sep 19, 2024
264c0aa
Merge branch 'main' into multi_matmul_schedule_smem_store
jacobhinkle Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/test_matmul_aten_evaluation.cpp
${NVFUSER_ROOT}/tests/cpp/test_matmul_sass.cpp
${NVFUSER_ROOT}/tests/cpp/test_matmul_scheduler.cpp
${NVFUSER_ROOT}/tests/cpp/test_multi_matmul_scheduler.cpp
${NVFUSER_ROOT}/tests/cpp/test_mma.cpp
)
add_test(test_matmul "${MATMUL_TEST_SRCS}" "")
Expand Down
154 changes: 129 additions & 25 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,41 +362,55 @@ void scheduleContiguousVectorLoad(
tv->axis(-4)->parallelize(ParallelType::TIDz);
}

void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes) {
void makeTile(
AbstractMatmulTensor& abten,
const std::vector<int64_t>& tile_sizes) {
NVF_CHECK(
tv->getLoopDomain().size() >= tile_sizes.size(),
abten.size() >= tile_sizes.size(),
"Tensor dimension less than tile dimension!");

// Number of inner dimensions we are tiling.
const int64_t tile_dimension_size = (int64_t)tile_sizes.size();

// Split the inner dimensions:
for (int64_t idx : c10::irange(tile_dimension_size)) {
// Using negative indexing to accomodate potential batching
// dimensions on the further left. Eg.:
// 0, 1, 2 -> -3,-2,-1
// [M, N, K] -> [B0, B1, M, N, K]
tv->split(idx - tile_dimension_size, tile_sizes.at(idx));
// Split the inner dimensions
size_t num_split_axes = 0;
for (int64_t i = (int64_t)abten.size() - 1; i >= 0; --i) {
if (num_split_axes > 2) {
break;
}
const std::optional<MatmulDimRole> id_role_opt = abten.getTag(i);
if (!id_role_opt.has_value()) {
continue;
}
const MatmulDimRole id_role = id_role_opt.value();
// Assumes tile_sizes are given in m,n,k order
switch (id_role) {
case MatmulDimRole::M:
abten.split(i, tile_sizes.at(0));
break;
case MatmulDimRole::N:
abten.split(i, tile_sizes.at(1));
break;
case MatmulDimRole::K:
abten.split(i, tile_sizes.at(2));
break;
default:
continue;
}
num_split_axes++;
}

// The transformation happened should look like:
// The transformation above is:
// Before After
// [..., M, N, K] -> [..., Mo, Mi, No, Ni, Ko, Ki]

// Re-order the tiles so that all the outer tiles are
// Now we re-order the tiles so that all the outer tiles are
// on the left of all the inner tiles
std::unordered_map<int64_t, int64_t> reorder_map_old_to_new;

// Number of tiled inner dimensions after we split.
const auto split_tile_dimension_size = 2 * tile_dimension_size;
const auto split_tile_dimension_size = 2 * num_split_axes;
for (auto idx : c10::irange(split_tile_dimension_size)) {
// We want to reorder as follows:
// Before
//
// [..., Mo, Mi, No, Ni, Ko, Ki] ->
// After
// vvv group0 vvv vvv group1 vvv
// [..., Mo, No, Ko, Mi, Ni, Ki]
// Before After
// vvv group0 vvv vvv group1 vvv
// [..., Mo, Mi, No, Ni, Ko, Ki] -> [..., Mo, No, Ko, Mi, Ni, Ki]

// The index offset within group of current
// iterdomain, with grouping specified above.
Expand All @@ -408,16 +422,82 @@ void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes) {

// Calculate the actual index after reordering
auto index_after_reorder =
group_index * tile_dimension_size + index_within_group;
group_index * num_split_axes + index_within_group;

// Add pair {idx_before, idx_after} to re-order map.
reorder_map_old_to_new.insert(std::make_pair(
idx - split_tile_dimension_size,
index_after_reorder - split_tile_dimension_size));
}

// Apply the re-order map to tensor
tv->reorder(reorder_map_old_to_new);
// Apply the re-order map to abstract tensor
abten.reorder(reorder_map_old_to_new);
}

void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes) {
// We will create an AbstractMatmulTensor so that we can use the abstract
// makeTile implementation above.

// Set tags for the innermost axes corresponding to m,n,k (omitting some
// axes if tile_sizes.size() < 3
std::vector<std::unordered_set<MatmulDimRole>> axis_roles(tv->nDims());
NVF_ERROR(axis_roles.size() >= tile_sizes.size());
for (size_t i : c10::irange(tile_sizes.size())) {
size_t pos = axis_roles.size() - tile_sizes.size() + i;
switch (i) {
case 0:
axis_roles[pos].insert(MatmulDimRole::M);
break;
case 1:
axis_roles[pos].insert(MatmulDimRole::N);
break;
case 2:
axis_roles[pos].insert(MatmulDimRole::K);
break;
default:
NVF_THROW("Length tile_sizes must be 3 or less");
}
}
AbstractMatmulTensor abten(tv->getLoopDomain(), axis_roles);
makeTile(abten, tile_sizes);
tv->setLoopDomain(abten.as<IterDomain*>());
}

void makeTile(
TensorView* tv,
const GemmTile& mnk_tile_sizes,
const std::vector<MatmulDimRole>& axis_roles) {
NVF_ERROR(
tv->getLoopDomain().size() == axis_roles.size(),
"Tensor dimension must equal number of provided axis roles");

std::unordered_set<MatmulDimRole> axis_set(
axis_roles.begin(), axis_roles.end());
NVF_ERROR(
axis_set.size() == axis_roles.size(),
"Repeated axis roles are not allowed");
// Here we fill out tile_sizes to match the given axis roles. For example
// axis_roles might be something like [N, M], in which case we should use
// {mnk_tile_sizes.n, mnk_tile_sizes.m}.
std::vector<int64_t> tile_sizes;
for (MatmulDimRole role : axis_roles) {
switch (role) {
case MatmulDimRole::Batch:
NVF_ERROR(tile_sizes.empty(), "Batch dimension must be first");
break;
case MatmulDimRole::M:
tile_sizes.push_back(mnk_tile_sizes.m);
break;
case MatmulDimRole::N:
tile_sizes.push_back(mnk_tile_sizes.n);
break;
case MatmulDimRole::K:
tile_sizes.push_back(mnk_tile_sizes.k);
break;
}
}

makeTile(tv, tile_sizes);
}

namespace {
Expand Down Expand Up @@ -1082,6 +1162,8 @@ AbstractTensor MmaSwizzler::scheduleMmaOutputAllocation(AbstractTensor t) {
return t;
}

// TODO: Remove this in favor of mergeConsecutiveAxesWithSameRole once
// multi-matmul refactor is finished.
std::vector<MatmulDimRole> canonicalizeMmaTvOrdering(
TensorView* tv,
const ValGraph& permissive_graph,
Expand Down Expand Up @@ -1128,6 +1210,28 @@ std::vector<MatmulDimRole> canonicalizeMmaTvOrdering(
return roles;
}

void mergeConsecutiveAxesWithSameRole(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we be also interested in merging non-consecutive axes for more flexibility?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will pick up that case too. The consecutive dims are consecutive after a reordering so the original order can be anything. After we merge the refactor I want to add some tests with exotic multiple dim combinations to start exercising this code more.

TensorView* tv,
const DimRolesMap& dim_roles,
const ValGraph* graph) {
const auto getRole = [&](const int64_t pos) {
const ValGroup& vg = graph->toGroup(tv->axis(pos));
auto it = dim_roles.find(vg);
NVF_ERROR(it != dim_roles.end());
return it->second;
};
// Loop from inner to outer, merging when needed
NVF_ERROR(tv->nDims() > 0);
MatmulDimRole prev_role = getRole(-1);
for (int64_t dim = (int64_t)tv->nDims() - 2; dim >= 0; --dim) {
MatmulDimRole role = getRole(dim);
if (role == prev_role) {
tv->merge(dim);
}
prev_role = role;
}
}

namespace {
inline void resolveTvToMatmulDimRolesMapping(
DependenciesMap& deps_map,
Expand Down
24 changes: 24 additions & 0 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ NVF_API void scheduleWarpTileWithNoReduction(
//! Gives A[B, I0o, I1o, I2o, I0i(1), I1i(2), I2i(3)]
void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes);

//! The above call assumes the axes in TV are [(B), M, N, K]. In this version,
//! we provide the dimension roles that are present for this tensor.
void makeTile(
TensorView* tv,
const GemmTile& tile_sizes,
const std::vector<MatmulDimRole>& axis_roles);

//! We model each dimension of every tensor in the Fusion with ID roles
//! described by MatmulDimRole.
using AbstractMatmulTensor = TaggedAbstractTensor<MatmulDimRole>;

//! Abstract version of the above utility. Schedules the provided
//! AbstractMatmulTensor instead of a concrete TensorView.
void makeTile(
AbstractMatmulTensor& canonicalized_abstract_tensor,
const std::vector<int64_t>& tile_sizes);

//! Order the inner tile dimensions as the original order in
//! (maybe allocation) domain. Also putting broadcast domains on the left.
//! Eg. A[I0o,I1o,B2o,I0i,I1i,B2i] (maybe allocation domain: I1,B,I0)
Expand All @@ -79,6 +96,13 @@ std::vector<MatmulDimRole> canonicalizeMmaTvOrdering(
const DimRolesMap& dim_roles,
const std::vector<ValGroup>& ordering);

//! Given a TensorView matching the canonicalDimOrdering, schedule it by
//! merging dimensions with matching roles.
void mergeConsecutiveAxesWithSameRole(
TensorView* tv,
const DimRolesMap& dim_roles,
const ValGraph* graph);

//! [MmaSwizzler]:
//! This class is used to implement the thread swizzle format
//! required for the mma macros, cf. PTX ISA 9.7.13.4.
Expand Down
Loading
Loading