Skip to content

Commit

Permalink
Fix compile error
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Sep 23, 2024
1 parent bc3ddae commit f0b64d7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
24 changes: 12 additions & 12 deletions csrc/scheduler/multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ class MultipleMatmulScheduler {
void swizzleBlockTiles(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles) {
if (params_.grid_swizzle_factor != 1) {
if (params_->grid_swizzle_factor != 1) {
// Find position of outer M and N dims in schedule_.tiled
int64_t Mo_pos = -1, No_pos = -1;
for (size_t i : c10::irange(outer_dim_roles.size())) {
Expand All @@ -804,8 +804,8 @@ class MultipleMatmulScheduler {
}
}

int factor = std::max(1, params_.grid_swizzle_factor); // must be >=1
switch (params_.cta_order) {
int factor = std::max(1, params_->grid_swizzle_factor); // must be >=1
switch (params_->cta_order) {
case MatmulParams::TileRasterizationOrder::RowMajor:
// split [I1, I2/factor, factor]
// reorder [I1, factor, I2/factor]
Expand Down Expand Up @@ -898,8 +898,8 @@ class MultipleMatmulScheduler {
//! 1) Axes will be ordered according to canonicalDimOrdering, and then axes
//! with the same role will be merged.
//! 2) After that, we perform splits according to
//! params_.tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_.grid_swizzle_factor, if the TV has
//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has
//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions
//! Mo and No.
//! 4) Finally, we do a split-K split if the splitk_factor is not 1
Expand Down Expand Up @@ -960,17 +960,17 @@ class MultipleMatmulScheduler {
// then apply it (with "forwarding") to each TV instead. We already cache
// a vector<ValGroup> as canonical_dim_ordering_ so AbstractTensor
// scheduling is the next step in this modernization.
mma_utils::makeTile(tv, params_.tile_sizes.cta_tile, merged_roles);
mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles);

swizzleBlockTiles(tv, merged_roles);

all_merged_roles.push_back(merged_roles);

if (params_.splitk_factor > 1) {
if (params_->splitk_factor > 1) {
// Outer K dimension in tv is in same position found in merged_roles
for (size_t i : c10::irange(merged_roles.size())) {
if (merged_roles[i] == MatmulDimRole::K) {
tv->split((int64_t)i, params_.splitk_factor, /*inner*/ false);
tv->split((int64_t)i, params_->splitk_factor, /*inner*/ false);
}
}
}
Expand All @@ -988,7 +988,7 @@ class MultipleMatmulScheduler {
const int64_t vec_size) {
blockTileTensors(smem_operands);
for (TensorView* tv : smem_operands) {
if (params_.promote_prologue_smem_reuse) {
if (params_->promote_prologue_smem_reuse) {
tv->promoteReuse();
}
mma_utils::orderTiledConcreteIdAsMaybeAllocationDomain(tv);
Expand All @@ -999,11 +999,11 @@ class MultipleMatmulScheduler {
// NOTE: this splits and parallelizes the inner dimension as
// TIDz, TIDy, TIDx, V
mma_utils::scheduleContiguousVectorLoad(
tv, params_.tile_sizes, vec_size, /*vectorize=*/vec_size > 1);
tv, params_->tile_sizes, vec_size, /*vectorize=*/vec_size > 1);
}
};
scheduleBranch(as_, acw_smems_, params_.supported_vec_size.a);
scheduleBranch(bs_, bcw_smems_, params_.supported_vec_size.b);
scheduleBranch(as_, acw_smems_, params_->supported_vec_size.a);
scheduleBranch(bs_, bcw_smems_, params_->supported_vec_size.b);
}

void scheduleMmaResults() {
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_multi_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ class MultiMatmulSchedulerMatchTest
cloner_ = std::make_unique<IrCloner>(Fusion::copy(fusion, &new_fusion));

// Schedule fusion with original matmul scheduler
scheduleMatmul(fusion, params);
scheduleMatmul(fusion, &params);

// Schedule cloned fusion with new scheduler
scheduleMultipleMatmuls(&new_fusion, params);
scheduleMultipleMatmuls(&new_fusion, &params);

// Find tensors to compare. Note that these, and all producer tensors will
// be checked.
Expand Down

0 comments on commit f0b64d7

Please sign in to comment.