Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compile error in multi-matmul scheduler #2990

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions csrc/scheduler/multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ class MultipleMatmulScheduler {
void swizzleBlockTiles(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles) {
if (params_.grid_swizzle_factor != 1) {
if (params_->grid_swizzle_factor != 1) {
// Find position of outer M and N dims in schedule_.tiled
int64_t Mo_pos = -1, No_pos = -1;
for (size_t i : c10::irange(outer_dim_roles.size())) {
Expand All @@ -804,8 +804,8 @@ class MultipleMatmulScheduler {
}
}

int factor = std::max(1, params_.grid_swizzle_factor); // must be >=1
switch (params_.cta_order) {
int factor = std::max(1, params_->grid_swizzle_factor); // must be >=1
switch (params_->cta_order) {
case MatmulParams::TileRasterizationOrder::RowMajor:
// split [I1, I2/factor, factor]
// reorder [I1, factor, I2/factor]
Expand Down Expand Up @@ -898,8 +898,8 @@ class MultipleMatmulScheduler {
//! 1) Axes will be ordered according to canonicalDimOrdering, and then axes
//! with the same role will be merged.
//! 2) After that, we perform splits according to
//! params_.tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_.grid_swizzle_factor, if the TV has
//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has
//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions
//! Mo and No.
//! 4) Finally, we do a split-K split if the splitk_factor is not 1
Expand Down Expand Up @@ -960,17 +960,17 @@ class MultipleMatmulScheduler {
// then apply it (with "forwarding") to each TV instead. We already cache
// a vector<ValGroup> as canonical_dim_ordering_ so AbstractTensor
// scheduling is the next step in this modernization.
mma_utils::makeTile(tv, params_.tile_sizes.cta_tile, merged_roles);
mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles);

swizzleBlockTiles(tv, merged_roles);

all_merged_roles.push_back(merged_roles);

if (params_.splitk_factor > 1) {
if (params_->splitk_factor > 1) {
// Outer K dimension in tv is in same position found in merged_roles
for (size_t i : c10::irange(merged_roles.size())) {
if (merged_roles[i] == MatmulDimRole::K) {
tv->split((int64_t)i, params_.splitk_factor, /*inner*/ false);
tv->split((int64_t)i, params_->splitk_factor, /*inner*/ false);
}
}
}
Expand All @@ -988,7 +988,7 @@ class MultipleMatmulScheduler {
const int64_t vec_size) {
blockTileTensors(smem_operands);
for (TensorView* tv : smem_operands) {
if (params_.promote_prologue_smem_reuse) {
if (params_->promote_prologue_smem_reuse) {
tv->promoteReuse();
}
mma_utils::orderTiledConcreteIdAsMaybeAllocationDomain(tv);
Expand All @@ -999,11 +999,11 @@ class MultipleMatmulScheduler {
// NOTE: this splits and parallelizes the inner dimension as
// TIDz, TIDy, TIDx, V
mma_utils::scheduleContiguousVectorLoad(
tv, params_.tile_sizes, vec_size, /*vectorize=*/vec_size > 1);
tv, params_->tile_sizes, vec_size, /*vectorize=*/vec_size > 1);
}
};
scheduleBranch(as_, acw_smems_, params_.supported_vec_size.a);
scheduleBranch(bs_, bcw_smems_, params_.supported_vec_size.b);
scheduleBranch(as_, acw_smems_, params_->supported_vec_size.a);
scheduleBranch(bs_, bcw_smems_, params_->supported_vec_size.b);
}

void scheduleMmaResults() {
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_multi_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ class MultiMatmulSchedulerMatchTest
cloner_ = std::make_unique<IrCloner>(Fusion::copy(fusion, &new_fusion));

// Schedule fusion with original matmul scheduler
scheduleMatmul(fusion, params);
scheduleMatmul(fusion, &params);

// Schedule cloned fusion with new scheduler
scheduleMultipleMatmuls(&new_fusion, params);
scheduleMultipleMatmuls(&new_fusion, &params);

// Find tensors to compare. Note that these, and all producer tensors will
// be checked.
Expand Down
Loading