From 456b3192d01b7120d78cc1e257961e4ceabafbe6 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 10 Dec 2024 13:07:01 -0800 Subject: [PATCH 1/4] Move StMatrix and TMA Store swizzle schedule functions to mma_utils (#3552) This PR is stack on https://github.com/NVIDIA/Fuser/pull/3553. ### Changes 1. Moved `analyzeSwizzleSharedMemory`, `tmaSwizzleSharedMemory`, `scheduleStMatrixForMmaOutput`, and `scheduleTMAStoreForMmaOutput` to `mma_utils`. 2. Deleted `swizzleSharedMemory` from hopper matmul scheduler. 3. Updated tests to use new `mma_utils`. --- csrc/scheduler/hopper_multi_matmul.cpp | 592 +------------------------ csrc/scheduler/hopper_multi_matmul.h | 14 - csrc/scheduler/mma_utils.cpp | 353 +++++++++++++-- csrc/scheduler/mma_utils.h | 17 +- tests/cpp/test_matmul.cpp | 18 +- tests/cpp/test_memory.cpp | 3 +- tests/cpp/test_mma.cpp | 5 +- 7 files changed, 361 insertions(+), 641 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index d236cf2c072..fbb95d46df2 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -29,500 +29,6 @@ namespace nvfuser { -namespace { - -// This function returns a pair of integers. The first integer is the gcd -// between megabanks and row stride. The second integer is the repeat pattern -// size. If the gcd is 1, then no swizzle is necessary to resolve bank -// conflicts. In that case, the second integer is irrelevant and -1 is returned. -std::pair analyzeSwizzleSharedMemory( - TensorView* shared_mem_tv) { - NVF_ERROR(shared_mem_tv->getMemoryType() == MemoryType::Shared); - AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); - - // Check that the innermost 2 dimensions are concrete and static - // sized so that the swizzle function can be defined. - NVF_ERROR( - (int64_t)swizzle_domain.size() >= 2, - "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", - shared_mem_tv->toString()); - mma_utils::checkConcreteStaticDim(swizzle_domain[-2]); - mma_utils::checkConcreteStaticDim(swizzle_domain[-1]); - - // Extract the constant sizes of the swizzled tile - const int64_t tile_size_x = - swizzle_domain[-2]->extent()->evaluate().as(); - const int64_t tile_size_y = - swizzle_domain[-1]->extent()->evaluate().as(); - - // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. - // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit - // (i.e. float) - const int64_t data_type_size = dataTypeSize(*shared_mem_tv->getDataType()); - NVF_ERROR(data_type_size == 2 || data_type_size == 4); - - // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. - // For epilogue, threads in a warp is organized as 8 rows x 4 columns. - // Each thread vectorized write 2 items, so 8 items per row. - //--0--1--2--3 - //--4--5--6--7 - //--8--9--10-11 - //--12-13-14-15 - //--16-17-18-19 - //--20-21-22-23 - //--24-25-26-27 - //--28-29-30-31 - constexpr int64_t n_rows = 8; - constexpr int64_t n_cols = 8; - - // Column size of the tile needs to be multiples of 8 for ldmatrix to work. - NVF_ERROR( - tile_size_x >= n_rows && tile_size_x % n_rows == 0 && - tile_size_y >= n_cols && tile_size_y % n_cols == 0, - "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", - tile_size_x, - "x", - tile_size_y); - - /* Note [How to remove bank conflict for ldmatrix?] - * - * **This note is interleaved with code, I suggest reading this note like - * reading a jupyter notebook** - * - * Our task is to make sure different rows does not fall into the same - * bank of shared memory. - * - * Introduction to bank conflict can be found at page 54-72 of: - * https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf - * - * When we talk about bank conflict removal, we are talking about the - * following task: - * "there are 32 banks, and each bank contains one 4-byte word, we want to - * make sure different lanes in a warp does not access different word - * addresses in the same bank" - * For example, if thread 0 is accessing word address 1, and thread 1 is - * accessing word address 33, then these two threads will have a bank - * conflict because they are accessing different word addresses in the same - * bank. However, if thread 0 is accessing byte address 4 and thread 1 is - * accessing byte address 6 then there will be no bank conflict because 4 - * and 6 both belong to word 1. - */ - - constexpr int64_t smem_bytes_per_word = 4; - constexpr int64_t smem_banks = 32; - - /* but here, for our convenience, because ldmatrix always use vectorized - * access of 8 items = 16 bytes = 4 words, we further group words into - * units: we consider each 4 words as a "unit", and each 4 banks as a - * "megabank". So we can rephrase our task as: - * "there are 8 megabanks, and each megabanks contains one 4-word unit, we - * want to make sure different lanes in a warp does not access different - * unit addresses in the same megabank" - * In this terminology, matrices are in the row major format, each matrix - * has 8 rows, and each row has exactly one unit. - */ - - constexpr int64_t items_per_unit = n_cols; - const int64_t bytes_per_unit = items_per_unit * data_type_size; - const int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; - const int64_t num_megabanks = smem_banks / words_per_unit; - - /* In the following example, each CTA tile contains 2 rows and 3 colums of - * matrices, each 8x8 size: - * +----------+----------+----------+ - * | matrix 0 | matrix 1 | matrix 2 | - * +----------+----------+----------+ - * | matrix 3 | matrix 4 | matrix 5 | - * +----------+----------+----------+ - * The addresses of different rows in the same matrix are offset by 3 units. - * In this perspective, loading a matrix is a strided memory access with the - * following stride (in units): - */ - - // number of units per row - int64_t row_stride = tile_size_y / items_per_unit; - - /* So the bank conflicting problem is now converted to the following game: - * I have a clock that has one pointer and `num_megabanks` ticks. I start - * my game by making my pointer pointing to somewhere, and turn forward - * the pointer `n_rows` times, each time by `row_stride` ticks. - * This problem can be well modeled by modular arithmetic in number theory - * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. - * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. - * Additions and multiplications are defined in a cyclic manner: - * 5 + 1 = 0, 5 + 2 = 1, 5 + 3 = 2, 5 + 4 = 3, ... - * 2 * 1 = 2, 2 * 2 = 4, 2 * 3 = 0, 2 * 4 = 2, ... - * With this definition, Z is mapped to Z/nZ naturally by i -> i % n [2] - * - * It worth mention that Z/nZ is a "commutative ring", that is, we can use - * addition and multiplication rules just like using normal integers: - * a + b = b + a, a * (b + c) = a * b + a * c, ... - * In short, we can reason about Z/nZ just like we are reasoning about - * integers, except that every number is automatically "% n". - * - * Reference: - * [1] https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n - * [2] The % is under Euclidean definition, that is -1 % 6 is 5 instead of - * -1, see [The Mathematics of Integer Arithmetic] for more detail. But - * we are only interested in non-negative numbers here, so there is no - * need to worry about this problem - */ - - // row_stride in Z/nZ, where n is num_megabanks: - // assert(row_stride >= 0); - // assert(num_megabanks >= 0); - int64_t row_stride_znz = row_stride % num_megabanks; - /* Consider the following function in Z/nZ: - * f(i; init) = init + i * stride - * where init is the initial position of the pointer in the clock when we - * start the game, and stride is the number of ticks we move forward each - * time, and i is the number of times we move forward. For a fixed init, we - * abbrivate f(i; init) as f(i). - * - * In our problem, f(i) is the megabank of the `i`th row of the matrix, and - * `init` is the megabank of the 0th row of the matrix. - * - * One very important property of f(i) is: - * - if f(i1) == f(i2), then for every j, f(i1 + j) = f(i2 + j) - * This property is true because: - * f(i1 + j) = f(i1) + j * stride = f(i2) + j * stride = f(i2 + j) - * - * The above property tells us, as we turn the clock forward: - * - initially, we will go to a never-visited tick in each turn, but, - * - at some point, we will return back to our original position, and, - * - after we return, we start repeat the pervious pattern again and again. - * - * As an example, consider f(i) where init = 0, stride = 6, under Z/8Z: - * i 0 1 2 3 4 5 6 7 - * f(i) 0 6 4 2 0 6 4 2 - * We can see that f(i) is repeating a pattern of four unique numbers - * "0 6 4 2" twice. In our bank conflict problem, this means we are using 4 - * different megabanks, and we have a 2-way conflict. - * - * The question of interest is, does the above observation generalize? That - * is, does f(i) always repeat a pattern of p unique numbers q times? Note - * that p and q must satisfy p * q = n. - * - * The answer to the above question is: yes! Consider the following - * equation: - * f(i1 + j) == f(i1) - * We want to know what is the smallest positive number j that makes the - * above equation true. Because this tells us in how many steps we will see - * repeat. This equation can be simplified as: - * f(i1 + j) == f(i1) + j * stride == f(i1) - * ==> j * stride == 0 - * - * An important tool to study this equation is multiplicative inverse: - * https://en.wikipedia.org/wiki/Modular_multiplicative_inverse - * A number i has multiplicative inverse `minv(i)` in Z/nZ if and only if it - * coprime with n. `minv(i)` is the number that `i * minv(i) == 1`. So in - * Z/nZ, the equation `ax = b` has solution `x = minv(a)*b` if a has - * multiplicative inverse. For example, in Z/15Z, `minv(2) = 8` because - * (2 * 8) % 15 = 1 - * - * stride has an multiplicative inverse if and only if stride coprime with - * n, that is, g := gcd(stride, n) == 1. In such case, the solution to our - * equation j * stride == 0 is j = minv(stride) * 0 = 0, that is: f(i) does - * not repeat, that is: there is no bank conflict. - */ - - int64_t g = std::gcd(num_megabanks, row_stride_znz); - if (g == 1) { - return {g, -1}; // No need to swizzle in this case. - } - - /* For the case where stride does not coprime with n, we note that - * j * stride == 0 in Z/nZ is equivalent to (j * stride) % n = 0 in Z. We - * can write stride and n as: - * stride = s * g, n = m * g - * According to Theorem 4.13 in [The Mathematics of Integer Arithmetic], we - * have: - * (j * stride) % n = 0 - * ==> (j * s) % m * g = 0 - * ==> (j * s) % m = 0 - * which is equivalent to j * s == 0 in Z/mZ. Because s coprime with m, we - * further get: - * j == 0 (in Z/mZ) - * That is, j is a multiple of m in Z. So the smallest positive j that make - * the equation hold is n / g. - * - * That is: f(i) always repeat a pattern of n/g unique numbers g times. - * In other word: we are using n/g megabanks, and we have a g-way bank - * conflict. - * - * Let's use the word "pattern" to refer to the set of values of `f` at - * different `i`, that is: - * pattern k = { f(i; init=k) | i in Z/nZ } - * For the example of stride = 6 under Z/8Z, we have the following patterns - * f(i): 01234567 - * pattern 0: x_x_x_x_ - * pattern 1: _x_x_x_x - * (x => occupied, _ => unoccupied) - */ - - int64_t repeated_pattern_size = num_megabanks / g; - - if (repeated_pattern_size >= n_rows) { - return {g, -1}; // No need to swizzle in this case. - } - - return {g, repeated_pattern_size}; -} - -//! Automatically generates the shared memory swizzled data layout for tma loads -//! in matmul mainloop. The shared memory data layout is always 2D currently. -//! This utility function assumes that the shared_mem_tv has the following -//! structure: [tile_row, tile_col] -//! Returns which swizzle format to use for mma inputs with tma loads. -MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv) { - auto&& [g, repeated_pattern_size] = analyzeSwizzleSharedMemory(shared_mem_tv); - - if (g == 1) { - return MmaInputSmemSwizzle::None; // No need to swizzle in this case. - } - - // 128B swizzle results in 8 x 8 matrix given half precision inputs. - constexpr int64_t n_rows = 8; - - NVF_ERROR( - n_rows % repeated_pattern_size == 0, - "Can not partition matrix into megarows"); - int64_t num_gigarows = n_rows / repeated_pattern_size; - int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size - - /* To further simplify the problem, if we assume: */ - NVF_ERROR( - num_gigarows % num_gigabanks == 0, - "Requires non-square swizzle, which is not supported yet"); - - AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); - // Extract the constant sizes of the swizzled tile - const int64_t inner_dim_size = - swizzle_domain[-1]->extent()->evaluate().as(); - - auto dtype = shared_mem_tv->getDataType().value(); - const int64_t B128_elements = 128 / dataTypeSize(dtype); - const int64_t B64_elements = 64 / dataTypeSize(dtype); - const int64_t B32_elements = 32 / dataTypeSize(dtype); - - if (inner_dim_size % B128_elements == 0) { - return MmaInputSmemSwizzle::B128; - } else if (inner_dim_size % B64_elements == 0) { - return MmaInputSmemSwizzle::B64; - } else if (inner_dim_size % B32_elements == 0) { - return MmaInputSmemSwizzle::B32; - } else { - NVF_THROW("Unsupported swizzle size for TMA shared memory mma inputs"); - } -} - -//! Automatically generates the shared memory swizzled data layout for matmul -//! epilogue. -//! The shared mem data layout is always 2D currently, and this utility -//! function assumes that the shared_mem_tv has the following structure: -//! [tile_row, tile_col] -//! Returns the domain with swizzle. For the case of legacy swizzle, this -//! domain must be set as loop domain. For the case of new swizzle, this domain -//! must be set as allocation domain. -template -AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { - auto&& [g, repeated_pattern_size] = analyzeSwizzleSharedMemory(shared_mem_tv); - - // Create Abstract Tensor from shared memory tensor loop domain. - AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); - - if (g == 1) { - return swizzle_domain; // No need to swizzle in this case. - } - - /* Now we know that we have a g-way bank conflict. How do we remove this - * bank conflict? The answer is to mix the storage of different matrices. - * We first split the matrices along the row axis into g pieces, each piece - * has n/g rows. With this split, each piece occupies exactly one pattern. - * We want to use some non-traditional storage to let different pieces of - * the same matrix to occupy different patterns. - * - * Because Z/nZ has n items, each pattern has n/g different items, so we - * have in total g different patterns. We want to find the corresponding - * `init` values of these g different patterns. - * - * Consider two different init values `init1` and `init2`. When do they - * represent the same pattern? They represent the same pattern if and only - * if `f(0; init2)` falls on the pattern of `init1`, that is, there exist an - * i such that - * f(i; init1) == f(0; init2) - * which simplifies to - * init1 + i * stride == init2 - * ==> init2 - init1 == i * stride - * What values can `i * stride` be? It can be an arbitrary multiple of g: - * i * stride in Z/nZ is (i * stride) % n in Z. Let m = n/g, according to - * Theorem 4.13 in [The Mathematics of Integer Arithmetic] - * (i * stride) % n = (i * s) % m * g - * Because s coprime with m, we know that for an arbitrary value `j` in - * Z/mZ, we can take `i = minv(s) * j` to make `i * s == j`. - * - * That said, for init values that are off by a multiple of g they - * correspond to the same pattern, otherwise they belongs to different - * patterns. So, we can use - * init = 0, 1, ..., g - 1 - * to canonically represent g patterns. Let's call the above - * `init` values "pattern id". - * - * Now we have the idea about how to remove bank conflict: We can do an - * inner split of our row dimension by `repeated_pattern_size` to get - * (repeat, pattern), then different indices of the "repeat" dimension will - * be using the same megabank, and different indices of the "pattern" - * dimension will be using different megabank. We don't need to touch the - * "pattern" dimension, but we need to play with the "repeat" dimension to - * interleave it with matrice ids so that each matrix is distributed across - * different banks. - * - * For example, if we have repeated_pattern_size = 4, we would want to do - * something like below: - * +----------+----------+ - * 0| | | - * 1| matrix 0 | matrix 1 | - * 2| | | - * 3| | | - * +----------+----------+ - * 4| | | - * 5| matrix 1 | matrix 0 | - * 6| | | - * 7| | | - * +----------+----------+ - * - * We can consider each repeated_pattern_size rows as a gigarow, and each - * repeated_pattern_size megabanks as a gigabank. Note that megabank is a - * contiguous chunk of banks, but gigabank is not contiguous. Indeed, - * nearby megabanks in a gigabank has a distance of `g` megabanks - */ - - // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. - // For epilogue, threads in a warp is organized as 8 rows x 4 columns. - // Each thread vectorized write 2 items, so 8 items per row. - //--0--1--2--3 - //--4--5--6--7 - //--8--9--10-11 - //--12-13-14-15 - //--16-17-18-19 - //--20-21-22-23 - //--24-25-26-27 - //--28-29-30-31 - constexpr int64_t n_rows = 8; - constexpr int64_t n_cols = 8; - - NVF_ERROR( - n_rows % repeated_pattern_size == 0, - "Can not partition matrix into megarows"); - int64_t num_gigarows = n_rows / repeated_pattern_size; - int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size - - // -2 -1 - // [row, col] - if (repeated_pattern_size > 1) { - swizzle_domain.split(-2, repeated_pattern_size); - } - swizzle_domain.split(-1, n_cols); - // -4 -3 -2 -1 - // [gigarow id, gigarow, matrix id, matrix] - swizzle_domain.split(-2, num_gigabanks); - // -5 -4 -3 -2 -1 - // [gigarow id, gigarow, y outer, gigabank id, matrix] - // Note that megabanks inside a gigabank are not contiguous, so the gigabank - // id is -2 instead of -3 - - /* We want to evenly distribute gigarows across gigabanks, for example, if - * we have 7 gigarows and 3 gigabanks, then we might distribute them as: - * +---+ - * |x | - * | x | - * | x| - * |x | - * | x | - * | x| - * |x | - * +---+ - * considering all matrices, this is a swizzle function like: - * +---+ - * |012| - * |201| - * |120| - * |012| - * |201| - * |120| - * |012| - * +---+ - * which is a cyclic shift. - * - * Note that because num_gigabanks (a.k.a. g) divide num_megabanks and - * row_stride_znz (which is row_stride % num_megabanks), g should also - * divide row_stride, because according to the fundamental - * division-with-remainder property (see doc/math/integer-division.md): - * row_stride = q * num_megabanks + row_stride_znz - * which means, we can just consider each num_gigabanks matrices as a group, - * and we always have complete groups (i.e. no group has less than - * num_gigabanks matrices). Interleaving the memory of matrices within each - * group should be enough to fully remove bank conflict. - */ - - /* To further simplify the problem, if we assume: */ - NVF_ERROR( - num_gigarows % num_gigabanks == 0, - "Requires non-square swizzle, which is not supported yet"); - /* Then we can partition gigarows into full waves, each wave has - * num_gigabanks gigarows. This partition creates square dimensions, making - * the swizzle implementation easier */ - - // -5 -4 -3 -2 -1 - // [gigarow id, gigarow, y outer, gigabank id, matrix] - int axis_of_gigarow_id = repeated_pattern_size > 1 ? -5 : -4; - swizzle_domain.split(axis_of_gigarow_id, num_gigabanks); - // -6 -5 -4 -3 -2 -1 - // [wave id, wave, gigarow, y outer, gigabank id, matrix] - - // swizzle wave with gigabank id to make threads in a wave access different - // gigabank. Apply swizzle only when shared_mem_tv is stored in shared - // memory. - // TODO: This is a temporary workaround for the following issue: - // For the mma output, we have the following schedule: - // rFactor: [...., X, Y] -> mma-swizzle transformations -> loop - // For epilogue smem tensor, the schedule is - // rFactor: [...., X, Y] -> split -> [...., X1, X2, X3, Y1, Y2, Y3] - // -> swizzle X2, Y2 -> [...., X1, X2', X3, Y1, Y2', Y3] - // -> merge back -> [...., X', Y'] - // -> mma-swizzle transformations -> loop - // The mma-swizzle transformations for the mma output and epilogue smem - // tensor are the same. In indexing, we do require {X, X'} and {Y, Y'} to be - // mapped in CA map, however, we currently can not handle that. So we have - // to do the same split and merge to the mma output without actually - // applying the swizzle, and this check is to detect and handle this - // specific case. We should remove this special handling when we fix our CA - // mapping. - using SwizzleTypeMaybeLegacy = - std::conditional_t; - if (isPowOf2(num_gigabanks)) { - swizzle_domain.swizzle(SwizzleTypeMaybeLegacy::XOR, axis_of_gigarow_id, -2); - } else { - swizzle_domain.swizzle( - SwizzleTypeMaybeLegacy::CyclicShift, axis_of_gigarow_id, -2); - } - - if (legacy) { - if (repeated_pattern_size > 1) { - swizzle_domain.merge(-6); - } - swizzle_domain.merge(-5); - - // merge back tile_size_y - swizzle_domain.merge(-3); - swizzle_domain.merge(-2); - } - - return swizzle_domain; -} - -} // namespace - MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { ValGroup vg = graph_->toGroup(id); auto it = id_roles_.find(vg); @@ -813,7 +319,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() { tv->promoteReuse(); } mma_utils::orderTiledConcreteIdAsMaybeAllocationDomain(tv); - MmaInputSmemSwizzle swizzle_type = tmaSwizzleSharedMemory(tv); + MmaInputSmemSwizzle swizzle_type = mma_utils::tmaSwizzleSharedMemory(tv); tv->applyMmaSwizzleForTMALoad(swizzle_type); } }; @@ -1077,14 +583,14 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { dc->setAllocationDomain(s.as(), true); } - MmaInputSmemSwizzle swizzle = tmaSwizzleSharedMemory(d_smem); + MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem); // Schedule shared memory cache; Output from StMatrix - scheduleStMatrixForMmaOutput( + mma_utils::scheduleStMatrixForMmaOutput( d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n); // Schedule global memory output; Output from TMA Store - scheduleTMAStoreForMmaOutput(d, swizzle); + mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle); } } } @@ -1240,94 +746,4 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { */ } -void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle, - int64_t tile_m, - int64_t tile_n) { - NVF_ERROR( - ((tile_m == 16 && tile_n == 16) || (tile_m == 16 && tile_n == 8)), - "We only support 16x16 and 16x16 stmatrix now"); - - NVF_ERROR( - tv->dtype() == DataType::Half, "we only support half type in stmatrix"); - - // [M, N] -> [128(TIDx), N/8 , 2 , 2] - auto s = - mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain()); - - if (swizzle != MmaInputSmemSwizzle::None) { - // Create tma store allocation domain with swizzle - scheduleTMAStoreForMmaOutput(tv, swizzle); - } - - tv->setLoopDomain(s.as()); - - if (tile_m == 16 && tile_n == 16) { - // Let [M, N] be [64, 32] - // After scheduleMmaOutputAllocation: [128(TIDx), 4, 2, 2] - // [128(TIDx), 4(n), 2, 2] -> [128(TIDx), 2(no), 2(ni), 2, 2] - tv->split(-3, 2); - // [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 2(ni), 2, 2] - tv->reorder({{-4, 0}}); - // [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 8 (vectorize)] - tv->merge(-3); - tv->merge(-2); - } else if (tile_m == 16 && tile_n == 8) { - // Let [M, N] be [64, 16] - // After scheduleMmaOutputAllocation: [128(TIDx), 2, 2, 2] - // [128(TIDx), 2, 2, 2] -> [2, 128(TIDx), 2, 2] - tv->reorder({{-3, 0}}); - // [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)] - tv->merge(-2); - } - tv->axis(-1)->parallelize(ParallelType::Vectorize); -} - -void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle) { - // [BDX, BDY, TDY, MI, NI] - // skip all but last 2 iterDomains - int64_t num_ids_to_skip = - static_cast(tv->getLoopDomain().size() - 2); - - NVF_ERROR(num_ids_to_skip >= 0); - if (swizzle == MmaInputSmemSwizzle::None) { - // For no-swizzle case, the entire tile are divided into 8x8 core matrices, - // and each core matrix resides in a contiguous 8*8*2 bytes region in shared - // memory. [K, N] - tv->split(-2, 8); - tv->split(-1, 8); - // [Ko, K8, No, N8] - tv->reorder({{-2, -3}}); - } else { - auto dtype = tv->getDataType().value(); - - // In the comments below I assume K=16, N=32, swizzle=32, dtype = half. - - // split the inner-dim - // [K(16), N(32)] -> [K(16), NO(2), NI(16)] - tv->split(-1, getBytesFromSwizzle(swizzle) / dataTypeSize(dtype)); - - // [NO, K, NI] - the TMA Box is [K, NI] - tv->reorder({{-2, -3}}); - - // [NO, K, NI] -> - // [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] - tv->swizzleTMABox(swizzle); - num_ids_to_skip += 1; - } - - // The shared memory producer must have the swizzled allocation domain. - // The global memory consumer must have the ParallelType::Bulk iterDomains. - if (tv->getMemoryType() == MemoryType::Shared) { - // Set the allocation to the loop domain. - tv->setAllocationDomain(tv->getLoopDomain(), true); - } else { - mma_utils::MmaSwizzler::parallelizeAsBulkSkippingFirstIDs( - tv, num_ids_to_skip); - } -} - } // namespace nvfuser diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index a736823f75c..bf7bc1df0f5 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -178,20 +178,6 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { void setUpCircularBuffering(); - //! Schedules the copy operation of output of a Mma op which resided in the - //! registers to shared memory. - void scheduleStMatrixForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle, - int64_t tile_m, - int64_t tile_n); - - //! Schedules the copy operation of output of a Mma op which resided in the - //! shared memory to global memory. - void scheduleTMAStoreForMmaOutput( - TensorView* tv, - MmaInputSmemSwizzle swizzle); - // Map TensorView's iterDomain to its ValGroup. // Then, find the MatmulDimRole for the ValGroup. // Return MatmulDimRole for IterDomain diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 13fd337e113..e9a24851a5a 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1256,38 +1256,53 @@ inline void resolveTvToMatmulDimRolesMapping( } // anonymous namespace -void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n) { - NVF_ERROR( - tv->getMemoryType() == MemoryType::Global, - "TMA Store should write to global memory"); +void scheduleTMAStoreForMmaOutput(TensorView* tv, MmaInputSmemSwizzle swizzle) { + // [BDX, BDY, TDY, MI, NI] + // skip all but last 2 iterDomains + int64_t num_ids_to_skip = + static_cast(tv->getLoopDomain().size() - 2); - NVF_ERROR( - tv->definition()->isA(), - "This tensor should be the result of a LoadStoreOp"); + NVF_ERROR(num_ids_to_skip >= 0); + if (swizzle == MmaInputSmemSwizzle::None) { + // For no-swizzle case, the entire tile are divided into 8x8 core matrices, + // and each core matrix resides in a contiguous 8*8*2 bytes region in shared + // memory. [K, N] + tv->split(-2, 8); + tv->split(-1, 8); + // [Ko, K8, No, N8] + tv->reorder({{-2, -3}}); + } else { + auto dtype = tv->getDataType().value(); - NVF_ERROR( - tv->definition()->as()->opType() == - LoadStoreOpType::CpAsyncBulkTensorTile, - "This is not a TMA operation"); + // In the comments below I assume K=16, N=32, swizzle=32, dtype = half. - NVF_ERROR( - tv->definition() - ->as() - ->in() - ->as() - ->getMemoryType() == MemoryType::Shared, - "Producer should be in shared memory"); - - // [M(m), N(n)] -> [MO(1), MI(m), NO(1), NI(n)] - tv->split(-2, m); - tv->split(-1, n); - // [MO(1), MI(m), NO(1), NI(n)] -> [MO(1), NO(1), MI(m), NI(n)] - tv->reorder({{-2, -3}}); - mma_utils::MmaSwizzler::parallelizeAsBulkSkippingFirstIDs(tv, 2); + // split the inner-dim + // [K(16), N(32)] -> [K(16), NO(2), NI(16)] + tv->split(-1, getBytesFromSwizzle(swizzle) / dataTypeSize(dtype)); + + // [NO, K, NI] - the TMA Box is [K, NI] + tv->reorder({{-2, -3}}); + + // [NO, K, NI] -> + // [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] + tv->swizzleTMABox(swizzle); + num_ids_to_skip += 1; + } + + // The shared memory producer must have the swizzled allocation domain. + // The global memory consumer must have the ParallelType::Bulk iterDomains. + if (tv->getMemoryType() == MemoryType::Shared) { + // Set the allocation to the loop domain. + tv->setAllocationDomain(tv->getLoopDomain(), true); + } else { + mma_utils::MmaSwizzler::parallelizeAsBulkSkippingFirstIDs( + tv, num_ids_to_skip); + } } void scheduleStMatrixForMmaOutput( TensorView* tv, + MmaInputSmemSwizzle swizzle, int64_t tile_m, int64_t tile_n) { NVF_ERROR( @@ -1300,6 +1315,12 @@ void scheduleStMatrixForMmaOutput( // [M, N] -> [128(TIDx), N/8 , 2 , 2] auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain()); + + if (swizzle != MmaInputSmemSwizzle::None) { + // Create tma store allocation domain with swizzle + mma_utils::scheduleTMAStoreForMmaOutput(tv, swizzle); + } + tv->setLoopDomain(s.as()); if (tile_m == 16 && tile_n == 16) { @@ -2174,6 +2195,288 @@ std::optional> allPatternRoles( id_roles, tensor_roles_opt.getData()}; } +namespace { +// This function returns a pair of integers. The first integer is the gcd +// between megabanks and row stride. The second integer is the repeat pattern +// size. If the gcd is 1, then no swizzle is necessary to resolve bank +// conflicts. In that case, the second integer is irrelevant and -1 is returned. +std::pair analyzeSwizzleSharedMemory( + TensorView* shared_mem_tv) { + NVF_ERROR(shared_mem_tv->getMemoryType() == MemoryType::Shared); + AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); + + // Check that the innermost 2 dimensions are concrete and static + // sized so that the swizzle function can be defined. + NVF_ERROR( + (int64_t)swizzle_domain.size() >= 2, + "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", + shared_mem_tv->toString()); + mma_utils::checkConcreteStaticDim(swizzle_domain[-2]); + mma_utils::checkConcreteStaticDim(swizzle_domain[-1]); + + // Extract the constant sizes of the swizzled tile + const int64_t tile_size_x = + swizzle_domain[-2]->extent()->evaluate().as(); + const int64_t tile_size_y = + swizzle_domain[-1]->extent()->evaluate().as(); + + // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. + // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit + // (i.e. float) + const int64_t data_type_size = dataTypeSize(*shared_mem_tv->getDataType()); + NVF_ERROR(data_type_size == 2 || data_type_size == 4); + + // For main loop, ldmatrix loads a n_rows x n_cols = 8 x 8 matrix each time. + // For epilogue, threads in a warp is organized as 8 rows x 4 columns. + // Each thread vectorized write 2 items, so 8 items per row. + //--0--1--2--3 + //--4--5--6--7 + //--8--9--10-11 + //--12-13-14-15 + //--16-17-18-19 + //--20-21-22-23 + //--24-25-26-27 + //--28-29-30-31 + constexpr int64_t n_rows = 8; + constexpr int64_t n_cols = 8; + + // Column size of the tile needs to be multiples of 8 for ldmatrix to work. + NVF_ERROR( + tile_size_x >= n_rows && tile_size_x % n_rows == 0 && + tile_size_y >= n_cols && tile_size_y % n_cols == 0, + "Prolog swizzle for ldmatrix, illegal tile size for prolog swizzle", + tile_size_x, + "x", + tile_size_y); + + /* Note [How to remove bank conflict for ldmatrix?] + * + * **This note is interleaved with code, I suggest reading this note like + * reading a jupyter notebook** + * + * Our task is to make sure different rows does not fall into the same + * bank of shared memory. + * + * Introduction to bank conflict can be found at page 54-72 of: + * https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf + * + * When we talk about bank conflict removal, we are talking about the + * following task: + * "there are 32 banks, and each bank contains one 4-byte word, we want to + * make sure different lanes in a warp does not access different word + * addresses in the same bank" + * For example, if thread 0 is accessing word address 1, and thread 1 is + * accessing word address 33, then these two threads will have a bank + * conflict because they are accessing different word addresses in the same + * bank. However, if thread 0 is accessing byte address 4 and thread 1 is + * accessing byte address 6 then there will be no bank conflict because 4 + * and 6 both belong to word 1. + */ + + constexpr int64_t smem_bytes_per_word = 4; + constexpr int64_t smem_banks = 32; + + /* but here, for our convenience, because ldmatrix always use vectorized + * access of 8 items = 16 bytes = 4 words, we further group words into + * units: we consider each 4 words as a "unit", and each 4 banks as a + * "megabank". So we can rephrase our task as: + * "there are 8 megabanks, and each megabanks contains one 4-word unit, we + * want to make sure different lanes in a warp does not access different + * unit addresses in the same megabank" + * In this terminology, matrices are in the row major format, each matrix + * has 8 rows, and each row has exactly one unit. + */ + + constexpr int64_t items_per_unit = n_cols; + const int64_t bytes_per_unit = items_per_unit * data_type_size; + const int64_t words_per_unit = bytes_per_unit / smem_bytes_per_word; + const int64_t num_megabanks = smem_banks / words_per_unit; + + /* In the following example, each CTA tile contains 2 rows and 3 colums of + * matrices, each 8x8 size: + * +----------+----------+----------+ + * | matrix 0 | matrix 1 | matrix 2 | + * +----------+----------+----------+ + * | matrix 3 | matrix 4 | matrix 5 | + * +----------+----------+----------+ + * The addresses of different rows in the same matrix are offset by 3 units. + * In this perspective, loading a matrix is a strided memory access with the + * following stride (in units): + */ + + // number of units per row + int64_t row_stride = tile_size_y / items_per_unit; + + /* So the bank conflicting problem is now converted to the following game: + * I have a clock that has one pointer and `num_megabanks` ticks. I start + * my game by making my pointer pointing to somewhere, and turn forward + * the pointer `n_rows` times, each time by `row_stride` ticks. + * This problem can be well modeled by modular arithmetic in number theory + * using the concept "integers modulo n" a.k.a. "Z/nZ"[1]. + * Take n = 6 as an example, Z/6Z only has 6 elements: 0, 1, 2, 3, 4, 5. + * Additions and multiplications are defined in a cyclic manner: + * 5 + 1 = 0, 5 + 2 = 1, 5 + 3 = 2, 5 + 4 = 3, ... + * 2 * 1 = 2, 2 * 2 = 4, 2 * 3 = 0, 2 * 4 = 2, ... + * With this definition, Z is mapped to Z/nZ naturally by i -> i % n [2] + * + * It worth mention that Z/nZ is a "commutative ring", that is, we can use + * addition and multiplication rules just like using normal integers: + * a + b = b + a, a * (b + c) = a * b + a * c, ... + * In short, we can reason about Z/nZ just like we are reasoning about + * integers, except that every number is automatically "% n". + * + * Reference: + * [1] https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n + * [2] The % is under Euclidean definition, that is -1 % 6 is 5 instead of + * -1, see [The Mathematics of Integer Arithmetic] for more detail. But + * we are only interested in non-negative numbers here, so there is no + * need to worry about this problem + */ + + // row_stride in Z/nZ, where n is num_megabanks: + // assert(row_stride >= 0); + // assert(num_megabanks >= 0); + int64_t row_stride_znz = row_stride % num_megabanks; + /* Consider the following function in Z/nZ: + * f(i; init) = init + i * stride + * where init is the initial position of the pointer in the clock when we + * start the game, and stride is the number of ticks we move forward each + * time, and i is the number of times we move forward. For a fixed init, we + * abbrivate f(i; init) as f(i). + * + * In our problem, f(i) is the megabank of the `i`th row of the matrix, and + * `init` is the megabank of the 0th row of the matrix. + * + * One very important property of f(i) is: + * - if f(i1) == f(i2), then for every j, f(i1 + j) = f(i2 + j) + * This property is true because: + * f(i1 + j) = f(i1) + j * stride = f(i2) + j * stride = f(i2 + j) + * + * The above property tells us, as we turn the clock forward: + * - initially, we will go to a never-visited tick in each turn, but, + * - at some point, we will return back to our original position, and, + * - after we return, we start repeat the pervious pattern again and again. + * + * As an example, consider f(i) where init = 0, stride = 6, under Z/8Z: + * i 0 1 2 3 4 5 6 7 + * f(i) 0 6 4 2 0 6 4 2 + * We can see that f(i) is repeating a pattern of four unique numbers + * "0 6 4 2" twice. In our bank conflict problem, this means we are using 4 + * different megabanks, and we have a 2-way conflict. + * + * The question of interest is, does the above observation generalize? That + * is, does f(i) always repeat a pattern of p unique numbers q times? Note + * that p and q must satisfy p * q = n. + * + * The answer to the above question is: yes! Consider the following + * equation: + * f(i1 + j) == f(i1) + * We want to know what is the smallest positive number j that makes the + * above equation true. Because this tells us in how many steps we will see + * repeat. This equation can be simplified as: + * f(i1 + j) == f(i1) + j * stride == f(i1) + * ==> j * stride == 0 + * + * An important tool to study this equation is multiplicative inverse: + * https://en.wikipedia.org/wiki/Modular_multiplicative_inverse + * A number i has multiplicative inverse `minv(i)` in Z/nZ if and only if it + * coprime with n. `minv(i)` is the number that `i * minv(i) == 1`. So in + * Z/nZ, the equation `ax = b` has solution `x = minv(a)*b` if a has + * multiplicative inverse. For example, in Z/15Z, `minv(2) = 8` because + * (2 * 8) % 15 = 1 + * + * stride has an multiplicative inverse if and only if stride coprime with + * n, that is, g := gcd(stride, n) == 1. In such case, the solution to our + * equation j * stride == 0 is j = minv(stride) * 0 = 0, that is: f(i) does + * not repeat, that is: there is no bank conflict. + */ + + int64_t g = std::gcd(num_megabanks, row_stride_znz); + if (g == 1) { + return {g, -1}; // No need to swizzle in this case. + } + + /* For the case where stride does not coprime with n, we note that + * j * stride == 0 in Z/nZ is equivalent to (j * stride) % n = 0 in Z. We + * can write stride and n as: + * stride = s * g, n = m * g + * According to Theorem 4.13 in [The Mathematics of Integer Arithmetic], we + * have: + * (j * stride) % n = 0 + * ==> (j * s) % m * g = 0 + * ==> (j * s) % m = 0 + * which is equivalent to j * s == 0 in Z/mZ. Because s coprime with m, we + * further get: + * j == 0 (in Z/mZ) + * That is, j is a multiple of m in Z. So the smallest positive j that make + * the equation hold is n / g. + * + * That is: f(i) always repeat a pattern of n/g unique numbers g times. + * In other word: we are using n/g megabanks, and we have a g-way bank + * conflict. + * + * Let's use the word "pattern" to refer to the set of values of `f` at + * different `i`, that is: + * pattern k = { f(i; init=k) | i in Z/nZ } + * For the example of stride = 6 under Z/8Z, we have the following patterns + * f(i): 01234567 + * pattern 0: x_x_x_x_ + * pattern 1: _x_x_x_x + * (x => occupied, _ => unoccupied) + */ + + int64_t repeated_pattern_size = num_megabanks / g; + + if (repeated_pattern_size >= n_rows) { + return {g, -1}; // No need to swizzle in this case. + } + + return {g, repeated_pattern_size}; +} +} // namespace + +MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv) { + auto&& [g, repeated_pattern_size] = analyzeSwizzleSharedMemory(shared_mem_tv); + + if (g == 1) { + return MmaInputSmemSwizzle::None; // No need to swizzle in this case. + } + + // 128B swizzle results in 8 x 8 matrix given half precision inputs. + constexpr int64_t n_rows = 8; + + NVF_ERROR( + n_rows % repeated_pattern_size == 0, + "Can not partition matrix into megarows"); + int64_t num_gigarows = n_rows / repeated_pattern_size; + int64_t num_gigabanks = g; // also = num_megabanks / repeated_pattern_size + + /* To further simplify the problem, if we assume: */ + NVF_ERROR( + num_gigarows % num_gigabanks == 0, + "Requires non-square swizzle, which is not supported yet"); + + AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); + // Extract the constant sizes of the swizzled tile + const int64_t inner_dim_size = + swizzle_domain[-1]->extent()->evaluate().as(); + + auto dtype = shared_mem_tv->getDataType().value(); + const int64_t B128_elements = 128 / dataTypeSize(dtype); + const int64_t B64_elements = 64 / dataTypeSize(dtype); + const int64_t B32_elements = 32 / dataTypeSize(dtype); + + if (inner_dim_size % B128_elements == 0) { + return MmaInputSmemSwizzle::B128; + } else if (inner_dim_size % B64_elements == 0) { + return MmaInputSmemSwizzle::B64; + } else if (inner_dim_size % B32_elements == 0) { + return MmaInputSmemSwizzle::B32; + } else { + NVF_THROW("Unsupported swizzle size for TMA shared memory mma inputs"); + } +} + } // namespace mma_utils std::string toString(const mma_utils::AbstractMatmulTensor& abten) { diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 13d5524101a..c88fe4926e3 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -245,14 +245,14 @@ class MmaSwizzler { }; //! Schedules the copy operation of output of a Mma op which resided in the -//! shared memory to global memory. This assumes the outout of Mma in the -//! shared memory is of the form [M, N]. -//! This is tiled to [MO(1), NO(1), MI(m), NI(n)]. The inner two dims are -//! marked parallel type bulk. -void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n); +//! shared memory to global memory. +void scheduleTMAStoreForMmaOutput(TensorView* tv, MmaInputSmemSwizzle swizzle); +//! Schedules the copy operation of output of a Mma op which resided in the +//! registers to shared memory. void scheduleStMatrixForMmaOutput( TensorView* tv, + MmaInputSmemSwizzle swizzle, int64_t tile_m, int64_t tile_n); @@ -482,6 +482,13 @@ inline void checkConcreteStaticDim(const AbstractId& abs_id) { id->toString()); } +//! Automatically generates the shared memory swizzled data layout for tma loads +//! in matmul mainloop. The shared memory data layout is always 2D currently. +//! This utility function assumes that the shared_mem_tv has the following +//! structure: [tile_row, tile_col] +//! Returns which swizzle format to use for mma inputs with tma loads. +MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv); + } // namespace mma_utils std::string toString(const mma_utils::AbstractMatmulTensor& abten); diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 87657e50997..00830bd9a37 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3773,17 +3773,23 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { tv3c->setLoopDomain(s.as()); tv3c->setAllocationDomain(s.as(), true); - // We'll use stmatrix.x4 to store from reg to shared memory - fusion.manage("st_matrix_m_tile", (int64_t)16); - fusion.manage("st_matrix_n_tile", (int64_t)16); + constexpr int64_t stmatrix_tile_m = 16; + constexpr int64_t stmatrix_tile_n = 16; + fusion.manage("st_matrix_m_tile", stmatrix_tile_m); + fusion.manage("st_matrix_n_tile", stmatrix_tile_n); fusion.manage("st_matrix_m", getM(macro)); fusion.manage("st_matrix_n", getN(macro)); + MmaInputSmemSwizzle store_swizzle = + mma_utils::tmaSwizzleSharedMemory(tv3_shmem); + // This internally calls - // mma_utils::MmaSwizzler::scheduleMmaOutputAllocation - mma_utils::scheduleStMatrixForMmaOutput(tv3_shmem, 16, 16); + // Schedule shared memory cache; Output from StMatrix + mma_utils::scheduleStMatrixForMmaOutput( + tv3_shmem, store_swizzle, stmatrix_tile_m, stmatrix_tile_n); - mma_utils::scheduleTMAStoreForMmaOutput(tv3, M, N); + // Schedule global memory output; Output from TMA Store + mma_utils::scheduleTMAStoreForMmaOutput(tv3, store_swizzle); } inlineMost(); diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 7679d52590a..ce359b41d32 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -2864,7 +2864,8 @@ TEST_P(StMatrixTest, Regular) { tv1->setLoopDomain(s.as()); tv1->setAllocationDomain(s.as(), true); - mma_utils::scheduleStMatrixForMmaOutput(tv2, tile_m, tile_n); + mma_utils::scheduleStMatrixForMmaOutput( + tv2, /*swizzle=*/MmaInputSmemSwizzle::None, tile_m, tile_n); tv3->merge(0); tv3->split(0, 32); diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index 83cda79a43e..7e5ed33a8a6 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -531,9 +531,10 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { tv2->axis(-3)->parallelize(ParallelType::Mma); } - mma_utils::scheduleStMatrixForMmaOutput(tv3, tile_m, tile_n); + MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3); + mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n); - mma_utils::scheduleTMAStoreForMmaOutput(tv4, getM(macro), getN(macro)); + mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle); auto inputs = matmulAtInput3DHopperRS( getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype)); From 214d598c853d5f76b28bd51a391b52bdfeabe585 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 10 Dec 2024 14:01:41 -0800 Subject: [PATCH 2/4] Very naive and stupid CGA support (#3557) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a super naive CGA support. It is by no means how we should design CGA, and not even an incremental step. But this PR is simple enough and it does provide us with an additional parameter to tune about. Perf on H100: ``` Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 33.4 134047 1 134047.0 134047.0 134047 134047 0.0 ::nvfuser_none_f0_c0_r0_g0(::Tensor<::__half, (int)3, (int)3>, … 22.9 92031 1 92031.0 92031.0 92031 92031 0.0 nvjet_hsh_128x256_64x4_2x1_v_bz_coopA_NTN ``` nvFuser/cuBLAS: 68.7% --- csrc/codegen.cpp | 11 ++++++++++- tests/cpp/test_matmul.cpp | 7 +++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 727d25beb53..0060e626fe6 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -273,7 +273,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Generates the kernel function declaration void genDeclaration(const std::string& kernel_name) { - code_ << "__global__ void " << kernel_name << "("; + code_ << "__global__ void "; + if (kernel_->hasManaged("cluster_dims")) { + auto cluster_dims = + kernel_->getManaged>( + "cluster_dims"); + code_ << "__cluster_dims__(" << std::get<0>(cluster_dims) << ", " + << std::get<1>(cluster_dims) << ", " << std::get<2>(cluster_dims) + << ") "; + } + code_ << kernel_name << "("; std::unordered_set unique_args; diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 00830bd9a37..3e9f7f553c8 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3664,6 +3664,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { const int64_t cta_m = 2 * getM(macro); const int64_t cta_n = 1 * getN(macro); + constexpr std::tuple cluster_dims{2, 1, 1}; + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); fusion.addInput(tv0); @@ -3679,6 +3681,11 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { auto tv3 = castOp(DataType::Half, tv2); fusion.addOutput(tv3); + if constexpr ( + cluster_dims != std::tuple{1, 1, 1}) { + fusion.manage("cluster_dims", cluster_dims); + } + auto mma_ops = ir_utils::getOpsOfType(&fusion); NVF_CHECK( 1 == mma_ops.size(), From 89c47f695b296eb4ffd27984bd4c953fc3f3264b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 10 Dec 2024 14:46:53 -0800 Subject: [PATCH 3/4] WAR for supporting the rotation + residual pattern (#3555) Stacked on top of #3549. This is also a WAR for #3455 and necessary to schedule RoPE-like rotation patterns. Because of the issue, a tensor may have two IDs that are exactly mapped. For example, when an ID is sliced to half and then is padded back to the same size, and the final output ID is used with the initial input ID, the initial input and the final output IDs get mapped together. This can make it difficult to use `scheduleLoopDomainsLike`. For example, if a reference has a split that is done with the final output ID, and we want to replay the split on other tensors, it becomes ambiguous whether the split is done with the initial input or the final output since both are exactly mapped. To avoid this ambiguity, this PR adds a flag to indicate that we just want to update the current loop domain with a reference domain. As seen in the added tests, this flag is used to propagate the scheduling of a reference tensor once all resize ops are propagated to inputs. Specifically, the overall scheduling follows this pattern: 1. Propagate all slice and pad ops to fusion inputs 2. Pick and schedule a reference tensor 3. Propagate the scheduling of the reference tensor to the other tensors `scheduleLoopDomainsLike` with the flag is used at step 3. For that step, we know that we don't need to schedule each tensor with a complex replay path, like some backward ops followed by some other forward ops, but we just need to update the current loop domain by replaying the diff with the reference domain. --- .../scheduler/tools/loop_domain_scheduler.cpp | 44 +++- csrc/scheduler/tools/loop_domain_scheduler.h | 7 +- tests/cpp/test_resize.cpp | 193 ++++++++++++++++++ 3 files changed, 233 insertions(+), 11 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index e451f8073d7..f04a2f2271e 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -100,8 +100,11 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch { class LoopDomainScheduler { public: - LoopDomainScheduler(std::vector ref_loop_dom) - : ref_loop_dom_(std::move(ref_loop_dom)) { + LoopDomainScheduler( + std::vector ref_loop_dom, + bool update_loop_domain_only = false) + : ref_loop_dom_(std::move(ref_loop_dom)), + update_loop_domain_only_(update_loop_domain_only) { NVF_ERROR(!ref_loop_dom_.empty()); // For now, ref must not be a broadcast domain @@ -174,6 +177,9 @@ class LoopDomainScheduler { private: std::vector ref_loop_dom_; + // If true, uses the current loop domain as the starting domain and + // updates it to make it look like the given reference loop domain + bool update_loop_domain_only_ = false; std::unique_ptr id_model_; ValGroups ref_id_groups_; ValGroups all_ancestors_of_ref_; @@ -188,9 +194,14 @@ void LoopDomainScheduler::schedule(TensorView* tv) const { // All of the existing IDs are reused as much as possible to // minimize creating new IDs. - auto all_ids = tv->domain()->allIDs(); + std::unordered_map group_to_id; ValGroups all_id_groups; + // When update_mode_ is true, only the loop domain IDs are reused as + // we attempt to transform the current loop domain to look like the + // reference loop domain. + auto all_ids = + update_loop_domain_only_ ? tv->getLoopDomain() : tv->domain()->allIDs(); for (auto id : all_ids) { const auto& group = graph().toGroup(id); group_to_id.emplace(group, id); @@ -297,9 +308,10 @@ void LoopDomainScheduler::schedule(TensorView* tv) const { // See LoopDomainSchedulingTest.ReshapeTraversalDirection for a // concrete example. ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { - // Find the path to the root domain of the tensor. It is important - // to use the root domain if available since there can be multiple - // forward paths to the logical domain in the ValGraph. For example, + // If not with the update mode, find the path to the root domain of + // the tensor. It is important to use the root domain if available since there + // can be multiple forward paths to the logical domain in the ValGraph. For + // example, // // t0 = [i0] // t1 = reshape(t0, {i0}, {i0/4, 4}) @@ -316,8 +328,12 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { // mean the t2 logical domain would have another definition (exactly mapped // with the t4 merge reshape). This issue can be avoided by using the root // domain of tv2 as the target of path finding. - ValGroups tv_target_domains = - graph().toGroups(TensorDomain::noBroadcasts(tv->getMaybeRootDomain())); + // + // In the case of the update mode, the target should be just the + // current loop domain of the tensor. + ValGroups tv_target_domains = graph().toGroups(TensorDomain::noBroadcasts( + update_loop_domain_only_ ? tv->getLoopDomain() + : tv->getMaybeRootDomain())); // If all the target domains are an ancestor of the reference // domains, just a single backward BFS should be enough to find a @@ -337,6 +353,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { .first; } + // In the case of the update mode, the path from the reference is + // assumed to just a backward traversal path. + NVF_ERROR( + !update_loop_domain_only_, + "Trying to update the current loop domain but could not find a valid path from the reference: ", + tv->toString()); + // Find the forward path from the ancestors to the target tensor auto forward_path = ValGraphBFS::getExprGroupsBetween( graph(), @@ -373,12 +396,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { void scheduleLoopDomainsLike( const std::vector& tvs, - const std::vector& ref_loop_dom) { + const std::vector& ref_loop_dom, + bool update_loop_domain_only) { if (tvs.empty()) { return; } - LoopDomainScheduler scheduler(ref_loop_dom); + LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only); for (auto tv : tvs) { // Loop domain of fusion inputs should have no meaning diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index 6dd79240ed4..5939c9d31e2 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -20,9 +20,14 @@ namespace scheduler_tools { // Create the loop domain of given tensors as specified by the // reference. The new loop domain is connected to the existing IDs of // each tensor by replaying exprs found in the Exact ValGraph. +// +// If update_loop_domain_only is true, uses the current loop domain as +// the starting domain and updates it to make it look like the given +// reference loop domain. void scheduleLoopDomainsLike( const std::vector& tvs, - const std::vector& ref_loop_dom); + const std::vector& ref_loop_dom, + bool update_loop_domain_only = false); // Replay a transform expr on the loop domain of each of the given // tensors. If the input of the transform is exact mapped with the loop diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index a99e1bed4da..7d9e357d18b 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -4373,6 +4373,199 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } +// RoPE-like rotation patten +TEST_F(ResizeTest, SliceRotateCat) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({-1, 100}); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv3 = set(tv0); + + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); + + auto tv5 = cat({tv4, tv2}, 1); + + fusion.addOutput(tv5); + + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv5; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + { + IdModel id_model(&fusion, false); + id_model.buildExactGraph(); + std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } + + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()); + } + + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// RoPE-like rotation and residual patten +TEST_F(ResizeTest, SliceRotateCatResidual) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({-1, 100}); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv3 = set(tv0); + + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); + + auto tv5 = cat({tv4, tv2}, 1); + + auto tv6 = add(tv0, tv5); + + fusion.addOutput(tv6); + + // Propagate the left half of slice and pad + scheduler_tools::propagateResizeToInputs(tv2->definition()); + auto pad_left = + dynamic_cast(tv5->definition()->input(1)->definition()); + scheduler_tools::propagateResizeToInputs(pad_left); + + // Propagate the right half of slice and pad + scheduler_tools::propagateResizeToInputs(tv4->definition()); + auto pad_right = + dynamic_cast(tv5->definition()->input(0)->definition()); + scheduler_tools::propagateResizeToInputs(pad_right); + + auto ref_tv = tv6; + + // Fusion should have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + // Schedule the reference + ref_tv->flatten(); + // For TIDx + ref_tv->split(0, 128); + // For BIDx + ref_tv->split(0, 4); + + { + IdModel id_model(&fusion, false); + id_model.buildExactGraph(); + std::ofstream ofs("exact_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } + + scheduler_tools::scheduleLoopDomainsLike( + fusion.allTvs(), ref_tv->getLoopDomain(), /*update_mode=*/true); + + // Fusion should still have a uniform loop domain + checkLoopDomainEquivalence(ref_tv); + + inlineMost(); + + // All tensors, except for fusion inputs, should be fully inlined + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(tv->getComputeAtPosition(), tv->nDims()) + << "Invalid computeAt position of " << tv->toString(); + } + + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({16, 100}, options); + std::vector inputs({t0}); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + // Consumer-based scheduling of pad TEST_F(ResizeTest, PropagatePadToInputs) { Fusion fusion; From 8c82f30798105a95c20272a541c132612d828623 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 10 Dec 2024 15:59:53 -0800 Subject: [PATCH 4/4] Benchmark sequence parallelism in test_transformer_engine (#3546) ``` $ nvidia-smi -L GPU 0: NVIDIA H100 80GB HBM3 GPU 1: NVIDIA H100 80GB HBM3 GPU 2: NVIDIA H100 80GB HBM3 GPU 3: NVIDIA H100 80GB HBM3 GPU 4: NVIDIA H100 80GB HBM3 GPU 5: NVIDIA H100 80GB HBM3 GPU 6: NVIDIA H100 80GB HBM3 GPU 7: NVIDIA H100 80GB HBM3 $ mpirun -np 8 --output-filename /tmp/test_transformer_engine pytest tests/python/test_transformer_engine.py --only-mpi $ cat /tmp/test_transformer_engine/1/rank.0/stdout ------------------------------------------------------------------------------------------ benchmark: 4 tests ------------------------------------------------------------------------------------------ Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_transformer_layer[sp-forward] 2.2564 (1.0) 55.7794 (11.73) 13.2931 (3.01) 23.7547 (125.77) 2.6707 (1.05) 14.1577 (88.73) 1;1 75.2268 (0.33) 5 1 test_transformer_layer[tp-forward] 2.3941 (1.06) 18.6497 (3.92) 6.7947 (1.54) 7.0469 (37.31) 2.5476 (1.0) 8.2456 (51.68) 1;0 147.1742 (0.65) 5 1 test_transformer_layer[tp-backward] 4.2568 (1.89) 4.8231 (1.01) 4.4578 (1.01) 0.2570 (1.36) 4.2940 (1.69) 0.4091 (2.56) 1;0 224.3258 (0.99) 5 1 test_transformer_layer[sp-backward] 4.3135 (1.91) 4.7558 (1.0) 4.4221 (1.0) 0.1889 (1.0) 4.3292 (1.70) 0.1596 (1.0) 1;1 226.1393 (1.0) 5 1 -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` Latency is neutral as expected. --- tests/python/test_transformer_engine.py | 28 ++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/python/test_transformer_engine.py b/tests/python/test_transformer_engine.py index de4734e6c90..00eb4b9eeb7 100644 --- a/tests/python/test_transformer_engine.py +++ b/tests/python/test_transformer_engine.py @@ -22,6 +22,13 @@ class ComputeType(Enum): BACKWARD = auto() +class Parallelism(Enum): + # https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism + TENSOR_PARALLEL = auto() + # https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism + SEQUENCE_PARALLEL = auto() + + @pytest.fixture(scope="module") def setup_process_group(mpi_test) -> None: # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. @@ -47,7 +54,12 @@ def setup_process_group(mpi_test) -> None: [ComputeType.FORWARD, ComputeType.BACKWARD], ids=["forward", "backward"], ) -def test_transformer_layer(setup_process_group, benchmark, compute_type): +@pytest.mark.parametrize( + "parallelism", + [Parallelism.TENSOR_PARALLEL, Parallelism.SEQUENCE_PARALLEL], + ids=["tp", "sp"], +) +def test_transformer_layer(setup_process_group, benchmark, compute_type, parallelism): # Hyperparameters for GPT-3 hidden_size = 12288 num_heads = 96 @@ -69,12 +81,20 @@ def test_transformer_layer(setup_process_group, benchmark, compute_type): # benchmark fails to execute on H100 with the default format (SBHD). attn_input_format="bshd", set_parallel_mode=True, + sequence_parallel=(parallelism == Parallelism.SEQUENCE_PARALLEL), tp_group=dist.group.WORLD, ) transformer_layer.to(dtype).to("cuda") + match parallelism: + case Parallelism.TENSOR_PARALLEL: + local_sequence_length = sequence_length + case Parallelism.SEQUENCE_PARALLEL: + assert sequence_length % size == 0 + local_sequence_length = sequence_length // size + x = torch.randn( - batch_size, sequence_length, hidden_size, dtype=dtype, device="cuda" + batch_size, local_sequence_length, hidden_size, dtype=dtype, device="cuda" ) match compute_type: @@ -93,7 +113,9 @@ def benchmark_fn(profile): # Warmup. y = benchmark_fn(False) - assert y.size() == torch.Size([batch_size, sequence_length, hidden_size]) + assert y.size() == torch.Size( + [batch_size, local_sequence_length, hidden_size] + ) benchmark.pedantic(benchmark_fn, args=(True,), rounds=5) case ComputeType.BACKWARD: