diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 727d25beb53..0060e626fe6 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -273,7 +273,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Generates the kernel function declaration void genDeclaration(const std::string& kernel_name) { - code_ << "__global__ void " << kernel_name << "("; + code_ << "__global__ void "; + if (kernel_->hasManaged("cluster_dims")) { + auto cluster_dims = + kernel_->getManaged>( + "cluster_dims"); + code_ << "__cluster_dims__(" << std::get<0>(cluster_dims) << ", " + << std::get<1>(cluster_dims) << ", " << std::get<2>(cluster_dims) + << ") "; + } + code_ << kernel_name << "("; std::unordered_set unique_args; diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index d236cf2c072..fbb95d46df2 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -29,500 +29,6 @@ namespace nvfuser { -namespace { - -// This function returns a pair of integers. The first integer is the gcd -// between megabanks and row stride. The second integer is the repeat pattern -// size. If the gcd is 1, then no swizzle is necessary to resolve bank -// conflicts. In that case, the second integer is irrelevant and -1 is returned. -std::pair analyzeSwizzleSharedMemory( - TensorView* shared_mem_tv) { - NVF_ERROR(shared_mem_tv->getMemoryType() == MemoryType::Shared); - AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); - - // Check that the innermost 2 dimensions are concrete and static - // sized so that the swizzle function can be defined. - NVF_ERROR( - (int64_t)swizzle_domain.size() >= 2, - "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", - shared_mem_tv->toString()); - mma_utils::checkConcreteStaticDim(swizzle_domain[-2]); - mma_utils::checkConcreteStaticDim(swizzle_domain[-1]); - - // Extract the constant sizes of the swizzled tile - const int64_t tile_size_x = - swizzle_domain[-2]->extent()->evaluate().as(); - const int64_t tile_size_y = - swizzle_domain[-1]->extent()->evaluate().as(); - - // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. - // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit - // (i.e. float) - const int64_t data_type_size = dataTypeSize(*shared_mem_tv->getDataType()); - NVF_ERROR(data_type_size == 2 || data_type_size == 4); - - // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. - // For epilogue, threads in a warp is organized as 8 rows x 4 columns. - // Each thread vectorized write 2 items, so 8 items per row. - //--0--1--2--3 - //--4--5--6--7 - //--8--9--10-11 - //--12-13-14-15 - //--16-17-18-19 - //--20-21-22-23 - //--24-25-26-27 - //--28-29-30-31 - constexpr int64_t n_rows = 8; - constexpr int64_t n_cols = 8; - - // Column size of the tile needs to be multiples of 8 for ldmatrix to work. - NVF_ERROR( - tile_size_x >= n_rows && tile_size_x % n_rows == 0 && - tile_size_y >= n_cols && tile_size_y % n_cols == 0, - "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", - tile_size_x, - "x", - tile_size_y); - - /* Note [How to remove bank conflict for ldmatrix?] - * - * **This note is interleaved with code, I suggest reading this note like - * reading a jupyter notebook** - * - * Our task is to make sure different rows does not fall into the same - * bank of shared memory. - * - * Introduction to bank conflict can be found at page 54-72 of: - * https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf - * - * When we talk about bank conflict removal, we are talking about the - * following task: - * "there are 32 banks, and each bank contains one 4-byte word, we want to - * make sure different lanes in a warp does not access different word - * addresses in the same bank" - * For example, if thread 0 is accessing word address 1, and thread 1 is - * accessing word address 33, then these two threads will have a bank - * conflict because they are accessing different word addresses in the same - * bank. However, if thread 0 is accessing byte address 4 and thread 1 is - * accessing byte address 6 then there will be no bank conflict because 4 - * and 6 both belong to word 1. - */ - - constexpr int64_t smem_bytes_per_word = 4; - constexpr int64_t smem_banks = 32; - - /* but here, for our convenience, because ldmatrix always use vectorized - * access of 8 items = 16 bytes = 4 words, we further group words into - * units: we consider each 4 words as a "unit", and each 4 banks as a - * "megabank". So we can rephrase our task as: - * "there are 8 megabanks, and each megabanks contains one 4-word unit, we - * want to make sure different lanes in a warp does not access different - * unit addresses in the same megabank" - * In this terminology, matrices are in the row major format, each matrix - * has 8 rows, and each row has exactly one unit. - */ - - constexpr int64_t items_per_unit = n_cols; - const int64_t bytes_per_unit = items_per_unit * data_type_size; - const int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; - const int64_t num_megabanks = smem_banks / words_per_unit; - - /* In the following example, each CTA tile contains 2 rows and 3 colums of - * matrices, each 8x8 size: - * +----------+----------+----------+ - * | matrix 0 | matrix 1 | matrix 2 | - * +----------+----------+----------+ - * | matrix 3 | matrix 4 | matrix 5 | - * +----------+----------+----------+ - * The addresses of different rows in the same matrix are offset by 3 units. - * In this perspective, loading a matrix is a strided memory access with the - * following stride (in units): - */ - - // number of units per row - int64_t row_stride = tile_size_y / items_per_unit; - - /* So the bank conflicting problem is now converted to the following game: - * I have a clock that has one pointer and `num_megabanks` ticks. I start - * my game by making my pointer pointing to somewhere, and turn forward - * the pointer `n_rows` times, each time by `row_stride` ticks. - * This problem can be well modeled by modular arithmetic in number theory - * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. - * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. - * Additions and multiplications are defined in a cyclic manner: - * 5 + 1 = 0, 5 + 2 = 1, 5 + 3 = 2, 5 + 4 = 3, ... - * 2 * 1 = 2, 2 * 2 = 4, 2 * 3 = 0, 2 * 4 = 2, ... - * With this definition, Z is mapped to Z/nZ naturally by i -> i % n [2] - * - * It worth mention that Z/nZ is a "commutative ring", that is, we can use - * addition and multiplication rules just like using normal integers: - * a + b = b + a, a * (b + c) = a * b + a * c, ... - * In short, we can reason about Z/nZ just like we are reasoning about - * integers, except that every number is automatically "% n". - * - * Reference: - * [1] https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n - * [2] The % is under Euclidean definition, that is -1 % 6 is 5 instead of - * -1, see [The Mathematics of Integer Arithmetic] for more detail. But - * we are only interested in non-negative numbers here, so there is no - * need to worry about this problem - */ - - // row_stride in Z/nZ, where n is num_megabanks: - // assert(row_stride >= 0); - // assert(num_megabanks >= 0); - int64_t row_stride_znz = row_stride % num_megabanks; - /* Consider the following function in Z/nZ: - * f(i; init) = init + i * stride - * where init is the initial position of the pointer in the clock when we - * start the game, and stride is the number of ticks we move forward each - * time, and i is the number of times we move forward. For a fixed init, we - * abbrivate f(i; init) as f(i). - * - * In our problem, f(i) is the megabank of the `i`th row of the matrix, and - * `init` is the megabank of the 0th row of the matrix. - * - * One very important property of f(i) is: - * - if f(i1) == f(i2), then for every j, f(i1 + j) = f(i2 + j) - * This property is true because: - * f(i1 + j) = f(i1) + j * stride = f(i2) + j * stride = f(i2 + j) - * - * The above property tells us, as we turn the clock forward: - * - initially, we will go to a never-visited tick in each turn, but, - * - at some point, we will return back to our original position, and, - * - after we return, we start repeat the pervious pattern again and again. - * - * As an example, consider f(i) where init = 0, stride = 6, under Z/8Z: - * i 0 1 2 3 4 5 6 7 - * f(i) 0 6 4 2 0 6 4 2 - * We can see that f(i) is repeating a pattern of four unique numbers - * "0 6 4 2" twice. In our bank conflict problem, this means we are using 4 - * different megabanks, and we have a 2-way conflict. - * - * The question of interest is, does the above observation generalize? That - * is, does f(i) always repeat a pattern of p unique numbers q times? Note - * that p and q must satisfy p * q = n. - * - * The answer to the above question is: yes! Consider the following - * equation: - * f(i1 + j) == f(i1) - * We want to know what is the smallest positive number j that makes the - * above equation true. Because this tells us in how many steps we will see - * repeat. This equation can be simplified as: - * f(i1 + j) == f(i1) + j * stride == f(i1) - * ==> j * stride == 0 - * - * An important tool to study this equation is multiplicative inverse: - * https://en.wikipedia.org/wiki/Modular_multiplicative_inverse - * A number i has multiplicative inverse `minv(i)` in Z/nZ if and only if it - * coprime with n. `minv(i)` is the number that `i * minv(i) == 1`. So in - * Z/nZ, the equation `ax = b` has solution `x = minv(a)*b` if a has - * multiplicative inverse. For example, in Z/15Z, `minv(2) = 8` because - * (2 * 8) % 15 = 1 - * - * stride has an multiplicative inverse if and only if stride coprime with - * n, that is, g := gcd(stride, n) == 1. In such case, the solution to our - * equation j * stride == 0 is j = minv(stride) * 0 = 0, that is: f(i) does - * not repeat, that is: there is no bank conflict. - */ - - int64_t g = std::gcd(num_megabanks, row_stride_znz); - if (g == 1) { - return {g, -1}; // No need to swizzle in this case. - } - - /* For the case where stride does not coprime with n, we note that - * j * stride == 0 in Z/nZ is equivalent to (j * stride) % n = 0 in Z. We - * can write stride and n as: - * stride = s * g, n = m * g - * According to Theorem 4.13 in [The Mathematics of Integer Arithmetic], we - * have: - * (j * stride) % n = 0 - * ==> (j * s) % m * g = 0 - * ==> (j * s) % m = 0 - * which is equivalent to j * s == 0 in Z/mZ. Because s coprime with m, we - * further get: - * j == 0 (in Z/mZ) - * That is, j is a multiple of m in Z. So the smallest positive j that make - * the equation hold is n / g. - * - * That is: f(i) always repeat a pattern of n/g unique numbers g times. - * In other word: we are using n/g megabanks, and we have a g-way bank - * conflict. - * - * Let's use the word "pattern" to refer to the set of values of `f` at - * different `i`, that is: - * pattern k = { f(i; init=k) | i in Z/nZ } - * For the example of stride = 6 under Z/8Z, we have the following patterns - * f(i): 01234567 - * pattern 0: x_x_x_x_ - * pattern 1: _x_x_x_x - * (x => occupied, _ => unoccupied) - */ - - int64_t repeated_pattern_size = num_megabanks / g; - - if (repeated_pattern_size >= n_rows) { - return {g, -1}; // No need to swizzle in this case. - } - - return {g, repeated_pattern_size}; -} - -//! Automatically generates the shared memory swizzled data layout for tma loads -//! in matmul mainloop. The shared memory data layout is always 2D currently. -//! This utility function assumes that the shared_mem_tv has the following -//! structure: [tile_row, tile_col] -//! Returns which swizzle format to use for mma inputs with tma loads. -MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv) { - auto&& [g, repeated_pattern_size] = analyzeSwizzleSharedMemory(shared_mem_tv); - - if (g == 1) { - return MmaInputSmemSwizzle::None; // No need to swizzle in this case. - } - - // 128B swizzle results in 8 x 8 matrix given half precision inputs. - constexpr int64_t n_rows = 8; - - NVF_ERROR( - n_rows % repeated_pattern_size == 0, - "Can not partition matrix into megarows"); - int64_t num_gigarows = n_rows / repeated_pattern_size; - int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size - - /* To further simplify the problem, if we assume: */ - NVF_ERROR( - num_gigarows % num_gigabanks == 0, - "Requires non-square swizzle, which is not supported yet"); - - AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); - // Extract the constant sizes of the swizzled tile - const int64_t inner_dim_size = - swizzle_domain[-1]->extent()->evaluate().as(); - - auto dtype = shared_mem_tv->getDataType().value(); - const int64_t B128_elements = 128 / dataTypeSize(dtype); - const int64_t B64_elements = 64 / dataTypeSize(dtype); - const int64_t B32_elements = 32 / dataTypeSize(dtype); - - if (inner_dim_size % B128_elements == 0) { - return MmaInputSmemSwizzle::B128; - } else if (inner_dim_size % B64_elements == 0) { - return MmaInputSmemSwizzle::B64; - } else if (inner_dim_size % B32_elements == 0) { - return MmaInputSmemSwizzle::B32; - } else { - NVF_THROW("Unsupported swizzle size for TMA shared memory mma inputs"); - } -} - -//! Automatically generates the shared memory swizzled data layout for matmul -//! epilogue. -//! The shared mem data layout is always 2D currently, and this utility -//! function assumes that the shared_mem_tv has the following structure: -//! [tile_row, tile_col] -//! Returns the domain with swizzle. For the case of legacy swizzle, this -//! domain must be set as loop domain. For the case of new swizzle, this domain -//! must be set as allocation domain. -template -AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { - auto&& [g, repeated_pattern_size] = analyzeSwizzleSharedMemory(shared_mem_tv); - - // Create Abstract Tensor from shared memory tensor loop domain. - AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); - - if (g == 1) { - return swizzle_domain; // No need to swizzle in this case. - } - - /* Now we know that we have a g-way bank conflict. How do we remove this - * bank conflict? The answer is to mix the storage of different matrices. - * We first split the matrices along the row axis into g pieces, each piece - * has n/g rows. With this split, each piece occupies exactly one pattern. - * We want to use some non-traditional storage to let different pieces of - * the same matrix to occupy different patterns. - * - * Because Z/nZ has n items, each pattern has n/g different items, so we - * have in total g different patterns. We want to find the corresponding - * `init` values of these g different patterns. - * - * Consider two different init values `init1` and `init2`. When do they - * represent the same pattern? They represent the same pattern if and only - * if `f(0; init2)` falls on the pattern of `init1`, that is, there exist an - * i such that - * f(i; init1) == f(0; init2) - * which simplifies to - * init1 + i * stride == init2 - * ==> init2 - init1 == i * stride - * What values can `i * stride` be? It can be an arbitrary multiple of g: - * i * stride in Z/nZ is (i * stride) % n in Z. Let m = n/g, according to - * Theorem 4.13 in [The Mathematics of Integer Arithmetic] - * (i * stride) % n = (i * s) % m * g - * Because s coprime with m, we know that for an arbitrary value `j` in - * Z/mZ, we can take `i = minv(s) * j` to make `i * s == j`. - * - * That said, for init values that are off by a multiple of g they - * correspond to the same pattern, otherwise they belongs to different - * patterns. So, we can use - * init = 0, 1, ..., g - 1 - * to canonically represent g patterns. Let's call the above - * `init` values "pattern id". - * - * Now we have the idea about how to remove bank conflict: We can do an - * inner split of our row dimension by `repeated_pattern_size` to get - * (repeat, pattern), then different indices of the "repeat" dimension will - * be using the same megabank, and different indices of the "pattern" - * dimension will be using different megabank. We don't need to touch the - * "pattern" dimension, but we need to play with the "repeat" dimension to - * interleave it with matrice ids so that each matrix is distributed across - * different banks. - * - * For example, if we have repeated_pattern_size = 4, we would want to do - * something like below: - * +----------+----------+ - * 0| | | - * 1| matrix 0 | matrix 1 | - * 2| | | - * 3| | | - * +----------+----------+ - * 4| | | - * 5| matrix 1 | matrix 0 | - * 6| | | - * 7| | | - * +----------+----------+ - * - * We can consider each repeated_pattern_size rows as a gigarow, and each - * repeated_pattern_size megabanks as a gigabank. Note that megabank is a - * contiguous chunk of banks, but gigabank is not contiguous. Indeed, - * nearby megabanks in a gigabank has a distance of `g` megabanks - */ - - // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. - // For epilogue, threads in a warp is organized as 8 rows x 4 columns. - // Each thread vectorized write 2 items, so 8 items per row. - //--0--1--2--3 - //--4--5--6--7 - //--8--9--10-11 - //--12-13-14-15 - //--16-17-18-19 - //--20-21-22-23 - //--24-25-26-27 - //--28-29-30-31 - constexpr int64_t n_rows = 8; - constexpr int64_t n_cols = 8; - - NVF_ERROR( - n_rows % repeated_pattern_size == 0, - "Can not partition matrix into megarows"); - int64_t num_gigarows = n_rows / repeated_pattern_size; - int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size - - // -2 -1 - // [row, col] - if (repeated_pattern_size > 1) { - swizzle_domain.split(-2, repeated_pattern_size); - } - swizzle_domain.split(-1, n_cols); - // -4 -3 -2 -1 - // [gigarow id, gigarow, matrix id, matrix] - swizzle_domain.split(-2, num_gigabanks); - // -5 -4 -3 -2 -1 - // [gigarow id, gigarow, y outer, gigabank id, matrix] - // Note that megabanks inside a gigabank are not contiguous, so the gigabank - // id is -2 instead of -3 - - /* We want to evenly distribute gigarows across gigabanks, for example, if - * we have 7 gigarows and 3 gigabanks, then we might distribute them as: - * +---+ - * |x | - * | x | - * | x| - * |x | - * | x | - * | x| - * |x | - * +---+ - * considering all matrices, this is a swizzle function like: - * +---+ - * |012| - * |201| - * |120| - * |012| - * |201| - * |120| - * |012| - * +---+ - * which is a cyclic shift. - * - * Note that because num_gigabanks (a.k.a. g) divide num_megabanks and - * row_stride_znz (which is row_stride % num_megabanks), g should also - * divide row_stride, because according to the fundamental - * division-with-remainder property (see doc/math/integer-division.md): - * row_stride = q * num_megabanks + row_stride_znz - * which means, we can just consider each num_gigabanks matrices as a group, - * and we always have complete groups (i.e. no group has less than - * num_gigabanks matrices). Interleaving the memory of matrices within each - * group should be enough to fully remove bank conflict. - */ - - /* To further simplify the problem, if we assume: */ - NVF_ERROR( - num_gigarows % num_gigabanks == 0, - "Requires non-square swizzle, which is not supported yet"); - /* Then we can partition gigarows into full waves, each wave has - * num_gigabanks gigarows. This partition creates square dimensions, making - * the swizzle implementation easier */ - - // -5 -4 -3 -2 -1 - // [gigarow id, gigarow, y outer, gigabank id, matrix] - int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; - swizzle_domain.split(axis_of_gigarow_id, num_gigabanks); - // -6 -5 -4 -3 -2 -1 - // [wave id, wave, gigarow, y outer, gigabank id, matrix] - - // swizzle wave with gigabank id to make threads in a wave access different - // gigabank. Apply swizzle only when shared_mem_tv is stored in shared - // memory. - // TODO: This is a temporary workaround for the following issue: - // For the mma output, we have the following schedule: - // rFactor: [...., X, Y] -> mma-swizzle transformations -> loop - // For epilogue smem tensor, the schedule is - // rFactor: [...., X, Y] -> split -> [...., X1, X2, X3, Y1, Y2, Y3] - // -> swizzle X2, Y2 -> [...., X1, X2', X3, Y1, Y2', Y3] - // -> merge back -> [...., X', Y'] - // -> mma-swizzle transformations -> loop - // The mma-swizzle transformations for the mma output and epilogue smem - // tensor are the same. In indexing, we do require {X, X'} and {Y, Y'} to be - // mapped in CA map, however, we currently can not handle that. So we have - // to do the same split and merge to the mma output without actually - // applying the swizzle, and this check is to detect and handle this - // specific case. We should remove this special handling when we fix our CA - // mapping. - using SwizzleTypeMaybeLegacy = - std::conditional_t; - if (isPowOf2(num_gigabanks)) { - swizzle_domain.swizzle(SwizzleTypeMaybeLegacy::XOR, axis_of_gigarow_id, -2); - } else { - swizzle_domain.swizzle( - SwizzleTypeMaybeLegacy::CyclicShift, axis_of_gigarow_id, -2); - } - - if (legacy) { - if (repeated_pattern_size > 1) { - swizzle_domain.merge(-6); - } - swizzle_domain.merge(-5); - - // merge back tile_size_y - swizzle_domain.merge(-3); - swizzle_domain.merge(-2); - } - - return swizzle_domain; -} - -} // namespace - MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { ValGroup vg = graph_->toGroup(id); auto it = id_roles_.find(vg); @@ -813,7 +319,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() { tv->promoteReuse(); } mma_utils::orderTiledConcreteIdAsMaybeAllocationDomain(tv); - MmaInputSmemSwizzle swizzle_type = tmaSwizzleSharedMemory(tv); + MmaInputSmemSwizzle swizzle_type = mma_utils::tmaSwizzleSharedMemory(tv); tv->applyMmaSwizzleForTMALoad(swizzle_type); } }; @@ -1077,14 +583,14 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { dc->setAllocationDomain(s.as(), true); } - MmaInputSmemSwizzle swizzle = tmaSwizzleSharedMemory(d_smem); + MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem); // Schedule shared memory cache; Output from StMatrix - scheduleStMatrixForMmaOutput( + mma_utils::scheduleStMatrixForMmaOutput( d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n); // Schedule global memory output; Output from TMA Store - scheduleTMAStoreForMmaOutput(d, swizzle); + mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle); } } } @@ -1240,94 +746,4 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { */ } -void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle, - int64_t tile_m, - int64_t tile_n) { - NVF_ERROR( - ((tile_m == 16 && tile_n == 16) || (tile_m == 16 && tile_n == 8)), - "We only support 16x16 and 16x16 stmatrix now"); - - NVF_ERROR( - tv->dtype() == DataType::Half, "we only support half type in stmatrix"); - - // [M, N] -> [128(TIDx), N/8 , 2 , 2] - auto s = - mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain()); - - if (swizzle != MmaInputSmemSwizzle::None) { - // Create tma store allocation domain with swizzle - scheduleTMAStoreForMmaOutput(tv, swizzle); - } - - tv->setLoopDomain(s.as()); - - if (tile_m == 16 && tile_n == 16) { - // Let [M, N] be [64, 32] - // After scheduleMmaOutputAllocation: [128(TIDx), 4, 2, 2] - // [128(TIDx), 4(n), 2, 2] -> [128(TIDx), 2(no), 2(ni), 2, 2] - tv->split(-3, 2); - // [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 2(ni), 2, 2] - tv->reorder({{-4, 0}}); - // [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 8 (vectorize)] - tv->merge(-3); - tv->merge(-2); - } else if (tile_m == 16 && tile_n == 8) { - // Let [M, N] be [64, 16] - // After scheduleMmaOutputAllocation: [128(TIDx), 2, 2, 2] - // [128(TIDx), 2, 2, 2] -> [2, 128(TIDx), 2, 2] - tv->reorder({{-3, 0}}); - // [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)] - tv->merge(-2); - } - tv->axis(-1)->parallelize(ParallelType::Vectorize); -} - -void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle) { - // [BDX, BDY, TDY, MI, NI] - // skip all but last 2 iterDomains - int64_t num_ids_to_skip = - static_cast(tv->getLoopDomain().size() - 2); - - NVF_ERROR(num_ids_to_skip >= 0); - if (swizzle == MmaInputSmemSwizzle::None) { - // For no-swizzle case, the entire tile are divided into 8x8 core matrices, - // and each core matrix resides in a contiguous 8*8*2 bytes region in shared - // memory. [K, N] - tv->split(-2, 8); - tv->split(-1, 8); - // [Ko, K8, No, N8] - tv->reorder({{-2, -3}}); - } else { - auto dtype = tv->getDataType().value(); - - // In the comments below I assume K=16, N=32, swizzle=32, dtype = half. - - // split the inner-dim - // [K(16), N(32)] -> [K(16), NO(2), NI(16)] - tv->split(-1, getBytesFromSwizzle(swizzle) / dataTypeSize(dtype)); - - // [NO, K, NI] - the TMA Box is [K, NI] - tv->reorder({{-2, -3}}); - - // [NO, K, NI] -> - // [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] - tv->swizzleTMABox(swizzle); - num_ids_to_skip += 1; - } - - // The shared memory producer must have the swizzled allocation domain. - // The global memory consumer must have the ParallelType::Bulk iterDomains. - if (tv->getMemoryType() == MemoryType::Shared) { - // Set the allocation to the loop domain. - tv->setAllocationDomain(tv->getLoopDomain(), true); - } else { - mma_utils::MmaSwizzler::parallelizeAsBulkSkippingFirstIDs( - tv, num_ids_to_skip); - } -} - } // namespace nvfuser diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index a736823f75c..bf7bc1df0f5 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -178,20 +178,6 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { void setUpCircularBuffering(); - //! Schedules the copy operation of output of a Mma op which resided in the - //! registers to shared memory. - void scheduleStMatrixForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle, - int64_t tile_m, - int64_t tile_n); - - //! Schedules the copy operation of output of a Mma op which resided in the - //! shared memory to global memory. - void scheduleTMAStoreForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle); - // Map TensorView's iterDomain to its ValGroup. // Then, find the MatmulDimRole for the ValGroup. // Return MatmulDimRole for IterDomain diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 13fd337e113..e9a24851a5a 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1256,38 +1256,53 @@ inline void resolveTvToMatmulDimRolesMapping( } // anonymous namespace -void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n) { - NVF_ERROR( - tv->getMemoryType() == MemoryType::Global, - "TMA Store should write to global memory"); +void scheduleTMAStoreForMmaOutput(TensorView* tv, MmaInputSmemSwizzle swizzle) { + // [BDX, BDY, TDY, MI, NI] + // skip all but last 2 iterDomains + int64_t num_ids_to_skip = + static_cast(tv->getLoopDomain().size() - 2); - NVF_ERROR( - tv->definition()->isA(), - "This tensor should be the result of a LoadStoreOp"); + NVF_ERROR(num_ids_to_skip >= 0); + if (swizzle == MmaInputSmemSwizzle::None) { + // For no-swizzle case, the entire tile are divided into 8x8 core matrices, + // and each core matrix resides in a contiguous 8*8*2 bytes region in shared + // memory. [K, N] + tv->split(-2, 8); + tv->split(-1, 8); + // [Ko, K8, No, N8] + tv->reorder({{-2, -3}}); + } else { + auto dtype = tv->getDataType().value(); - NVF_ERROR( - tv->definition()->as()->opType() == - LoadStoreOpType::CpAsyncBulkTensorTile, - "This is not a TMA operation"); + // In the comments below I assume K=16, N=32, swizzle=32, dtype = half. - NVF_ERROR( - tv->definition() - ->as() - ->in() - ->as() - ->getMemoryType() == MemoryType::Shared, - "Producer should be in shared memory"); - - // [M(m), N(n)] -> [MO(1), MI(m), NO(1), NI(n)] - tv->split(-2, m); - tv->split(-1, n); - // [MO(1), MI(m), NO(1), NI(n)] -> [MO(1), NO(1), MI(m), NI(n)] - tv->reorder({{-2, -3}}); - mma_utils::MmaSwizzler::parallelizeAsBulkSkippingFirstIDs(tv, 2); + // split the inner-dim + // [K(16), N(32)] -> [K(16), NO(2), NI(16)] + tv->split(-1, getBytesFromSwizzle(swizzle) / dataTypeSize(dtype)); + + // [NO, K, NI] - the TMA Box is [K, NI] + tv->reorder({{-2, -3}}); + + // [NO, K, NI] -> + // [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] + tv->swizzleTMABox(swizzle); + num_ids_to_skip += 1; + } + + // The shared memory producer must have the swizzled allocation domain. + // The global memory consumer must have the ParallelType::Bulk iterDomains. + if (tv->getMemoryType() == MemoryType::Shared) { + // Set the allocation to the loop domain. + tv->setAllocationDomain(tv->getLoopDomain(), true); + } else { + mma_utils::MmaSwizzler::parallelizeAsBulkSkippingFirstIDs( + tv, num_ids_to_skip); + } } void scheduleStMatrixForMmaOutput( TensorView* tv, + MmaInputSmemSwizzle swizzle, int64_t tile_m, int64_t tile_n) { NVF_ERROR( @@ -1300,6 +1315,12 @@ void scheduleStMatrixForMmaOutput( // [M, N] -> [128(TIDx), N/8 , 2 , 2] auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain()); + + if (swizzle != MmaInputSmemSwizzle::None) { + // Create tma store allocation domain with swizzle + mma_utils::scheduleTMAStoreForMmaOutput(tv, swizzle); + } + tv->setLoopDomain(s.as()); if (tile_m == 16 && tile_n == 16) { @@ -2174,6 +2195,288 @@ std::optional> allPatternRoles( id_roles, tensor_roles_opt.getData()}; } +namespace { +// This function returns a pair of integers. The first integer is the gcd +// between megabanks and row stride. The second integer is the repeat pattern +// size. If the gcd is 1, then no swizzle is necessary to resolve bank +// conflicts. In that case, the second integer is irrelevant and -1 is returned. +std::pair analyzeSwizzleSharedMemory( + TensorView* shared_mem_tv) { + NVF_ERROR(shared_mem_tv->getMemoryType() == MemoryType::Shared); + AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); + + // Check that the innermost 2 dimensions are concrete and static + // sized so that the swizzle function can be defined. + NVF_ERROR( + (int64_t)swizzle_domain.size() >= 2, + "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", + shared_mem_tv->toString()); + mma_utils::checkConcreteStaticDim(swizzle_domain[-2]); + mma_utils::checkConcreteStaticDim(swizzle_domain[-1]); + + // Extract the constant sizes of the swizzled tile + const int64_t tile_size_x = + swizzle_domain[-2]->extent()->evaluate().as(); + const int64_t tile_size_y = + swizzle_domain[-1]->extent()->evaluate().as(); + + // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. + // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit + // (i.e. float) + const int64_t data_type_size = dataTypeSize(*shared_mem_tv->getDataType()); + NVF_ERROR(data_type_size == 2 || data_type_size == 4); + + // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. + // For epilogue, threads in a warp is organized as 8 rows x 4 columns. + // Each thread vectorized write 2 items, so 8 items per row. + //--0--1--2--3 + //--4--5--6--7 + //--8--9--10-11 + //--12-13-14-15 + //--16-17-18-19 + //--20-21-22-23 + //--24-25-26-27 + //--28-29-30-31 + constexpr int64_t n_rows = 8; + constexpr int64_t n_cols = 8; + + // Column size of the tile needs to be multiples of 8 for ldmatrix to work. + NVF_ERROR( + tile_size_x >= n_rows && tile_size_x % n_rows == 0 && + tile_size_y >= n_cols && tile_size_y % n_cols == 0, + "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", + tile_size_x, + "x", + tile_size_y); + + /* Note [How to remove bank conflict for ldmatrix?] + * + * **This note is interleaved with code, I suggest reading this note like + * reading a jupyter notebook** + * + * Our task is to make sure different rows does not fall into the same + * bank of shared memory. + * + * Introduction to bank conflict can be found at page 54-72 of: + * https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf + * + * When we talk about bank conflict removal, we are talking about the + * following task: + * "there are 32 banks, and each bank contains one 4-byte word, we want to + * make sure different lanes in a warp does not access different word + * addresses in the same bank" + * For example, if thread 0 is accessing word address 1, and thread 1 is + * accessing word address 33, then these two threads will have a bank + * conflict because they are accessing different word addresses in the same + * bank. However, if thread 0 is accessing byte address 4 and thread 1 is + * accessing byte address 6 then there will be no bank conflict because 4 + * and 6 both belong to word 1. + */ + + constexpr int64_t smem_bytes_per_word = 4; + constexpr int64_t smem_banks = 32; + + /* but here, for our convenience, because ldmatrix always use vectorized + * access of 8 items = 16 bytes = 4 words, we further group words into + * units: we consider each 4 words as a "unit", and each 4 banks as a + * "megabank". So we can rephrase our task as: + * "there are 8 megabanks, and each megabanks contains one 4-word unit, we + * want to make sure different lanes in a warp does not access different + * unit addresses in the same megabank" + * In this terminology, matrices are in the row major format, each matrix + * has 8 rows, and each row has exactly one unit. + */ + + constexpr int64_t items_per_unit = n_cols; + const int64_t bytes_per_unit = items_per_unit * data_type_size; + const int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; + const int64_t num_megabanks = smem_banks / words_per_unit; + + /* In the following example, each CTA tile contains 2 rows and 3 colums of + * matrices, each 8x8 size: + * +----------+----------+----------+ + * | matrix 0 | matrix 1 | matrix 2 | + * +----------+----------+----------+ + * | matrix 3 | matrix 4 | matrix 5 | + * +----------+----------+----------+ + * The addresses of different rows in the same matrix are offset by 3 units. + * In this perspective, loading a matrix is a strided memory access with the + * following stride (in units): + */ + + // number of units per row + int64_t row_stride = tile_size_y / items_per_unit; + + /* So the bank conflicting problem is now converted to the following game: + * I have a clock that has one pointer and `num_megabanks` ticks. I start + * my game by making my pointer pointing to somewhere, and turn forward + * the pointer `n_rows` times, each time by `row_stride` ticks. + * This problem can be well modeled by modular arithmetic in number theory + * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. + * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. + * Additions and multiplications are defined in a cyclic manner: + * 5 + 1 = 0, 5 + 2 = 1, 5 + 3 = 2, 5 + 4 = 3, ... + * 2 * 1 = 2, 2 * 2 = 4, 2 * 3 = 0, 2 * 4 = 2, ... + * With this definition, Z is mapped to Z/nZ naturally by i -> i % n [2] + * + * It worth mention that Z/nZ is a "commutative ring", that is, we can use + * addition and multiplication rules just like using normal integers: + * a + b = b + a, a * (b + c) = a * b + a * c, ... + * In short, we can reason about Z/nZ just like we are reasoning about + * integers, except that every number is automatically "% n". + * + * Reference: + * [1] https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n + * [2] The % is under Euclidean definition, that is -1 % 6 is 5 instead of + * -1, see [The Mathematics of Integer Arithmetic] for more detail. But + * we are only interested in non-negative numbers here, so there is no + * need to worry about this problem + */ + + // row_stride in Z/nZ, where n is num_megabanks: + // assert(row_stride >= 0); + // assert(num_megabanks >= 0); + int64_t row_stride_znz = row_stride % num_megabanks; + /* Consider the following function in Z/nZ: + * f(i; init) = init + i * stride + * where init is the initial position of the pointer in the clock when we + * start the game, and stride is the number of ticks we move forward each + * time, and i is the number of times we move forward. For a fixed init, we + * abbrivate f(i; init) as f(i). + * + * In our problem, f(i) is the megabank of the `i`th row of the matrix, and + * `init` is the megabank of the 0th row of the matrix. + * + * One very important property of f(i) is: + * - if f(i1) == f(i2), then for every j, f(i1 + j) = f(i2 + j) + * This property is true because: + * f(i1 + j) = f(i1) + j * stride = f(i2) + j * stride = f(i2 + j) + * + * The above property tells us, as we turn the clock forward: + * - initially, we will go to a never-visited tick in each turn, but, + * - at some point, we will return back to our original position, and, + * - after we return, we start repeat the pervious pattern again and again. + * + * As an example, consider f(i) where init = 0, stride = 6, under Z/8Z: + * i 0 1 2 3 4 5 6 7 + * f(i) 0 6 4 2 0 6 4 2 + * We can see that f(i) is repeating a pattern of four unique numbers + * "0 6 4 2" twice. In our bank conflict problem, this means we are using 4 + * different megabanks, and we have a 2-way conflict. + * + * The question of interest is, does the above observation generalize? That + * is, does f(i) always repeat a pattern of p unique numbers q times? Note + * that p and q must satisfy p * q = n. + * + * The answer to the above question is: yes! Consider the following + * equation: + * f(i1 + j) == f(i1) + * We want to know what is the smallest positive number j that makes the + * above equation true. Because this tells us in how many steps we will see + * repeat. This equation can be simplified as: + * f(i1 + j) == f(i1) + j * stride == f(i1) + * ==> j * stride == 0 + * + * An important tool to study this equation is multiplicative inverse: + * https://en.wikipedia.org/wiki/Modular_multiplicative_inverse + * A number i has multiplicative inverse `minv(i)` in Z/nZ if and only if it + * coprime with n. `minv(i)` is the number that `i * minv(i) == 1`. So in + * Z/nZ, the equation `ax = b` has solution `x = minv(a)*b` if a has + * multiplicative inverse. For example, in Z/15Z, `minv(2) = 8` because + * (2 * 8) % 15 = 1 + * + * stride has an multiplicative inverse if and only if stride coprime with + * n, that is, g := gcd(stride, n) == 1. In such case, the solution to our + * equation j * stride == 0 is j = minv(stride) * 0 = 0, that is: f(i) does + * not repeat, that is: there is no bank conflict. + */ + + int64_t g = std::gcd(num_megabanks, row_stride_znz); + if (g == 1) { + return {g, -1}; // No need to swizzle in this case. + } + + /* For the case where stride does not coprime with n, we note that + * j * stride == 0 in Z/nZ is equivalent to (j * stride) % n = 0 in Z. We + * can write stride and n as: + * stride = s * g, n = m * g + * According to Theorem 4.13 in [The Mathematics of Integer Arithmetic], we + * have: + * (j * stride) % n = 0 + * ==> (j * s) % m * g = 0 + * ==> (j * s) % m = 0 + * which is equivalent to j * s == 0 in Z/mZ. Because s coprime with m, we + * further get: + * j == 0 (in Z/mZ) + * That is, j is a multiple of m in Z. So the smallest positive j that make + * the equation hold is n / g. + * + * That is: f(i) always repeat a pattern of n/g unique numbers g times. + * In other word: we are using n/g megabanks, and we have a g-way bank + * conflict. + * + * Let's use the word "pattern" to refer to the set of values of `f` at + * different `i`, that is: + * pattern k = { f(i; init=k) | i in Z/nZ } + * For the example of stride = 6 under Z/8Z, we have the following patterns + * f(i): 01234567 + * pattern 0: x_x_x_x_ + * pattern 1: _x_x_x_x + * (x => occupied, _ => unoccupied) + */ + + int64_t repeated_pattern_size = num_megabanks / g; + + if (repeated_pattern_size >= n_rows) { + return {g, -1}; // No need to swizzle in this case. + } + + return {g, repeated_pattern_size}; +} +} // namespace + +MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv) { + auto&& [g, repeated_pattern_size] = analyzeSwizzleSharedMemory(shared_mem_tv); + + if (g == 1) { + return MmaInputSmemSwizzle::None; // No need to swizzle in this case. + } + + // 128B swizzle results in 8 x 8 matrix given half precision inputs. + constexpr int64_t n_rows = 8; + + NVF_ERROR( + n_rows % repeated_pattern_size == 0, + "Can not partition matrix into megarows"); + int64_t num_gigarows = n_rows / repeated_pattern_size; + int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size + + /* To further simplify the problem, if we assume: */ + NVF_ERROR( + num_gigarows % num_gigabanks == 0, + "Requires non-square swizzle, which is not supported yet"); + + AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); + // Extract the constant sizes of the swizzled tile + const int64_t inner_dim_size = + swizzle_domain[-1]->extent()->evaluate().as(); + + auto dtype = shared_mem_tv->getDataType().value(); + const int64_t B128_elements = 128 / dataTypeSize(dtype); + const int64_t B64_elements = 64 / dataTypeSize(dtype); + const int64_t B32_elements = 32 / dataTypeSize(dtype); + + if (inner_dim_size % B128_elements == 0) { + return MmaInputSmemSwizzle::B128; + } else if (inner_dim_size % B64_elements == 0) { + return MmaInputSmemSwizzle::B64; + } else if (inner_dim_size % B32_elements == 0) { + return MmaInputSmemSwizzle::B32; + } else { + NVF_THROW("Unsupported swizzle size for TMA shared memory mma inputs"); + } +} + } // namespace mma_utils std::string toString(const mma_utils::AbstractMatmulTensor& abten) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 13d5524101a..c88fe4926e3 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -245,14 +245,14 @@ class MmaSwizzler { }; //! Schedules the copy operation of output of a Mma op which resided in the -//! shared memory to global memory. This assumes the outout of Mma in the -//! shared memory is of the form [M, N]. -//! This is tiled to [MO(1), NO(1), MI(m), NI(n)]. The inner two dims are -//! marked parallel type bulk. -void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n); +//! shared memory to global memory. +void scheduleTMAStoreForMmaOutput(TensorView* tv, MmaInputSmemSwizzle swizzle); +//! Schedules the copy operation of output of a Mma op which resided in the +//! registers to shared memory. void scheduleStMatrixForMmaOutput( TensorView* tv, + MmaInputSmemSwizzle swizzle, int64_t tile_m, int64_t tile_n); @@ -482,6 +482,13 @@ inline void checkConcreteStaticDim(const AbstractId& abs_id) { id->toString()); } +//! Automatically generates the shared memory swizzled data layout for tma loads +//! in matmul mainloop. The shared memory data layout is always 2D currently. +//! This utility function assumes that the shared_mem_tv has the following +//! structure: [tile_row, tile_col] +//! Returns which swizzle format to use for mma inputs with tma loads. +MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv); + } // namespace mma_utils std::string toString(const mma_utils::AbstractMatmulTensor& abten); diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index e451f8073d7..f04a2f2271e 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -100,8 +100,11 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch { class LoopDomainScheduler { public: - LoopDomainScheduler(std::vector ref_loop_dom) - : ref_loop_dom_(std::move(ref_loop_dom)) { + LoopDomainScheduler( + std::vector ref_loop_dom, + bool update_loop_domain_only = false) + : ref_loop_dom_(std::move(ref_loop_dom)), + update_loop_domain_only_(update_loop_domain_only) { NVF_ERROR(!ref_loop_dom_.empty()); // For now, ref must not be a broadcast domain @@ -174,6 +177,9 @@ class LoopDomainScheduler { private: std::vector ref_loop_dom_; + // If true, uses the current loop domain as the starting domain and + // updates it to make it look like the given reference loop domain + bool update_loop_domain_only_ = false; std::unique_ptr id_model_; ValGroups ref_id_groups_; ValGroups all_ancestors_of_ref_; @@ -188,9 +194,14 @@ void LoopDomainScheduler::schedule(TensorView* tv) const { // All of the existing IDs are reused as much as possible to // minimize creating new IDs. - auto all_ids = tv->domain()->allIDs(); + std::unordered_map group_to_id; ValGroups all_id_groups; + // When update_mode_ is true, only the loop domain IDs are reused as + // we attempt to transform the current loop domain to look like the + // reference loop domain. + auto all_ids = + update_loop_domain_only_ ? tv->getLoopDomain() : tv->domain()->allIDs(); for (auto id : all_ids) { const auto& group = graph().toGroup(id); group_to_id.emplace(group, id); @@ -297,9 +308,10 @@ void LoopDomainScheduler::schedule(TensorView* tv) const { // See LoopDomainSchedulingTest.ReshapeTraversalDirection for a // concrete example. ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { - // Find the path to the root domain of the tensor. It is important - // to use the root domain if available since there can be multiple - // forward paths to the logical domain in the ValGraph. For example, + // If not with the update mode, find the path to the root domain of + // the tensor. It is important to use the root domain if available since there + // can be multiple forward paths to the logical domain in the ValGraph. For + // example, // // t0 = [i0] // t1 = reshape(t0, {i0}, {i0/4, 4}) @@ -316,8 +328,12 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { // mean the t2 logical domain would have another definition (exactly mapped // with the t4 merge reshape). This issue can be avoided by using the root // domain of tv2 as the target of path finding. - ValGroups tv_target_domains = - graph().toGroups(TensorDomain::noBroadcasts(tv->getMaybeRootDomain())); + // + // In the case of the update mode, the target should be just the + // current loop domain of the tensor. + ValGroups tv_target_domains = graph().toGroups(TensorDomain::noBroadcasts( + update_loop_domain_only_ ? tv->getLoopDomain() + : tv->getMaybeRootDomain())); // If all the target domains are an ancestor of the reference // domains, just a single backward BFS should be enough to find a @@ -337,6 +353,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { .first; } + // In the case of the update mode, the path from the reference is + // assumed to just a backward traversal path. + NVF_ERROR( + !update_loop_domain_only_, + "Trying to update the current loop domain but could not find a valid path from the reference: ", + tv->toString()); + // Find the forward path from the ancestors to the target tensor auto forward_path = ValGraphBFS::getExprGroupsBetween( graph(), @@ -373,12 +396,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { void scheduleLoopDomainsLike( const std::vector& tvs, - const std::vector& ref_loop_dom) { + const std::vector& ref_loop_dom, + bool update_loop_domain_only) { if (tvs.empty()) { return; } - LoopDomainScheduler scheduler(ref_loop_dom); + LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only); for (auto tv : tvs) { // Loop domain of fusion inputs should have no meaning diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index 6dd79240ed4..5939c9d31e2 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -20,9 +20,14 @@ namespace scheduler_tools { // Create the loop domain of given tensors as specified by the // reference. The new loop domain is connected to the existing IDs of // each tensor by replaying exprs found in the Exact ValGraph. +// +// If update_loop_domain_only is true, uses the current loop domain as +// the starting domain and updates it to make it look like the given +// reference loop domain. void scheduleLoopDomainsLike( const std::vector& tvs, - const std::vector& ref_loop_dom); + const std::vector& ref_loop_dom, + bool update_loop_domain_only = false); // Replay a transform expr on the loop domain of each of the given // tensors. If the input of the transform is exact mapped with the loop diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 87657e50997..3e9f7f553c8 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3664,6 +3664,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { const int64_t cta_m = 2 * getM(macro); const int64_t cta_n = 1 * getN(macro); + constexpr std::tuple cluster_dims{2, 1, 1}; + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); fusion.addInput(tv0); @@ -3679,6 +3681,11 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { auto tv3 = castOp(DataType::Half, tv2); fusion.addOutput(tv3); + if constexpr ( + cluster_dims != std::tuple{1, 1, 1}) { + fusion.manage("cluster_dims", cluster_dims); + } + auto mma_ops = ir_utils::getOpsOfType(&fusion); NVF_CHECK( 1 == mma_ops.size(), @@ -3773,17 +3780,23 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { tv3c->setLoopDomain(s.as()); tv3c->setAllocationDomain(s.as(), true); - // We'll use stmatrix.x4 to store from reg to shared memory - fusion.manage("st_matrix_m_tile", (int64_t)16); - fusion.manage("st_matrix_n_tile", (int64_t)16); + constexpr int64_t stmatrix_tile_m = 16; + constexpr int64_t stmatrix_tile_n = 16; + fusion.manage("st_matrix_m_tile", stmatrix_tile_m); + fusion.manage("st_matrix_n_tile", stmatrix_tile_n); fusion.manage("st_matrix_m", getM(macro)); fusion.manage("st_matrix_n", getN(macro)); + MmaInputSmemSwizzle store_swizzle = + mma_utils::tmaSwizzleSharedMemory(tv3_shmem); + // This internally calls - // mma_utils::MmaSwizzler::scheduleMmaOutputAllocation - mma_utils::scheduleStMatrixForMmaOutput(tv3_shmem, 16, 16); + // Schedule shared memory cache; Output from StMatrix + mma_utils::scheduleStMatrixForMmaOutput( + tv3_shmem, store_swizzle, stmatrix_tile_m, stmatrix_tile_n); - mma_utils::scheduleTMAStoreForMmaOutput(tv3, M, N); + // Schedule global memory output; Output from TMA Store + mma_utils::scheduleTMAStoreForMmaOutput(tv3, store_swizzle); } inlineMost(); diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 7679d52590a..ce359b41d32 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -2864,7 +2864,8 @@ TEST_P(StMatrixTest, Regular) { tv1->setLoopDomain(s.as()); tv1->setAllocationDomain(s.as(), true); - mma_utils::scheduleStMatrixForMmaOutput(tv2, tile_m, tile_n); + mma_utils::scheduleStMatrixForMmaOutput( + tv2, /*swizzle=*/MmaInputSmemSwizzle::None, tile_m, tile_n); tv3->merge(0); tv3->split(0, 32); diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index 83cda79a43e..7e5ed33a8a6 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -531,9 +531,10 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { tv2->axis(-3)->parallelize(ParallelType::Mma); } - mma_utils::scheduleStMatrixForMmaOutput(tv3, tile_m, tile_n); + MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3); + mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n); - mma_utils::scheduleTMAStoreForMmaOutput(tv4, getM(macro), getN(macro)); + mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle); auto inputs = matmulAtInput3DHopperRS( getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype)); diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index a99e1bed4da..7d9e357d18b 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -4373,6 +4373,199 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } +// RoPE-like rotation patten +TEST_F(ResizeTest, SliceRotateCat) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({-1, 100}); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv3 = set(tv0); + + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); + + auto tv5 = cat({tv4, tv2}, 1); + + fusion.addOutput(tv5); + + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv5; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + { + IdModel id_model(&fusion, false); + id_model.buildExactGraph(); + std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } + + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } + + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// RoPE-like rotation and residual patten +TEST_F(ResizeTest, SliceRotateCatResidual) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({-1, 100}); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv3 = set(tv0); + + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); + + auto tv5 = cat({tv4, tv2}, 1); + + auto tv6 = add(tv0, tv5); + + fusion.addOutput(tv6); + + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv6; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + { + IdModel id_model(&fusion, false); + id_model.buildExactGraph(); + std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } + + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()) + << "Invalid computeAt position of " << tv->toString(); + } + + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + // Consumer-based scheduling of pad TEST_F(ResizeTest, PropagatePadToInputs) { Fusion fusion; diff --git a/tests/python/test_transformer_engine.py b/tests/python/test_transformer_engine.py index de4734e6c90..00eb4b9eeb7 100644 --- a/tests/python/test_transformer_engine.py +++ b/tests/python/test_transformer_engine.py @@ -22,6 +22,13 @@ class ComputeType(Enum): BACKWARD = auto() +class Parallelism(Enum): + # https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism + TENSOR_PARALLEL = auto() + # https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism + SEQUENCE_PARALLEL = auto() + + @pytest.fixture(scope="module") def setup_process_group(mpi_test) -> None: # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. @@ -47,7 +54,12 @@ def setup_process_group(mpi_test) -> None: [ComputeType.FORWARD, ComputeType.BACKWARD], ids=["forward", "backward"], ) -def test_transformer_layer(setup_process_group, benchmark, compute_type): +@pytest.mark.parametrize( + "parallelism", + [Parallelism.TENSOR_PARALLEL, Parallelism.SEQUENCE_PARALLEL], + ids=["tp", "sp"], +) +def test_transformer_layer(setup_process_group, benchmark, compute_type, parallelism): # Hyperparameters for GPT-3 hidden_size = 12288 num_heads = 96 @@ -69,12 +81,20 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type): # benchmark fails to execute on H100 with the default format (SBHD). attn_input_format="bshd", set_parallel_mode=True, + sequence_parallel=(parallelism == Parallelism.SEQUENCE_PARALLEL), tp_group=dist.group.WORLD, ) transformer_layer.to(dtype).to("cuda") + match parallelism: + case Parallelism.TENSOR_PARALLEL: + local_sequence_length = sequence_length + case Parallelism.SEQUENCE_PARALLEL: + assert sequence_length % size == 0 + local_sequence_length = sequence_length // size + x = torch.randn( - batch_size, sequence_length, hidden_size, dtype=dtype, device="cuda" + batch_size, local_sequence_length, hidden_size, dtype=dtype, device="cuda" ) match compute_type: @@ -93,7 +113,9 @@ def benchmark_fn(profile): # Warmup. y = benchmark_fn(False) - assert y.size() == torch.Size([batch_size, sequence_length, hidden_size]) + assert y.size() == torch.Size( + [batch_size, local_sequence_length, hidden_size] + ) benchmark.pedantic(benchmark_fn, args=(True,), rounds=5) case ComputeType.BACKWARD: