From 9c9c34ce683dd0fa6b72b61767f4b9233b76c223 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 26 Nov 2024 09:34:15 -0800 Subject: [PATCH] TMA: Don't merge dims when proven wrong (#3472) In our TMA support, there is a mechanism "Define box by compositing" which is used to reduce the dimensionality of TMA when possible. The reason for doing so is because our hardware has a limitation that the tensor dimension can not be more than 5D, and by collapsing some dimensions, we will be able to support tensors higher than 5D by viewing it as a tensor of less than 5D. While I do believe this mechanism makes great sense in general, mindlessly collapsing as much as possible can cause trouble. For example, if the box size is `(2, 256)`, collapsing it into `(512,)` is a mistake because our hardware has another limitation that the box size can not be larger than `256`. This PR makes the "Define box by compositing" mechanism smarter: instead of always collapsing, we check if collapsing can cause trouble. If so, we will not collapse. Note that this check can not guarantee 100% safety. For example, if the box size is a symbolic value `v1`, then there is no way to get its value in lowering, therefore, it is not possible to know if merging it with another dim can make the box size greater than 256. But I believe this is fine in practice, because in my knowledge, we are not interested in symbolic box size right now, so the box size is almost always constant. In the future, assume that we are moving towards analytical heuristics, that is, the box size is a `Val` that is a function of our input shapes, and represented by our IR node. In this world, I think dynamic concretization is necessary to make "Define box by compositing" bug-free. For example, let's say our box size is `(v1, v2, v3, v4, v5, v6)`, depending on the values of these variables, we might want to do different things: - `(v1, v2, v3, v4, v5, v6) = (256, 2, 128, 256, 256, 256)` --> merge `v2` with `v3` and use 5D TMA - `(v1, v2, v3, v4, v5, v6) = (256, 256, 2, 128, 256, 256)` --> merge `v3` with `v4` and use 5D TMA - `(v1, v2, v3, v4, v5, v6) = (256, 256, 256, 128, 256, 256)` --> raise an error, not supported --- csrc/device_lower/analysis/tma.cpp | 87 +++++++++++++++++++++++++----- tests/cpp/test_memory.cpp | 49 ++++++++++++++++- 2 files changed, 121 insertions(+), 15 deletions(-) diff --git a/csrc/device_lower/analysis/tma.cpp b/csrc/device_lower/analysis/tma.cpp index eb2463b3923..a13f7525664 100644 --- a/csrc/device_lower/analysis/tma.cpp +++ b/csrc/device_lower/analysis/tma.cpp @@ -800,16 +800,22 @@ class DomainMerger { std::unordered_set& bulk_groups_; std::unordered_set& nonbulk_groups_; std::list& dim_info_; + MmaInputSmemSwizzle swizzle_; + int64_t item_size_bytes_; public: DomainMerger( std::list> raw_tma_domain, std::unordered_set& bulk_groups, std::unordered_set& nonbulk_groups, - std::list& dim_info) + std::list& dim_info, + MmaInputSmemSwizzle swizzle, + int64_t item_size_bytes) : bulk_groups_(bulk_groups), nonbulk_groups_(nonbulk_groups), - dim_info_(dim_info) { + dim_info_(dim_info), + swizzle_(swizzle), + item_size_bytes_(item_size_bytes) { ValGraph& id_graph = GpuLower::current()->tensorIndexer().traversalGraph(); contiguity_and_stride_.reserve(raw_tma_domain.size()); for (auto& item : raw_tma_domain) { @@ -868,6 +874,49 @@ class DomainMerger { return C; } + bool shouldMerge(int64_t i) { + auto type0 = type(i); + auto type1 = type(i + 1); + + bool may_increasing_box_size = (type0 == CB && type1 == CB); + if (!may_increasing_box_size) { + return true; + } + + auto extent0 = (*this)[i]->front()->as()->extent(); + auto extent1 = (*this)[i + 1]->front()->as()->extent(); + Val* merged_extent = SimplifyingIrBuilder::mulExpr(extent0, extent1); + + bool merging_innermost = ((int64_t)size() == i + 2); + + // If merging makes the size of a dimension larger than 256, we should not + // merge. + constexpr int64_t largest_dim_size = + 256; // Dimension size must be <= 256 as limited by hardware. + Val* too_large_after_merge = SimplifyingIrBuilder::gtExpr( + merged_extent, IrBuilder::create(largest_dim_size)); + if (simplifyExpr(too_large_after_merge)->isTrue()) { + return false; + } + + // If merging makes the inner size larger than the swizzle size, + // we should not merge + if (merging_innermost && swizzle_ != MmaInputSmemSwizzle::None) { + const int64_t swizzle_size = + getBytesFromSwizzle(swizzle_) / item_size_bytes_; + Val* merging_makes_gt_swizzle_size = SimplifyingIrBuilder::gtExpr( + merged_extent, IrBuilder::create(swizzle_size)); + if (simplifyExpr(merging_makes_gt_swizzle_size)->isTrue()) { + return false; + } + } + + // Because the shape is dynamic, we don't know if we should merge or + // not. For this case, we always assume merging is better than not + // merging. + return true; + } + void merge(int64_t i) { auto type0 = type(i); auto type1 = type(i + 1); @@ -941,9 +990,15 @@ std::vector run( std::unordered_set& bulk_groups, std::unordered_set& nonbulk_groups, std::list& dim_info, - int64_t item_size_bytes) { + int64_t item_size_bytes, + MmaInputSmemSwizzle swizzle) { DomainMerger tma_domain( - std::move(raw_tma_domain), bulk_groups, nonbulk_groups, dim_info); + std::move(raw_tma_domain), + bulk_groups, + nonbulk_groups, + dim_info, + swizzle, + item_size_bytes); // merge contiguous C groups and CB groups for (int64_t i = 0; i < (int64_t)tma_domain.size() - 1; i++) { if (!tma_domain.contiguity(i)) { @@ -951,8 +1006,10 @@ std::vector run( } if ((tma_domain.type(i) == C && tma_domain.type(i + 1) == C) || (tma_domain.type(i) == CB && tma_domain.type(i + 1) == CB)) { - tma_domain.merge(i); - i--; + if (tma_domain.shouldMerge(i)) { + tma_domain.merge(i); + i--; + } } } // merge contiguous C with SB/CB @@ -962,8 +1019,10 @@ std::vector run( } if (tma_domain.type(i) == C && (tma_domain.type(i + 1) == SB || tma_domain.type(i + 1) == CB)) { - tma_domain.merge(i); - i--; + if (tma_domain.shouldMerge(i)) { + tma_domain.merge(i); + i--; + } } } @@ -1056,6 +1115,9 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) { "(this is always the case for nvFuser now)", ", the first element of elementStrides must be one."); + MmaInputSmemSwizzle swizzle = getSwizzleFromBytes( + getCpAsyncBulkTensorSwizzleSize(smem_tv) * core_matrix_width_bytes); + // Handle "defining box by compositing" by collapsing some dimensions in the // raw TMA domain to get the final TMA domain. auto final_tma_domain = collapse_tma_domain::run( @@ -1063,12 +1125,9 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) { bulk_groups, nonbulk_groups, inferred_dims, - dataTypeSize(gmem_tv->dtype())); - return TMAInfo( - std::move(final_tma_domain), - getSwizzleFromBytes( - getCpAsyncBulkTensorSwizzleSize(smem_tv) * core_matrix_width_bytes), - gmem_tv); + dataTypeSize(gmem_tv->dtype()), + swizzle); + return TMAInfo(std::move(final_tma_domain), swizzle, gmem_tv); } } // namespace diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 46e66c0ba9f..bd138a37045 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -881,7 +881,6 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) { } // Parallelize the tile axes tv1->axis(1)->parallelize(ParallelType::Bulk); - // tv2->axis(1)->parallelize(ParallelType::TIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({32, 4, 2, 8, 8, 8, 2, 8, 4}, options); @@ -895,6 +894,54 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) { testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); } +TEST_F(TMAIndexingTest, DefineBoxByCompositingShouldNotMerge) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({2, 256, 2, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + tv1->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + + // Use 1 thread and a single instruction to load the entire tensor to smem + for (auto id : tv1->getLoopDomain()) { + id->parallelize(ParallelType::Bulk); + } + + // Then use 32 threads to dump results out + tv2->axis(3)->parallelize(ParallelType::TIDx); + + // Schedule the allocation domain of tv1 to use 128B swizzle + AbstractTensor alloc1(tv1->getLoopDomain()); + alloc1.merge(0); + alloc1.merge(0); + // [1024, 32] + alloc1.split(1, 4); + alloc1.split(0, 8); + // [128, 8, 8, 4] + alloc1.swizzle(SwizzleType::XOR, 1, 2); + tv1->setAllocationDomain(alloc1.as(), true); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 256, 2, 32}, options); + KernelExecutor ke; + ke.compile(&fusion, {t0}, {}, matmul_cparams); + + // Because merging dims will violate hardware requirement, we do not merge + // dims. + EXPECT_EQ(TMADimChecker::getDim(ke.kernel()), 4); + + EXPECT_TRUE(PredicatedChecker::isPredicated(tv1, ke.kernel())); + + auto cg_outputs = ke.run({t0}); + testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); +} + TEST_F(TMAIndexingTest, DefineBoxByRotation1) { Fusion fusion; FusionGuard fg(&fusion);