diff --git a/csrc/scheduler/multi_matmul.cpp b/csrc/scheduler/multi_matmul.cpp index 8eb9ec11586..9389467b48b 100644 --- a/csrc/scheduler/multi_matmul.cpp +++ b/csrc/scheduler/multi_matmul.cpp @@ -71,32 +71,23 @@ inline void checkConcreteStaticDim(const AbstractId& abs_id) { //! must be set as allocation domain. template AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { - // Set skip to skip all consecutive reduction domains starting from the - // innermost dimension. + NVF_ERROR(shared_mem_tv->getMemoryType() == MemoryType::Shared); AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); - int64_t skip = 0; - for (int64_t i = (int64_t)swizzle_domain.size() - 1; i >= 0; --i) { - if (swizzle_domain[i]->isReduction()) { - skip++; - } else { - break; - } - } // Check that the innermost 2 dimensions are concrete and static // sized so that the swizzle function can be defined. NVF_ERROR( - (int64_t)swizzle_domain.size() >= 2 + skip, + (int64_t)swizzle_domain.size() >= 2, "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", shared_mem_tv->toString()); - checkConcreteStaticDim(swizzle_domain[-2 - skip].as()); - checkConcreteStaticDim(swizzle_domain[-1 - skip].as()); + checkConcreteStaticDim(swizzle_domain[-2].as()); + checkConcreteStaticDim(swizzle_domain[-1].as()); // Extract the constant sizes of the swizzled tile const int64_t tile_size_x = - swizzle_domain[-2 - skip]->extent()->evaluate().as(); + swizzle_domain[-2]->extent()->evaluate().as(); const int64_t tile_size_y = - swizzle_domain[-1 - skip]->extent()->evaluate().as(); + swizzle_domain[-1]->extent()->evaluate().as(); // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit @@ -380,12 +371,12 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { // -2 -1 // [row, col] if (repeated_pattern_size > 1) { - swizzle_domain.split(-2 - skip, repeated_pattern_size); + swizzle_domain.split(-2, repeated_pattern_size); } - swizzle_domain.split(-1 - skip, n_cols); + swizzle_domain.split(-1, n_cols); // -4 -3 -2 -1 // [gigarow id, gigarow, matrix id, matrix] - swizzle_domain.split(-2 - skip, num_gigabanks); + swizzle_domain.split(-2, num_gigabanks); // -5 -4 -3 -2 -1 // [gigarow id, gigarow, y outer, gigabank id, matrix] // Note that megabanks inside a gigabank are not contiguous, so the gigabank @@ -436,7 +427,7 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { // -5 -4 -3 -2 -1 // [gigarow id, gigarow, y outer, gigabank id, matrix] int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; - swizzle_domain.split(axis_of_gigarow_id - skip, num_gigabanks); + swizzle_domain.split(axis_of_gigarow_id, num_gigabanks); // -6 -5 -4 -3 -2 -1 // [wave id, wave, gigarow, y outer, gigabank id, matrix] @@ -458,29 +449,25 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { // applying the swizzle, and this check is to detect and handle this // specific case. We should remove this special handling when we fix our CA // mapping. - if (shared_mem_tv->getMemoryType() == MemoryType::Shared) { - int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; - using SwizzleTypeMaybeLegacy = - std::conditional_t; - if (isPowOf2(num_gigabanks)) { - swizzle_domain.swizzle( - SwizzleTypeMaybeLegacy::XOR, axis_of_gigarow_id - skip, -2 - skip); - } else { - swizzle_domain.swizzle( - SwizzleTypeMaybeLegacy::CyclicShift, - axis_of_gigarow_id - skip, - -2 - skip); - } + using SwizzleTypeMaybeLegacy = + std::conditional_t; + if (isPowOf2(num_gigabanks)) { + swizzle_domain.swizzle(SwizzleTypeMaybeLegacy::XOR, axis_of_gigarow_id, -2); + } else { + swizzle_domain.swizzle( + SwizzleTypeMaybeLegacy::CyclicShift, axis_of_gigarow_id, -2); } - if (repeated_pattern_size > 1) { - swizzle_domain.merge(-6 - skip); - } - swizzle_domain.merge(-5 - skip); + if (legacy) { + if (repeated_pattern_size > 1) { + swizzle_domain.merge(-6); + } + swizzle_domain.merge(-5); - // merge back tile_size_y - swizzle_domain.merge(-3 - skip); - swizzle_domain.merge(-2 - skip); + // merge back tile_size_y + swizzle_domain.merge(-3); + swizzle_domain.merge(-2); + } return swizzle_domain; }